View Javadoc

1   /*
2    * Created on 27-nov-2003
3    *
4    */
5   package yawn.nn.mlp;
6   
7   import java.lang.reflect.InvocationTargetException;
8   import java.util.ArrayList;
9   import java.util.Iterator;
10  import java.util.List;
11  import java.util.ListIterator;
12  
13  import org.apache.commons.beanutils.BeanUtils;
14  import org.apache.commons.logging.Log;
15  import org.apache.commons.logging.LogFactory;
16  
17  import yawn.YawnRuntimeException;
18  import yawn.config.ConfigurationException;
19  import yawn.config.NeuralNetworkConfig;
20  import yawn.nn.NeuralNetwork;
21  import yawn.util.InputOutputPattern;
22  import yawn.util.Pattern;
23  
24  /***
25   * A multi-layer perceptron with backpropagation of errors learning.
26   * 
27   * <p>$Id: MultiLayerPerceptron.java,v 1.10 2005/05/09 11:04:54 supermarti Exp $</p>
28   * 
29   * @author Luis Mart&iacute; (luis dot marti at uc3m dot es)
30   * @version $Revision: 1.10 $
31   */
32  public class MultiLayerPerceptron extends NeuralNetwork {
33  
34      /***
35       * 
36       */
37      private static final long serialVersionUID = 3617013065753900343L;
38  
39      private static final Log log = LogFactory.getLog(MultiLayerPerceptron.class);
40  
41      protected long currentEpoch;
42  
43      protected double errorSum;
44  
45  	/***
46  	 * 
47  	 * @uml.property name="layers"
48  	 * @uml.associationEnd multiplicity="(0 -1)" elementType="yawn.nn.mlp.MLPLayer"
49  	 */
50  	protected List layers;
51  
52  	/***
53  	 * 
54  	 * @uml.property name="learningRate" 
55  	 */
56  	protected double learningRate;
57  
58  	/***
59  	 * 
60  	 * @uml.property name="momentumRate" 
61  	 */
62  	protected double momentumRate;
63  
64  	/***
65  	 * 
66  	 * @uml.property name="maxEpochs" 
67  	 */
68  	protected long maxEpochs;
69  
70  	/***
71  	 * 
72  	 * @uml.property name="predictionError" 
73  	 */
74  	protected double predictionError;
75  
76  
77      /***
78       * 
79       */
80      public MultiLayerPerceptron() {
81          super();
82          layers = new ArrayList();
83      }
84  
85      /***
86       * @param index
87       *            index of insertion
88       * @param element
89       *            the layer to add
90       * @see java.util.List#add(int, Object)
91       */
92      public void addLayer(int index, MLPLayer element) {
93          layers.add(index, element);
94      }
95  
96      /***
97       * @param layer
98       *            The layer to add.
99       */
100     public void addLayer(MLPLayer layer) {
101         layers.add(layer);
102     }
103 
104     public MLPLayer getLayer(int i) {
105         return (MLPLayer) layers.get(i);
106     }
107 
108 	/***
109 	 * @return Returns the layers.
110 	 * 
111 	 * @uml.property name="layers"
112 	 */
113 	public List getLayers() {
114 		return layers;
115 	}
116 
117 	/***
118 	 * @return Returns the learningRate.
119 	 * 
120 	 * @uml.property name="learningRate"
121 	 */
122 	public double getLearningRate() {
123 		return learningRate;
124 	}
125 
126 	/***
127 	 * @return Returns the maxEpochs.
128 	 * 
129 	 * @uml.property name="maxEpochs"
130 	 */
131 	public long getMaxEpochs() {
132 		return maxEpochs;
133 	}
134 
135 
136     /***
137      * @see yawn.nn.NeuralNetwork#getNeuralNetworkName()
138      */
139     public String getNeuralNetworkName() {
140         return "Multi-Layer Perceptron";
141     }
142 
143 	/***
144 	 * @return Returns the predictionError.
145 	 * 
146 	 * @uml.property name="predictionError"
147 	 */
148 	public double getPredictionError() {
149 		return predictionError;
150 	}
151 
152 
153     /***
154      * 
155      */
156     protected void init() {
157         for (int i = layers.size() - 1; i > 0; i--) {
158             ((MLPLayer) layers.get(i)).connectWith((MLPLayer) layers.get(i - 1));
159         }
160     }
161 
162     /***
163      * @see yawn.nn.NeuralNetwork#getInputSize()
164      */
165     public int getInputSize() {
166         return ((MLPLayer) layers.get(0)).getInputSize();
167     }
168 
169     /***
170      * @return the number of layers in the network
171      */
172     public int layersCount() {
173         return layers.size();
174     }
175 
176     /***
177      * @return a java.util.Iterator of the layers list
178      * @see java.util.List#iterator()
179      */
180     public Iterator layersIterator() {
181         return layers.iterator();
182     }
183 
184     /***
185      * @return The ListIterator of the layers list.
186      * @see java.util.List#listIterator()
187      */
188     public ListIterator layersListIterator() {
189         return layers.listIterator();
190     }
191 
192     /***
193      * Implements a learning iteration as:
194      * <ul>
195      * <li>propagates the input storing the activation of all nodes;</li>
196      * <li>sums in <code>errorSum</code> the mean square error of the
197      * prediction;</li>
198      * <li>calculation of the deltas;</li>
199      * <li>backpropagation of errors, and;</li>
200      * <li>weights update.</li>
201      * </ul>
202      * 
203      * @param input
204      *            The input presented to the network.
205      * @param output
206      *            The expected output.
207      * @see yawn.nn.NeuralNetwork#oneLearningStep(yawn.util.Pattern,
208      *      yawn.util.Pattern)
209      */
210     public void oneLearningStep(Pattern input, Pattern output) {
211         Pattern prediction = predict(input);
212 
213         errorSum += prediction.dist(output) / prediction.size();
214 
215         Pattern[] deltas = new Pattern[layers.size()];
216 
217         deltas[deltas.length - 1] = ((MLPLayer) layers.get(layers.size() - 1))
218                 .calculateDeltasAsOutputLayer(output);
219 
220         for (int i = deltas.length - 2; i >= 0; i--) {
221             deltas[i] = ((MLPLayer) layers.get(i)).calculateDeltasAsHiddenLayer(deltas[i + 1]);
222         }
223 
224         for (int i = 0; i < layers.size(); i++) {
225             ((MLPLayer) layers.get(i)).adapt(deltas[i], learningRate, momentumRate);
226         }
227     }
228 
229     /***
230      * @see yawn.nn.NeuralNetwork#getOutputSize()
231      */
232     public int getOutputSize() {
233         return ((MLPLayer) layers.get(layers.size() - 1)).size();
234     }
235 
236     /***
237      * @see yawn.nn.NeuralNetwork#predict(yawn.util.Pattern)
238      */
239     public Pattern predict(Pattern input) {
240         ((MLPLayer) layers.get(0)).setInput(input);
241 
242         for (int i = 0; i < layers.size() - 1; i++) {
243             ((MLPLayer) layers.get(i)).propagateToNextLayer();
244         }
245 
246         return ((MLPLayer) layers.get(layers.size() - 1)).output();
247     }
248 
249     /***
250      * @param index
251      */
252     public void removeLayer(int index) {
253         layers.remove(index);
254     }
255 
256 	/***
257 	 * @param layers
258 	 *            The layers to set.
259 	 * 
260 	 * @uml.property name="layers"
261 	 */
262 	public void setLayers(List layers) {
263 		this.layers = layers;
264 	}
265 
266 	/***
267 	 * @param learningRate
268 	 *            The learningRate to set.
269 	 * 
270 	 * @uml.property name="learningRate"
271 	 */
272 	public void setLearningRate(double learningRate) {
273 		this.learningRate = learningRate;
274 	}
275 
276 	/***
277 	 * @param maxEpochs
278 	 *            The maxEpochs to set.
279 	 * 
280 	 * @uml.property name="maxEpochs"
281 	 */
282 	public void setMaxEpochs(long maxEpochs) {
283 		this.maxEpochs = maxEpochs;
284 	}
285 
286 	/***
287 	 * @param predictionError
288 	 *            The predictionError to set.
289 	 * 
290 	 * @uml.property name="predictionError"
291 	 */
292 	public void setPredictionError(double predictionError) {
293 		this.predictionError = predictionError;
294 	}
295 
296 
297     /***
298      * 
299      * @see yawn.nn.NeuralNetwork#train(yawn.util.InputOutputPattern[])
300      */
301     public void train(InputOutputPattern[] iop) {
302         errorSum = Double.MAX_VALUE;
303         currentEpoch = 0;
304 
305         while ((currentEpoch < maxEpochs) && (errorSum / iop.length > predictionError)) {
306             errorSum = 0;
307             for (int i = 0; i < iop.length; i++) {
308                 oneLearningStep(iop[i].input, iop[i].output);
309             }
310             currentEpoch++;
311             log.debug("Mean error at the end of epoch " + currentEpoch + ": "
312                     + (errorSum / iop.length));
313         }
314 
315     }
316 
317     /***
318      * @see yawn.nn.NeuralNetwork#setup(NeuralNetworkConfig)
319      */
320     public void setup(NeuralNetworkConfig config) throws ConfigurationException {
321         MultiLayerPerceptronConfig c = (MultiLayerPerceptronConfig) config;
322         this.config = c;
323 
324         layers.clear();
325 
326         // setup the network topology
327         LayerElement lc1 = (LayerElement) c.getLayerConfigs().get(0);
328         MLPLayer first;
329         try {
330             first = new MLPLayer(null, config.getEnvironment().inputSize(), lc1.getSize(), Class
331                     .forName(lc1.getNodesClassName()));
332         } catch (ClassNotFoundException e) {
333             throw new ConfigurationException(e);
334         }
335 
336         layers.add(first);
337 
338         for (int i = 1; i < c.getLayerConfigs().size(); i++) {
339             LayerElement lc = (LayerElement) c.getLayerConfigs().get(i);
340 
341             MLPLayer curr;
342             try {
343                 curr = new MLPLayer(null, ((MLPLayer) layers.get(i - 1)).size(), lc.getSize(),
344                         Class.forName(lc.getNodesClassName()));
345             } catch (ClassNotFoundException e1) {
346                 throw new ConfigurationException(e1);
347             }
348 
349             layers.add(curr);
350 
351             MLPLayer back = (MLPLayer) layers.get(i - 1);
352             back.connectWith(curr);
353         }
354 
355         // copy the rest of the properties
356         try {
357             BeanUtils.copyProperties(this, c);
358         } catch (IllegalAccessException e1) {
359             throw new YawnRuntimeException(e1);
360         } catch (InvocationTargetException e1) {
361             throw new YawnRuntimeException(e1);
362         }
363 
364     }
365 
366 	/***
367 	 * 
368 	 * @uml.property name="config"
369 	 * @uml.associationEnd multiplicity="(0 1)"
370 	 */
371 	protected MultiLayerPerceptronConfig config;
372 
373 
374     /***
375      * @see yawn.nn.NeuralNetwork#yieldConfiguration()
376      */
377     public NeuralNetworkConfig yieldConfiguration() {
378         return config;
379         /*
380          * MultiLayerPerceptronConfig conf = new MultiLayerPerceptronConfig();
381          *  // copy the properties
382          * 
383          * try { BeanUtils.copyProperties(conf, this); } catch
384          * (IllegalAccessException e1) { throw new YawnRuntimeException(e1); }
385          * catch (InvocationTargetException e1) { throw new
386          * YawnRuntimeException(e1); }
387          *  // translate the topology to a layers configuration ArrayList res =
388          * new ArrayList();
389          * 
390          * for (int i = 0; i < layers.size(); i++) { MLPLayer cur = getLayer(i);
391          * LayerElement lc = new LayerElement(); lc.setSize(cur.size());
392          * lc.setNodesClassName(cur.getNodes()[0].getClass().getName());
393          * res.add(lc); }
394          * 
395          * conf.setLayerConfigs(res);
396          * 
397          * return conf;
398          */
399     }
400 
401 	/***
402 	 * @return Returns the momentumRate.
403 	 * 
404 	 * @uml.property name="momentumRate"
405 	 */
406 	public double getMomentumRate() {
407 		return this.momentumRate;
408 	}
409 
410 	/***
411 	 * @param momentumRate
412 	 *            The momentumRate to set.
413 	 * 
414 	 * @uml.property name="momentumRate"
415 	 */
416 	public void setMomentumRate(double momentumRate) {
417 		this.momentumRate = momentumRate;
418 	}
419 
420     public MultiLayerPerceptron copy() {
421         MultiLayerPerceptron res = new MultiLayerPerceptron();
422         try {
423             BeanUtils.copyProperties(res, this);
424             return res;
425         } catch (IllegalAccessException e) {
426             if (log.isDebugEnabled()) {
427                 log.debug(e);
428             } else {
429                 log.error(e.getMessage());
430             }
431             throw new YawnRuntimeException(e);
432         } catch (InvocationTargetException e) {
433             if (log.isDebugEnabled()) {
434                 log.debug(e);
435             } else {
436                 log.error(e.getMessage());
437             }
438             throw new YawnRuntimeException(e);
439         }
440 
441     }
442 }