1 package yawn.nn.appart;
2
3 import java.lang.reflect.InvocationTargetException;
4
5 import org.apache.commons.beanutils.BeanUtils;
6 import org.apache.commons.logging.Log;
7 import org.apache.commons.logging.LogFactory;
8
9 import yawn.YawnRuntimeException;
10 import yawn.config.ConfigurationException;
11 import yawn.config.NeuralNetworkConfig;
12 import yawn.nn.InputLayer;
13 import yawn.nn.NeuralNetwork;
14 import yawn.nn.NeuralNode;
15 import yawn.nn.OutputLayer;
16 import yawn.util.InputOutputPattern;
17 import yawn.util.Pattern;
18
19 /***
20 * <p>
21 * This class implements the AppART neural network. See:
22 * </p>
23 * <p>
24 * Martí, L., Policriti, A. & García, L. (2003). Hybrid Adaptive
25 * Resonance Theory Neural Networks for Universal Function Approximation. In
26 * Abraham, A. and Jain, L. (eds.), Innovations in Intelligent Systems and
27 * Applications, Studies in Fuzziness and Soft Computing Series. Heidelberg:
28 * Physica (Springer) Verlag.
29 * </p>
30 *
31 * @link http://www.dimi.uniud.it/~marti/research.html
32 *
33 * <p>$Id: AppArt.java,v 1.15 2005/05/09 11:04:55 supermarti Exp $</p>
34 * @author Luis Martí (luis dot marti at uc3m dot es)
35 * @version $Revision: 1.15 $
36 */
37
38 public class AppArt extends NeuralNetwork {
39
40 /***
41 *
42 */
43 private static final long serialVersionUID = 3257006553327415353L;
44
45 private static final Log log = LogFactory.getLog(AppArt.class);
46
47 /***
48 *
49 * @uml.property name="f2GainControl"
50 * @uml.associationEnd multiplicity="(1 1)"
51 */
52 protected GainControlUnitOnMatching f2GainControl;
53
54 /***
55 *
56 * @uml.property name="f2Layer"
57 * @uml.associationEnd multiplicity="(1 1)"
58 */
59 protected RecognitionLayer f2Layer;
60
61 /***
62 *
63 * @uml.property name="initialDeviations"
64 * @uml.associationEnd multiplicity="(0 1)"
65 */
66 protected Pattern initialDeviations;
67
68 /***
69 *
70 * @uml.property name="inputLayer"
71 * @uml.associationEnd multiplicity="(1 1)"
72 */
73 protected InputLayer inputLayer;
74
75
76 int inputSize;
77 int outputSize;
78
79 /***
80 *
81 * @uml.property name="learningRate"
82 */
83 protected double learningRate;
84
85 /***
86 *
87 * @uml.property name="matchTrackingOneShot"
88 */
89 protected boolean matchTrackingOneShot;
90
91 /***
92 *
93 * @uml.property name="maxEpochs"
94 */
95 protected long maxEpochs;
96
97
98 private double currentEpochErrorSum;
99
100 /***
101 *
102 * @uml.property name="outputGainControl"
103 * @uml.associationEnd multiplicity="(1 1)"
104 */
105 protected GainControlUnitOnOutput outputGainControl;
106
107 /***
108 *
109 * @uml.property name="outputLayer"
110 * @uml.associationEnd multiplicity="(1 1)"
111 */
112 protected OutputLayer outputLayer;
113
114 /***
115 *
116 * @uml.property name="predictionError"
117 */
118 protected double predictionError;
119
120 /***
121 *
122 * @uml.property name="predictionLayer"
123 * @uml.associationEnd multiplicity="(1 1)"
124 */
125 protected PredictionLayer predictionLayer;
126
127 /***
128 *
129 * @uml.property name="predictionLayerLearningRate"
130 */
131 protected double predictionLayerLearningRate;
132
133 /***
134 *
135 * @uml.property name="testMatchVigilance"
136 */
137 protected double testMatchVigilance;
138
139 /***
140 *
141 * @uml.property name="trainMatchVigilance"
142 */
143 protected double trainMatchVigilance;
144
145 /***
146 *
147 * @uml.property name="useAbsoluteError"
148 */
149 protected boolean useAbsoluteError;
150
151
152 public AppArt() {
153 super();
154
155 outputLayer = new OutputLayer(1);
156 predictionLayer = new PredictionLayer(1, 0, outputLayer);
157
158 initialDeviations = null;
159
160 f2Layer = new RecognitionLayer(predictionLayer, null, inputSize);
161 f2GainControl = new GainControlUnitOnMatching(-1, f2Layer);
162
163 f2Layer.setGF2(f2GainControl);
164
165 outputGainControl = new GainControlUnitOnOutput(-1);
166
167 inputLayer = new InputLayer(f2Layer, inputSize);
168
169 matchTrackingOneShot = false;
170 useAbsoluteError = false;
171 }
172
173 protected void computeError(Pattern prediction, Pattern expected) {
174 if (useAbsoluteError)
175 outputGainControl.calculateAbsoluteError(prediction, expected);
176 else
177 outputGainControl.calculateRelativeError(prediction, expected);
178 }
179
180 protected void doBackTrack() {
181 doMatchTracking();
182 f2Layer.reset();
183 predictionLayer.reset();
184 outputLayer.reset();
185 }
186
187 protected void doMatchTracking() {
188 if (!isMatchTrackingOneShot()) {
189 f2GainControl.minimumActivationMatchTracking();
190 } else {
191 f2GainControl.oneShotMatchTracking();
192 }
193 }
194
195 /***
196 * @return Returns the f2Layer.
197 *
198 * @uml.property name="f2Layer"
199 */
200 public RecognitionLayer getF2Layer() {
201 return f2Layer;
202 }
203
204 /***
205 * @return Returns the initialDeviations.
206 *
207 * @uml.property name="initialDeviations"
208 */
209 public Pattern getInitialDeviations() {
210 return initialDeviations;
211 }
212
213
214
215
216
217
218
219 public int getInputSize() {
220 return inputLayer.size();
221 }
222
223 /***
224 * @return Returns the learningRate.
225 *
226 * @uml.property name="learningRate"
227 */
228 public double getLearningRate() {
229 return learningRate;
230 }
231
232 /***
233 * @return Returns the maxEpochs.
234 *
235 * @uml.property name="maxEpochs"
236 */
237 public long getMaxEpochs() {
238 return maxEpochs;
239 }
240
241
242 /***
243 *
244 * @see yawn.nn.NeuralNetwork#getNeuralNetworkName()
245 */
246 public String getNeuralNetworkName() {
247 return "AppART";
248 }
249
250
251
252
253
254
255 public int getOutputSize() {
256 return outputLayer.size();
257 }
258
259 /***
260 * @return Returns the predictionError.
261 *
262 * @uml.property name="predictionError"
263 */
264 public double getPredictionError() {
265 return predictionError;
266 }
267
268 /***
269 * @return Returns the predictionLayer.
270 *
271 * @uml.property name="predictionLayer"
272 */
273 public PredictionLayer getPredictionLayer() {
274 return predictionLayer;
275 }
276
277 /***
278 * @return Returns the predictionLayerLearningRate.
279 *
280 * @uml.property name="predictionLayerLearningRate"
281 */
282 public double getPredictionLayerLearningRate() {
283 return this.predictionLayerLearningRate;
284 }
285
286 /***
287 * @return Returns the testMatchVigilance.
288 *
289 * @uml.property name="testMatchVigilance"
290 */
291 public double getTestMatchVigilance() {
292 return testMatchVigilance;
293 }
294
295 /***
296 * @return Returns the trainMatchVigilance.
297 *
298 * @uml.property name="trainMatchVigilance"
299 */
300 public double getTrainMatchVigilance() {
301 return trainMatchVigilance;
302 }
303
304
305 /***
306 *
307 */
308 protected void init() {
309 outputLayer = new OutputLayer(outputSize);
310
311 predictionLayer = new PredictionLayer(outputSize, predictionLayerLearningRate, outputLayer);
312
313
314
315 f2Layer = new RecognitionLayer(predictionLayer, null, inputSize);
316 f2GainControl = new GainControlUnitOnMatching(getTrainMatchVigilance(), f2Layer);
317
318 f2Layer.setGF2(f2GainControl);
319
320 outputGainControl = new GainControlUnitOnOutput(getPredictionError());
321
322 inputLayer = new InputLayer(f2Layer, inputSize);
323
324 }
325
326 /***
327 * @return Returns the matchTrackingOneShot.
328 *
329 * @uml.property name="matchTrackingOneShot"
330 */
331 public boolean isMatchTrackingOneShot() {
332 return matchTrackingOneShot;
333 }
334
335 /***
336 * @return Returns the useAbsoluteError.
337 *
338 * @uml.property name="useAbsoluteError"
339 */
340 public boolean isUseAbsoluteError() {
341 return useAbsoluteError;
342 }
343
344
345 protected void learn(Pattern pat) {
346
347 f2Layer.learn();
348 predictionLayer.learn(pat);
349 }
350
351 protected Pattern learnNewCategory(Pattern output) {
352 f2Layer.makeNewCategory();
353 int i = f2Layer.size() - 1;
354
355 inputLayer.propagateToNode(i);
356
357 NeuralNode[] nodes = f2Layer.getNodes();
358
359
360
361 ((RadialBasisFunctionsNeuralNode) (nodes[i])).learnNewClass(initialDeviations, i + 1);
362
363
364 predictionLayer.updateWeightsStructure(output);
365 f2Layer.propagateToNextLayer();
366 predictionLayer.propagateToNextLayer();
367 Pattern p1 = outputLayer.output();
368 reset();
369 return p1;
370 }
371
372 protected int numberOfF2Nodes() {
373 return f2Layer.size();
374 }
375
376 public void oneLearningStep(Pattern input, Pattern output) {
377 setAdapting(true);
378 setInput(input);
379 Pattern prediction = propagate(output);
380
381 currentEpochErrorSum += prediction.dist(output);
382 }
383
384 public Pattern predict(Pattern input) {
385 setAdapting(false);
386 f2GainControl.setVigilanceParameter(getTestMatchVigilance());
387 setInput(input);
388 return new Pattern(propagate(null));
389 }
390
391 /***
392 * Propagates the input already already presented to the network (by calling
393 * setInput())
394 *
395 * @return the result of propagation
396 */
397 protected Pattern propagate(Pattern output) {
398
399 Pattern p1;
400
401
402 if (f2Layer.size() == 0) {
403 if (isAdapting()) {
404 return learnNewCategory(output);
405 }
406 throw new YawnRuntimeException("No F2 nodes on the network");
407 }
408
409 inputLayer.propagateToNextLayer();
410 f2Layer.matching();
411
412
413
414 if (f2GainControl.fires() && isAdapting() && (output != null)) {
415 return learnNewCategory(output);
416 }
417
418 f2Layer.calculateNormalizedActivations();
419
420 f2Layer.propagateToNextLayer();
421 predictionLayer.propagateToNextLayer();
422
423
424
425 if (output == null) {
426 f2Layer.learn();
427 p1 = outputLayer.output();
428 reset();
429 return p1;
430 }
431
432 p1 = outputLayer.output();
433
434 computeError(p1, output);
435
436
437
438 if (!outputGainControl.fires()) {
439 learn(output);
440 reset();
441 return p1;
442 }
443
444
445
446 doBackTrack();
447 return propagate(output);
448 }
449
450 protected void reset() {
451 inputLayer.reset();
452 f2Layer.reset();
453 f2GainControl.setVigilanceParameter(f2GainControl.getBaseVigilanceParameter());
454 predictionLayer.reset();
455 outputLayer.reset();
456 }
457
458 public void setAdapting(boolean adapt) {
459 super.setAdapting(adapt);
460 f2Layer.setAdapting(adapt);
461 predictionLayer.setAdapting(adapt);
462 return;
463 }
464
465 /***
466 * @param gf2
467 * the f2GainControl to set.
468 *
469 * @uml.property name="f2GainControl"
470 */
471 public void setF2GainControl(GainControlUnitOnMatching gf2) {
472 this.f2GainControl = gf2;
473 }
474
475 /***
476 * @param f2
477 * The f2Layer to set.
478 *
479 * @uml.property name="f2Layer"
480 */
481 public void setF2Layer(RecognitionLayer f2) {
482 this.f2Layer = f2;
483 }
484
485 /***
486 * @param initialDeviations
487 * The initialDeviations to set.
488 *
489 * @uml.property name="initialDeviations"
490 */
491 public void setInitialDeviations(Pattern initialDeviations) {
492 this.initialDeviations = initialDeviations;
493 }
494
495
496 public void setInput(Pattern input) {
497 inputLayer.setInput(input);
498 }
499
500 /***
501 *
502 * @uml.property name="learningRate"
503 */
504 public void setLearningRate(double learningRate) {
505 this.learningRate = learningRate;
506 }
507
508 /***
509 * @param matchTrackingOneShot
510 * The matchTrackingOneShot to set.
511 *
512 * @uml.property name="matchTrackingOneShot"
513 */
514 public void setMatchTrackingOneShot(boolean matchTrackingOneShot) {
515 this.matchTrackingOneShot = matchTrackingOneShot;
516 }
517
518 /***
519 * @param maxEpochs
520 * The maxEpochs to set.
521 *
522 * @uml.property name="maxEpochs"
523 */
524 public void setMaxEpochs(long maxEpochs) {
525 this.maxEpochs = maxEpochs;
526 }
527
528 /***
529 * @param predictionError
530 * The predictionError to set.
531 *
532 * @uml.property name="predictionError"
533 */
534 public void setPredictionError(double predictionError) {
535 this.predictionError = predictionError;
536 }
537
538 /***
539 * @param predictionLayer
540 * the predictionLayer to set
541 *
542 * @uml.property name="predictionLayer"
543 */
544 public void setPredictionLayer(PredictionLayer predictionLayer) {
545 this.predictionLayer = predictionLayer;
546 }
547
548 /***
549 * @param predictionLayerLearningRate
550 * The predictionLayerLearningRate to set.
551 *
552 * @uml.property name="predictionLayerLearningRate"
553 */
554 public void setPredictionLayerLearningRate(
555 double predictionLayerLearningRate) {
556 this.predictionLayerLearningRate = predictionLayerLearningRate;
557 }
558
559 /***
560 * @param testMatchVigilance
561 * The testMatchVigilance to set.
562 *
563 * @uml.property name="testMatchVigilance"
564 */
565 public void setTestMatchVigilance(double testMatchVigilance) {
566 this.testMatchVigilance = testMatchVigilance;
567 }
568
569 /***
570 * @param trainMatchVigilance
571 * The trainMatchVigilance to set.
572 *
573 * @uml.property name="trainMatchVigilance"
574 */
575 public void setTrainMatchVigilance(double trainMatchVigilance) {
576 this.trainMatchVigilance = trainMatchVigilance;
577 }
578
579
580
581
582
583
584
585 public void setup(NeuralNetworkConfig config) throws ConfigurationException {
586 try {
587 BeanUtils.copyProperties(this, config);
588
589 this.inputSize = config.getEnvironment().inputSize();
590 this.outputSize = config.getEnvironment().outputSize();
591 this.initialDeviations = ((AppArtConfig) config).getInitialDeviations();
592 } catch (IllegalAccessException e) {
593 throw new YawnRuntimeException(e);
594 } catch (InvocationTargetException e) {
595 throw new YawnRuntimeException(e);
596 }
597 init();
598 }
599
600 /***
601 *
602 * @uml.property name="useAbsoluteError"
603 */
604 public void setUseAbsoluteError(boolean error) {
605 useAbsoluteError = error;
606 }
607
608
609 public void setUseRelativeError(boolean error) {
610 useAbsoluteError = !error;
611 }
612
613 /***
614 *
615 */
616 public void train(InputOutputPattern[] iop) {
617 f2GainControl.setVigilanceParameter(getTrainMatchVigilance());
618
619 for (int epoch = 0; epoch < getMaxEpochs(); epoch++) {
620 currentEpochErrorSum = 0;
621
622 for (int i = 0; i < iop.length; i++) {
623 oneLearningStep(iop[i].input, iop[i].output);
624 }
625
626 double meanError = currentEpochErrorSum / iop.length;
627
628 log.debug("Epoch: " + epoch + ", mse: " + meanError + ", f2Layer size: "
629 + f2Layer.size() + ".");
630 if (meanError < desiredMeanSquaredError) {
631 return;
632 }
633 }
634 }
635
636 /***
637 *
638 * @uml.property name="desiredMeanSquaredError"
639 */
640 protected double desiredMeanSquaredError;
641
642
643 public void useMinimumActivationMatchTracking(boolean use) {
644 setMatchTrackingOneShot(!use);
645 }
646
647 public void useOneShotMatchTracking(boolean use) {
648 setMatchTrackingOneShot(use);
649 }
650
651 /***
652 * (non-Javadoc)
653 *
654 * @see yawn.nn.NeuralNetwork#yieldConfiguration()
655 */
656 public NeuralNetworkConfig yieldConfiguration() {
657 AppArtConfig conf = new AppArtConfig();
658
659 try {
660 BeanUtils.copyProperties(conf, this);
661 } catch (IllegalAccessException e) {
662 throw new YawnRuntimeException(e);
663 } catch (InvocationTargetException e) {
664 throw new YawnRuntimeException(e);
665 }
666
667 return conf;
668 }
669
670 /***
671 * @return Returns the desiredMeanSquaredError.
672 *
673 * @uml.property name="desiredMeanSquaredError"
674 */
675 public double getDesiredMeanSquaredError() {
676 return desiredMeanSquaredError;
677 }
678
679 /***
680 * @param desiredMeanSquaredError
681 * The desiredMeanSquaredError to set.
682 *
683 * @uml.property name="desiredMeanSquaredError"
684 */
685 public void setDesiredMeanSquaredError(double desiredMeanSquaredError) {
686 this.desiredMeanSquaredError = desiredMeanSquaredError;
687 }
688
689 }