1
2
3
4 package yawn.nn.committee;
5
6 import java.lang.reflect.InvocationTargetException;
7 import java.util.ArrayList;
8 import java.util.Iterator;
9 import java.util.List;
10
11 import org.apache.commons.beanutils.BeanUtils;
12 import org.apache.commons.lang.builder.EqualsBuilder;
13
14 import yawn.YawnRuntimeException;
15 import yawn.config.ConfigurationException;
16 import yawn.config.NeuralNetworkConfig;
17 import yawn.nn.NeuralNetwork;
18 import yawn.nn.committee.functions.CommitteeFunction;
19 import yawn.util.InputOutputPattern;
20 import yawn.util.Pattern;
21
22 /***
23 * A committe or set of neural networks to be trained and used concurrently to
24 * provide a consistent learning experiment and predictions.
25 *
26 * <p>$Id: NetworkCommittee.java,v 1.7 2005/05/09 11:04:54 supermarti Exp $</p>
27 *
28 * @author Luis Martí (luis dot marti at uc3m dot es)
29 * @version $Revision: 1.7 $
30 */
31 public class NetworkCommittee extends NeuralNetwork {
32
33 /***
34 *
35 */
36 private static final long serialVersionUID = 3256438123029147696L;
37
38 /***
39 * the function used to combine the results from each committe member
40 *
41 * @uml.property name="combinationFunction"
42 * @uml.associationEnd multiplicity="(0 1)"
43 */
44 protected CommitteeFunction combinationFunction;
45
46 /***
47 *
48 * @uml.property name="committee"
49 * @uml.associationEnd multiplicity="(0 -1)" elementType="yawn.nn.NeuralNetwork"
50 */
51 protected List committee;
52
53 /***
54 *
55 * @uml.property name="serialProcessing"
56 */
57 protected boolean serialProcessing = false;
58
59
60 public NetworkCommittee() {
61 super();
62 committee = new ArrayList();
63 combinationFunction = null;
64 }
65
66 /***
67 *
68 */
69 public NetworkCommittee(CommitteeFunction combinationFunction) {
70 this();
71 this.combinationFunction = combinationFunction;
72 }
73
74 public void addCommitteeMembers(NeuralNetwork member) {
75 committee.add(member);
76 }
77
78 /***
79 *
80 * @uml.property name="committee"
81 */
82 public List getCommitteMembers() {
83 return committee;
84 }
85
86
87 /***
88 * this has no sense in this context, throws an
89 * <code>UnsupportedOperationException</code> by default.
90 */
91 public void oneLearningStep(Pattern input, Pattern output) {
92 throw new UnsupportedOperationException();
93 }
94
95 /***
96 *
97 * @see yawn.nn.NeuralNetwork#predict(yawn.util.Pattern)
98 */
99 public Pattern predict(Pattern input) {
100
101 ArrayList results = new ArrayList();
102
103 for (Iterator i = committee.iterator(); i.hasNext();) {
104 results.add(((NeuralNetwork) i.next()).predict(input));
105 }
106
107 return getCombinationFunction().assamble((Pattern[]) results.toArray(new Pattern[1]));
108 }
109
110 /***
111 *
112 * @see yawn.nn.NeuralNetwork#train(yawn.util.InputOutputPattern[])
113 */
114 public void train(InputOutputPattern[] iop) {
115 ThreadGroup threadGroup = new ThreadGroup("a-group-of:" + this.toString());
116
117 for (Iterator i = committee.iterator(); i.hasNext();) {
118 TrainRunner trainRunner = new TrainRunner((NeuralNetwork) i.next(), InputOutputPattern
119 .randomOrderList(iop));
120
121 Thread aThread = new Thread(threadGroup, trainRunner);
122 aThread.start();
123
124 if (isSerialProcessing()) {
125 while (aThread.isAlive()) {
126 Thread.yield();
127 }
128 }
129 }
130
131 while (threadGroup.activeCount() > 0) {
132 Thread.yield();
133 }
134
135 }
136
137 /***
138 * @return Returns the combinationFunction.
139 *
140 * @uml.property name="combinationFunction"
141 */
142 public CommitteeFunction getCombinationFunction() {
143 return combinationFunction;
144 }
145
146
147 /***
148 * @return Returns the committeSize.
149 */
150 public int getCommitteSize() {
151 return committee.size();
152 }
153
154
155
156
157
158
159 public String getNeuralNetworkName() {
160 return "Network Committe";
161 }
162
163 /***
164 * @param combinationFunction
165 * The combinationFunction to set.
166 *
167 * @uml.property name="combinationFunction"
168 */
169 public void setCombinationFunction(CommitteeFunction combinationFunction) {
170 this.combinationFunction = combinationFunction;
171 }
172
173 /***
174 * Utility class for multi-threaded training of neural networks.
175 *
176 * @author Luis Martí (marti at uh dot cu)
177 */
178
179 public class TrainRunner implements Runnable {
180
181 /***
182 *
183 * @uml.property name="network"
184 * @uml.associationEnd multiplicity="(0 1)"
185 */
186 private NeuralNetwork network;
187
188 /***
189 *
190 * @uml.property name="iops"
191 * @uml.associationEnd multiplicity="(0 -1)"
192 */
193 private InputOutputPattern[] iops;
194
195 /***
196 *
197 */
198 public TrainRunner(NeuralNetwork network, InputOutputPattern[] iops) {
199 this.network = network;
200 this.iops = iops;
201 }
202
203 /***
204 * @see java.lang.Runnable#run()
205 */
206 public void run() {
207 network.train(iops);
208 }
209 }
210
211
212
213
214
215
216 public int getInputSize() {
217
218 return 0;
219 }
220
221
222
223
224
225
226 public int getOutputSize() {
227
228 return 0;
229 }
230
231
232
233
234
235
236 public void setup(NeuralNetworkConfig config) throws ConfigurationException {
237 NetworkCommitteeConfig conf = (NetworkCommitteeConfig) config;
238
239 try {
240 BeanUtils.copyProperties(this, conf);
241 } catch (IllegalAccessException e1) {
242 throw new YawnRuntimeException(e1);
243 } catch (InvocationTargetException e1) {
244 throw new YawnRuntimeException(e1);
245 }
246
247 try {
248 setCombinationFunction((CommitteeFunction) conf.getCombinationFunctionClass()
249 .newInstance());
250 } catch (InstantiationException e) {
251 throw new ConfigurationException(e);
252 } catch (IllegalAccessException e) {
253 throw new ConfigurationException(e);
254 }
255
256 committee.clear();
257
258 for (Iterator i = conf.getCommitteeElements().iterator(); i.hasNext();) {
259 CommitteeElement ce = (CommitteeElement) i.next();
260 NeuralNetworkConfig elementConfig = ce.getNetworkConfig();
261
262 if (elementConfig.getEnvironment() == null) {
263 elementConfig.setEnvironment(conf.getEnvironment());
264 }
265
266 for (int j = 0; j < ce.getAmount(); j++) {
267
268 addCommitteeMembers(ce.getNetworkConfig().configuredNetworkFactory());
269 }
270 }
271 }
272
273 /***
274 *
275 * @see yawn.nn.NeuralNetwork#yieldConfiguration()
276 */
277 public NeuralNetworkConfig yieldConfiguration() {
278 NetworkCommitteeConfig res = new NetworkCommitteeConfig();
279
280 try {
281 BeanUtils.copyProperties(res, this);
282 } catch (IllegalAccessException e1) {
283 throw new YawnRuntimeException(e1);
284 } catch (InvocationTargetException e1) {
285 throw new YawnRuntimeException(e1);
286 }
287
288 res.setCombinationFunctionClassName(getCombinationFunction().getClass());
289
290 for (int i = 0; i < committee.size(); i++) {
291 NeuralNetworkConfig currentMember = ((NeuralNetwork) committee.get(i))
292 .yieldConfiguration();
293
294 for (Iterator iter = res.getCommitteeElements().iterator(); iter.hasNext();) {
295 CommitteeElement cur = (CommitteeElement) iter.next();
296 if (EqualsBuilder.reflectionEquals(cur.getNetworkConfig(), currentMember)) {
297 cur.setAmount(cur.getAmount() + 1);
298 currentMember = null;
299 break;
300 }
301 }
302
303 if (currentMember != null) {
304 CommitteeElement ce = new CommitteeElement();
305 ce.setNetworkConfig(currentMember);
306 ce.setAmount(1);
307 res.addCommitteeMember(ce);
308 }
309 }
310
311 return res;
312 }
313
314 /***
315 * @return Returns the serialProcessing.
316 *
317 * @uml.property name="serialProcessing"
318 */
319 public boolean isSerialProcessing() {
320 return serialProcessing;
321 }
322
323 /***
324 * @param serialProcessing
325 * The serialProcessing to set.
326 *
327 * @uml.property name="serialProcessing"
328 */
329 public void setSerialProcessing(boolean serialProcessing) {
330 this.serialProcessing = serialProcessing;
331 }
332
333 }