View Javadoc

1   package yawn.nn.appart;
2   
3   import java.lang.reflect.InvocationTargetException;
4   
5   import org.apache.commons.beanutils.BeanUtils;
6   import org.apache.commons.logging.Log;
7   import org.apache.commons.logging.LogFactory;
8   
9   import yawn.YawnRuntimeException;
10  import yawn.config.ConfigurationException;
11  import yawn.config.NeuralNetworkConfig;
12  import yawn.nn.InputLayer;
13  import yawn.nn.NeuralNetwork;
14  import yawn.nn.NeuralNode;
15  import yawn.nn.OutputLayer;
16  import yawn.util.InputOutputPattern;
17  import yawn.util.Pattern;
18  
19  /***
20   * <p>
21   * This class implements the AppART neural network. See:
22   * </p>
23   * <p>
24   * Mart&iacute;, L., Policriti, A. & Garc&iacute;a, L. (2003). Hybrid Adaptive
25   * Resonance Theory Neural Networks for Universal Function Approximation. In
26   * Abraham, A. and Jain, L. (eds.), Innovations in Intelligent Systems and
27   * Applications, Studies in Fuzziness and Soft Computing Series. Heidelberg:
28   * Physica (Springer) Verlag.
29   * </p>
30   * 
31   * @link http://www.dimi.uniud.it/~marti/research.html
32   * 
33   * <p>$Id: AppArt.java,v 1.15 2005/05/09 11:04:55 supermarti Exp $</p>
34   * @author Luis Mart&iacute; (luis dot marti at uc3m dot es)
35   * @version $Revision: 1.15 $
36   */
37  
38  public class AppArt extends NeuralNetwork {
39  
40      /***
41       * 
42       */
43      private static final long serialVersionUID = 3257006553327415353L;
44  
45      private static final Log log = LogFactory.getLog(AppArt.class);
46  
47  	/***
48  	 * 
49  	 * @uml.property name="f2GainControl"
50  	 * @uml.associationEnd multiplicity="(1 1)"
51  	 */
52  	protected GainControlUnitOnMatching f2GainControl;
53  
54  	/***
55  	 * 
56  	 * @uml.property name="f2Layer"
57  	 * @uml.associationEnd multiplicity="(1 1)"
58  	 */
59  	protected RecognitionLayer f2Layer;
60  
61  	/***
62  	 * 
63  	 * @uml.property name="initialDeviations"
64  	 * @uml.associationEnd multiplicity="(0 1)"
65  	 */
66  	protected Pattern initialDeviations;
67  
68  	/***
69  	 * 
70  	 * @uml.property name="inputLayer"
71  	 * @uml.associationEnd multiplicity="(1 1)"
72  	 */
73  	protected InputLayer inputLayer;
74  
75  
76      int inputSize;
77      int outputSize;
78  
79  	/***
80  	 * 
81  	 * @uml.property name="learningRate" 
82  	 */
83  	protected double learningRate;
84  
85  	/***
86  	 * 
87  	 * @uml.property name="matchTrackingOneShot" 
88  	 */
89  	protected boolean matchTrackingOneShot;
90  
91  	/***
92  	 * 
93  	 * @uml.property name="maxEpochs" 
94  	 */
95  	protected long maxEpochs;
96  
97  
98      private double currentEpochErrorSum;
99  
100 	/***
101 	 * 
102 	 * @uml.property name="outputGainControl"
103 	 * @uml.associationEnd multiplicity="(1 1)"
104 	 */
105 	protected GainControlUnitOnOutput outputGainControl;
106 
107 	/***
108 	 * 
109 	 * @uml.property name="outputLayer"
110 	 * @uml.associationEnd multiplicity="(1 1)"
111 	 */
112 	protected OutputLayer outputLayer;
113 
114 	/***
115 	 * 
116 	 * @uml.property name="predictionError" 
117 	 */
118 	protected double predictionError;
119 
120 	/***
121 	 * 
122 	 * @uml.property name="predictionLayer"
123 	 * @uml.associationEnd multiplicity="(1 1)"
124 	 */
125 	protected PredictionLayer predictionLayer;
126 
127 	/***
128 	 * 
129 	 * @uml.property name="predictionLayerLearningRate" 
130 	 */
131 	protected double predictionLayerLearningRate;
132 
133 	/***
134 	 * 
135 	 * @uml.property name="testMatchVigilance" 
136 	 */
137 	protected double testMatchVigilance;
138 
139 	/***
140 	 * 
141 	 * @uml.property name="trainMatchVigilance" 
142 	 */
143 	protected double trainMatchVigilance;
144 
145 	/***
146 	 * 
147 	 * @uml.property name="useAbsoluteError" 
148 	 */
149 	protected boolean useAbsoluteError;
150 
151 
152     public AppArt() {
153         super();
154 
155         outputLayer = new OutputLayer(1);
156         predictionLayer = new PredictionLayer(1, 0, outputLayer);
157 
158         initialDeviations = null;
159 
160         f2Layer = new RecognitionLayer(predictionLayer, null, inputSize);
161         f2GainControl = new GainControlUnitOnMatching(-1, f2Layer);
162 
163         f2Layer.setGF2(f2GainControl);
164 
165         outputGainControl = new GainControlUnitOnOutput(-1);
166 
167         inputLayer = new InputLayer(f2Layer, inputSize);
168 
169         matchTrackingOneShot = false;
170         useAbsoluteError = false;
171     }
172 
173     protected void computeError(Pattern prediction, Pattern expected) {
174         if (useAbsoluteError)
175             outputGainControl.calculateAbsoluteError(prediction, expected);
176         else
177             outputGainControl.calculateRelativeError(prediction, expected);
178     }
179 
180     protected void doBackTrack() {
181         doMatchTracking();
182         f2Layer.reset();
183         predictionLayer.reset();
184         outputLayer.reset();
185     }
186 
187     protected void doMatchTracking() {
188         if (!isMatchTrackingOneShot()) {
189             f2GainControl.minimumActivationMatchTracking();
190         } else {
191             f2GainControl.oneShotMatchTracking();
192         }
193     }
194 
195 	/***
196 	 * @return Returns the f2Layer.
197 	 * 
198 	 * @uml.property name="f2Layer"
199 	 */
200 	public RecognitionLayer getF2Layer() {
201 		return f2Layer;
202 	}
203 
204 	/***
205 	 * @return Returns the initialDeviations.
206 	 * 
207 	 * @uml.property name="initialDeviations"
208 	 */
209 	public Pattern getInitialDeviations() {
210 		return initialDeviations;
211 	}
212 
213 
214     /*
215      * (non-Javadoc)
216      * 
217      * @see yawn.nn.NeuralNetwork#inputSize()
218      */
219     public int getInputSize() {
220         return inputLayer.size();
221     }
222 
223 	/***
224 	 * @return Returns the learningRate.
225 	 * 
226 	 * @uml.property name="learningRate"
227 	 */
228 	public double getLearningRate() {
229 		return learningRate;
230 	}
231 
232 	/***
233 	 * @return Returns the maxEpochs.
234 	 * 
235 	 * @uml.property name="maxEpochs"
236 	 */
237 	public long getMaxEpochs() {
238 		return maxEpochs;
239 	}
240 
241 
242     /***
243      * 
244      * @see yawn.nn.NeuralNetwork#getNeuralNetworkName()
245      */
246     public String getNeuralNetworkName() {
247         return "AppART";
248     }
249 
250     /*
251      * (non-Javadoc)
252      * 
253      * @see yawn.nn.NeuralNetwork#outputSize()
254      */
255     public int getOutputSize() {
256         return outputLayer.size();
257     }
258 
259 	/***
260 	 * @return Returns the predictionError.
261 	 * 
262 	 * @uml.property name="predictionError"
263 	 */
264 	public double getPredictionError() {
265 		return predictionError;
266 	}
267 
268 	/***
269 	 * @return Returns the predictionLayer.
270 	 * 
271 	 * @uml.property name="predictionLayer"
272 	 */
273 	public PredictionLayer getPredictionLayer() {
274 		return predictionLayer;
275 	}
276 
277 	/***
278 	 * @return Returns the predictionLayerLearningRate.
279 	 * 
280 	 * @uml.property name="predictionLayerLearningRate"
281 	 */
282 	public double getPredictionLayerLearningRate() {
283 		return this.predictionLayerLearningRate;
284 	}
285 
286 	/***
287 	 * @return Returns the testMatchVigilance.
288 	 * 
289 	 * @uml.property name="testMatchVigilance"
290 	 */
291 	public double getTestMatchVigilance() {
292 		return testMatchVigilance;
293 	}
294 
295 	/***
296 	 * @return Returns the trainMatchVigilance.
297 	 * 
298 	 * @uml.property name="trainMatchVigilance"
299 	 */
300 	public double getTrainMatchVigilance() {
301 		return trainMatchVigilance;
302 	}
303 
304 
305     /***
306      * 
307      */
308     protected void init() {
309         outputLayer = new OutputLayer(outputSize);
310 
311         predictionLayer = new PredictionLayer(outputSize, predictionLayerLearningRate, outputLayer);
312 
313         // initialDeviations = null;
314 
315         f2Layer = new RecognitionLayer(predictionLayer, null, inputSize);
316         f2GainControl = new GainControlUnitOnMatching(getTrainMatchVigilance(), f2Layer);
317 
318         f2Layer.setGF2(f2GainControl);
319 
320         outputGainControl = new GainControlUnitOnOutput(getPredictionError());
321 
322         inputLayer = new InputLayer(f2Layer, inputSize);
323 
324     }
325 
326 	/***
327 	 * @return Returns the matchTrackingOneShot.
328 	 * 
329 	 * @uml.property name="matchTrackingOneShot"
330 	 */
331 	public boolean isMatchTrackingOneShot() {
332 		return matchTrackingOneShot;
333 	}
334 
335 	/***
336 	 * @return Returns the useAbsoluteError.
337 	 * 
338 	 * @uml.property name="useAbsoluteError"
339 	 */
340 	public boolean isUseAbsoluteError() {
341 		return useAbsoluteError;
342 	}
343 
344 
345     protected void learn(Pattern pat) {
346         // isAdapting() checks are done at each method
347         f2Layer.learn();
348         predictionLayer.learn(pat);
349     }
350 
351     protected Pattern learnNewCategory(Pattern output) {
352         f2Layer.makeNewCategory();
353         int i = f2Layer.size() - 1;
354 
355         inputLayer.propagateToNode(i); // i - 1 is the last one
356 
357         NeuralNode[] nodes = f2Layer.getNodes();
358         // RadialBasisFunctionsNeuralNode[] nodes =
359         // (RadialBasisFunctionsNeuralNode[])buzz;
360 
361         ((RadialBasisFunctionsNeuralNode) (nodes[i])).learnNewClass(initialDeviations, i + 1);
362 
363         // If output == null it behaves properly
364         predictionLayer.updateWeightsStructure(output);
365         f2Layer.propagateToNextLayer();
366         predictionLayer.propagateToNextLayer();
367         Pattern p1 = outputLayer.output();
368         reset();
369         return p1;
370     }
371 
372     protected int numberOfF2Nodes() {
373         return f2Layer.size();
374     }
375 
376     public void oneLearningStep(Pattern input, Pattern output) {
377         setAdapting(true);
378         setInput(input);
379         Pattern prediction = propagate(output);
380 
381         currentEpochErrorSum += prediction.dist(output);
382     }
383 
384     public Pattern predict(Pattern input) {
385         setAdapting(false);
386         f2GainControl.setVigilanceParameter(getTestMatchVigilance());
387         setInput(input);
388         return new Pattern(propagate(null));
389     }
390 
391     /***
392      * Propagates the input already already presented to the network (by calling
393      * setInput())
394      * 
395      * @return the result of propagation
396      */
397     protected Pattern propagate(Pattern output) {
398 
399         Pattern p1;
400 
401         // If f2Layer has no nodes then an uncommitted node should be committed.
402         if (f2Layer.size() == 0) {
403             if (isAdapting()) {
404                 return learnNewCategory(output);
405             }
406             throw new YawnRuntimeException("No F2 nodes on the network");
407         }
408 
409         inputLayer.propagateToNextLayer();
410         f2Layer.matching();
411 
412         // If f2GainControl fires then the current <input, output> must be coded
413         // by committing a node.
414         if (f2GainControl.fires() && isAdapting() && (output != null)) {
415             return learnNewCategory(output);
416         }
417 
418         f2Layer.calculateNormalizedActivations();
419 
420         f2Layer.propagateToNextLayer();
421         predictionLayer.propagateToNextLayer();
422 
423         // If output is null, then we proceed with an unsupervised learning in
424         // F2 and return the prediction
425         if (output == null) {
426             f2Layer.learn();
427             p1 = outputLayer.output();
428             reset();
429             return p1;
430         }
431 
432         p1 = outputLayer.output();
433 
434         computeError(p1, output);
435 
436         // If Go does not fires then the prediction is sufficiently
437         // similar to the expected output.
438         if (!outputGainControl.fires()) {
439             learn(output);
440             reset();
441             return p1;
442         }
443 
444         // Go fired, so we outputGainControl through the match tracking process
445         // and start over with propagation
446         doBackTrack();
447         return propagate(output);
448     }
449 
450     protected void reset() {
451         inputLayer.reset();
452         f2Layer.reset();
453         f2GainControl.setVigilanceParameter(f2GainControl.getBaseVigilanceParameter());
454         predictionLayer.reset();
455         outputLayer.reset();
456     }
457 
458     public void setAdapting(boolean adapt) {
459         super.setAdapting(adapt);
460         f2Layer.setAdapting(adapt);
461         predictionLayer.setAdapting(adapt);
462         return;
463     }
464 
465 	/***
466 	 * @param gf2
467 	 *            the f2GainControl to set.
468 	 * 
469 	 * @uml.property name="f2GainControl"
470 	 */
471 	public void setF2GainControl(GainControlUnitOnMatching gf2) {
472 		this.f2GainControl = gf2;
473 	}
474 
475 	/***
476 	 * @param f2
477 	 *            The f2Layer to set.
478 	 * 
479 	 * @uml.property name="f2Layer"
480 	 */
481 	public void setF2Layer(RecognitionLayer f2) {
482 		this.f2Layer = f2;
483 	}
484 
485 	/***
486 	 * @param initialDeviations
487 	 *            The initialDeviations to set.
488 	 * 
489 	 * @uml.property name="initialDeviations"
490 	 */
491 	public void setInitialDeviations(Pattern initialDeviations) {
492 		this.initialDeviations = initialDeviations;
493 	}
494 
495 
496     public void setInput(Pattern input) {
497         inputLayer.setInput(input);
498     }
499 
500 	/***
501 	 * 
502 	 * @uml.property name="learningRate"
503 	 */
504 	public void setLearningRate(double learningRate) {
505 		this.learningRate = learningRate;
506 	}
507 
508 	/***
509 	 * @param matchTrackingOneShot
510 	 *            The matchTrackingOneShot to set.
511 	 * 
512 	 * @uml.property name="matchTrackingOneShot"
513 	 */
514 	public void setMatchTrackingOneShot(boolean matchTrackingOneShot) {
515 		this.matchTrackingOneShot = matchTrackingOneShot;
516 	}
517 
518 	/***
519 	 * @param maxEpochs
520 	 *            The maxEpochs to set.
521 	 * 
522 	 * @uml.property name="maxEpochs"
523 	 */
524 	public void setMaxEpochs(long maxEpochs) {
525 		this.maxEpochs = maxEpochs;
526 	}
527 
528 	/***
529 	 * @param predictionError
530 	 *            The predictionError to set.
531 	 * 
532 	 * @uml.property name="predictionError"
533 	 */
534 	public void setPredictionError(double predictionError) {
535 		this.predictionError = predictionError;
536 	}
537 
538 	/***
539 	 * @param predictionLayer
540 	 *            the predictionLayer to set
541 	 * 
542 	 * @uml.property name="predictionLayer"
543 	 */
544 	public void setPredictionLayer(PredictionLayer predictionLayer) {
545 		this.predictionLayer = predictionLayer;
546 	}
547 
548 	/***
549 	 * @param predictionLayerLearningRate
550 	 *            The predictionLayerLearningRate to set.
551 	 * 
552 	 * @uml.property name="predictionLayerLearningRate"
553 	 */
554 	public void setPredictionLayerLearningRate(
555 		double predictionLayerLearningRate) {
556 		this.predictionLayerLearningRate = predictionLayerLearningRate;
557 	}
558 
559 	/***
560 	 * @param testMatchVigilance
561 	 *            The testMatchVigilance to set.
562 	 * 
563 	 * @uml.property name="testMatchVigilance"
564 	 */
565 	public void setTestMatchVigilance(double testMatchVigilance) {
566 		this.testMatchVigilance = testMatchVigilance;
567 	}
568 
569 	/***
570 	 * @param trainMatchVigilance
571 	 *            The trainMatchVigilance to set.
572 	 * 
573 	 * @uml.property name="trainMatchVigilance"
574 	 */
575 	public void setTrainMatchVigilance(double trainMatchVigilance) {
576 		this.trainMatchVigilance = trainMatchVigilance;
577 	}
578 
579 
580     /*
581      * (non-Javadoc)
582      * 
583      * @see yawn.nn.NeuralNetwork#setup(yawn.nn.NeuralNetworkConfig)
584      */
585     public void setup(NeuralNetworkConfig config) throws ConfigurationException {
586         try {
587             BeanUtils.copyProperties(this, config);
588 
589             this.inputSize = config.getEnvironment().inputSize();
590             this.outputSize = config.getEnvironment().outputSize();
591             this.initialDeviations = ((AppArtConfig) config).getInitialDeviations();
592         } catch (IllegalAccessException e) {
593             throw new YawnRuntimeException(e);
594         } catch (InvocationTargetException e) {
595             throw new YawnRuntimeException(e);
596         }
597         init();
598     }
599 
600 	/***
601 	 * 
602 	 * @uml.property name="useAbsoluteError"
603 	 */
604 	public void setUseAbsoluteError(boolean error) {
605 		useAbsoluteError = error;
606 	}
607 
608 
609     public void setUseRelativeError(boolean error) {
610         useAbsoluteError = !error;
611     }
612 
613     /***
614      * 
615      */
616     public void train(InputOutputPattern[] iop) {
617         f2GainControl.setVigilanceParameter(getTrainMatchVigilance());
618 
619         for (int epoch = 0; epoch < getMaxEpochs(); epoch++) {
620             currentEpochErrorSum = 0;
621 
622             for (int i = 0; i < iop.length; i++) {
623                 oneLearningStep(iop[i].input, iop[i].output);
624             }
625 
626             double meanError = currentEpochErrorSum / iop.length;
627 
628             log.debug("Epoch: " + epoch + ", mse: " + meanError + ", f2Layer size: "
629                     + f2Layer.size() + ".");
630             if (meanError < desiredMeanSquaredError) {
631                 return;
632             }
633         }
634     }
635 
636 	/***
637 	 * 
638 	 * @uml.property name="desiredMeanSquaredError" 
639 	 */
640 	protected double desiredMeanSquaredError;
641 
642 
643     public void useMinimumActivationMatchTracking(boolean use) {
644         setMatchTrackingOneShot(!use);
645     }
646 
647     public void useOneShotMatchTracking(boolean use) {
648         setMatchTrackingOneShot(use);
649     }
650 
651     /***
652      * (non-Javadoc)
653      * 
654      * @see yawn.nn.NeuralNetwork#yieldConfiguration()
655      */
656     public NeuralNetworkConfig yieldConfiguration() {
657         AppArtConfig conf = new AppArtConfig();
658 
659         try {
660             BeanUtils.copyProperties(conf, this);
661         } catch (IllegalAccessException e) {
662             throw new YawnRuntimeException(e);
663         } catch (InvocationTargetException e) {
664             throw new YawnRuntimeException(e);
665         }
666 
667         return conf;
668     }
669 
670 	/***
671 	 * @return Returns the desiredMeanSquaredError.
672 	 * 
673 	 * @uml.property name="desiredMeanSquaredError"
674 	 */
675 	public double getDesiredMeanSquaredError() {
676 		return desiredMeanSquaredError;
677 	}
678 
679 	/***
680 	 * @param desiredMeanSquaredError
681 	 *            The desiredMeanSquaredError to set.
682 	 * 
683 	 * @uml.property name="desiredMeanSquaredError"
684 	 */
685 	public void setDesiredMeanSquaredError(double desiredMeanSquaredError) {
686 		this.desiredMeanSquaredError = desiredMeanSquaredError;
687 	}
688 
689 }