View Javadoc

1   package yawn.nn.fuzzyartmap;
2   
3   import java.lang.reflect.InvocationTargetException;
4   import java.util.Vector;
5   
6   import org.apache.commons.beanutils.BeanUtils;
7   import org.apache.commons.logging.Log;
8   import org.apache.commons.logging.LogFactory;
9   
10  import yawn.YawnRuntimeException;
11  import yawn.config.ConfigurationException;
12  import yawn.config.NeuralNetworkConfig;
13  import yawn.nn.NeuralNetwork;
14  import yawn.nn.fuzzyartmap.fuzzyart.FuzzyArt;
15  import yawn.optim.OptimizableModel;
16  import yawn.optim.OptimizationAdapter;
17  import yawn.util.InputOutputPattern;
18  import yawn.util.Pattern;
19  
20  /***
21   * Implements the Fuzzy ARTMAP neural network, as described in
22   * <p>
23   * Carpenter, G. A., Grossberg, S., Markuzon, N., Reynolds, J. H. &amp; Rosen,
24   * D. B. (1992). Fuzzy ARTMAP: A neural network architecture for incremental
25   * supervised learning of analog multidimensional maps. <em>IEEE Transactions on
26   * Neural Networks</em>,
27   * 3(5):698-713.
28   * </p>
29   * 
30   * This code is based on the implementations of Relu Patrascu (rpatrasc at cs
31   * dot uwaterloo dot ca) and John Reynold (refined and maintained by Ah-Hwee Tan
32   * (atan at park dot bu dot edu)).
33   * 
34   * <p>$Id: FuzzyArtMap.java,v 1.13 2005/05/09 11:04:56 supermarti Exp $</p>
35   * 
36   * @author Luis Mart&iacute; (luis dot marti at uc3m dot es)
37   * @version $Revision: 1.13 $
38   */
39  
40  public class FuzzyArtMap extends NeuralNetwork implements OptimizableModel {
41  
42      private static final int ARRAY_INCREMENT = 10;
43  
44      private static final Log log = LogFactory.getLog(FuzzyArtMap.class);
45  
46      /***
47       * 
48       */
49      private static final long serialVersionUID = 3258135747491018292L;
50  
51  	/***
52  	 * 
53  	 * @uml.property name="alphaArtA" 
54  	 */
55  	protected double alphaArtA;
56  
57  	/***
58  	 * 
59  	 * @uml.property name="alphaArtB" 
60  	 */
61  	protected double alphaArtB;
62  
63  	/***
64  	 * 
65  	 * @uml.property name="artA"
66  	 * @uml.associationEnd multiplicity="(0 1)"
67  	 */
68  	protected FuzzyArt artA;
69  
70  	/***
71  	 * 
72  	 * @uml.property name="artB"
73  	 * @uml.associationEnd multiplicity="(0 1)"
74  	 */
75  	protected FuzzyArt artB;
76  
77  	/***
78  	 * 
79  	 * @uml.property name="betaArtA" 
80  	 */
81  	protected double betaArtA;
82  
83  	/***
84  	 * 
85  	 * @uml.property name="betaArtB" 
86  	 */
87  	protected double betaArtB;
88  
89  	/***
90  	 * 
91  	 * @uml.property name="currentConfig"
92  	 * @uml.associationEnd multiplicity="(0 1)" inverse="network:yawn.nn.fuzzyartmap.FuzzyArtMapConfig"
93  	 */
94  	private FuzzyArtMapConfig currentConfig;
95  
96  	/***
97  	 * small amount added when doing match tracking
98  	 * 
99  	 * @uml.property name="epsilon" 
100 	 */
101 	protected double epsilon;
102 
103 	/***
104 	 * 
105 	 * @uml.property name="inputSize" 
106 	 */
107 	protected int inputSize;
108 
109 
110     private int maps[];
111 
112 	/***
113 	 * max difference between desired output and prediction to be counted as
114 	 * error
115 	 * 
116 	 * @uml.property name="matchError" 
117 	 */
118 	protected double matchError;
119 
120 	/***
121 	 * max number of passes through the training set
122 	 * 
123 	 * @uml.property name="maxEpochs" 
124 	 */
125 	protected int maxEpochs;
126 
127 
128     private double[] maxInputs;
129 
130     private double[] maxOutputs;
131 
132     private double[] minInputs;
133 
134     private double[] minOutputs;
135 
136 	/***
137 	 * 
138 	 * @uml.property name="numberOfMismatches" 
139 	 */
140 	private int numberOfMismatches;
141 
142 	/***
143 	 * 
144 	 * @uml.property name="outputSize" 
145 	 */
146 	protected int outputSize;
147 
148 	/***
149 	 * 
150 	 * @uml.property name="useComplementCoding" 
151 	 */
152 	protected boolean useComplementCoding;
153 
154 	/***
155 	 * 
156 	 * @uml.property name="vigilanceArtA" 
157 	 */
158 	protected double vigilanceArtA;
159 
160 	/***
161 	 * 
162 	 * @uml.property name="vigilanceArtB" 
163 	 */
164 	protected double vigilanceArtB;
165 
166 
167     public FuzzyArtMap() {
168     }
169 
170     /***
171      * @param alphaArtA
172      * @param alphaArtB
173      * @param betaArtA
174      * @param betaArtB
175      * @param epsilon
176      * @param matchError
177      * @param maxEpochs
178      * @param useComplementCoding
179      * @param vigilanceArtA
180      * @param vigilanceArtB
181      */
182     FuzzyArtMap(double alphaArtA, double alphaArtB, double betaArtA, double betaArtB,
183             double epsilon, double matchError, int maxEpochs, boolean useComplementCoding,
184             double vigilanceArtA, double vigilanceArtB) {
185         super();
186         this.alphaArtA = alphaArtA;
187         this.alphaArtB = alphaArtB;
188         this.betaArtA = betaArtA;
189         this.betaArtB = betaArtB;
190         this.epsilon = epsilon;
191         this.matchError = matchError;
192         this.maxEpochs = maxEpochs;
193         this.useComplementCoding = useComplementCoding;
194         this.vigilanceArtA = vigilanceArtA;
195         this.vigilanceArtB = vigilanceArtB;
196     }
197 
198     /***
199      * Activates ARTb before ARTa, this kind of activation avoids the ``match
200      * tracking anomaly'' reported in
201      * 
202      * <p>
203      * Bartfai, G. (1996)On the Match Tracking Anomaly of the ARTMAP Neural
204      * Network. Neural Networks, 2 (9): 295-308.
205      * </p>
206      * 
207      * @param input
208      *            the ARTa input pattern
209      * @param output
210      *            the ARTb input pattern
211      * @return network prediction
212      */
213     protected Pattern bThenAActivation(Pattern input, Pattern output) {
214 
215         artB.setInputPattern(output.asDoubleArray());
216         artB.activate();
217 
218         if (isAdapting()) {
219             artB.learn();
220             if (artB.hasIncreased()) {
221                 increaseSize();
222                 artB.setIncreasedFlag(false);
223             }
224         }
225 
226         // activate artA
227         artA.vigilance = getVigilanceArtA();
228         artA.setInputPattern(input.asDoubleArray());
229 
230         // for(numberOfMismatches = 0; !artA.currentUncommitted &&
231         // artA.vigilance <= 1.0; numberOfMismatches++)
232 
233         int i = 0;
234         do {
235 
236             artA.activate();
237 
238             if (artA.currentUncommitted) {
239                 // no committed node in ARTa became active
240                 if (isAdapting()) {
241                     artA.learn();
242                     if (artA.hasIncreased) {
243                         increaseSize();
244                         artA.hasIncreased = false;
245                     }
246                 }
247             } else {
248                 // ARTa has a winning node
249 
250                 if (maps[artA.winner] == -1) {
251                     // ARTa winning node does not makes a prediction
252                     // so, we establish the A->B relation
253                     maps[artA.winner] = artB.winner;
254                 } else {
255                     // ARTa winner makes a prediction,
256                     // we must check if it is correct
257                     if (maps[artA.winner] != artB.winner) {
258                         // mismatch!
259                         if (artA.vigilance >= 1) {
260                             // this is a perfect mismatch, further match
261                             // tracking won't help
262                             maps[artA.winner] = artB.winner;
263                             // break;
264                         }
265                         // match tracking
266                         artA.vigilance = Math.min(1, artA.getMatchCritetionValue() + epsilon);
267                     }
268                 }
269             }
270         } while (maps[artA.winner] != artB.winner);
271 
272         if (isAdapting()) {
273             artA.learn();
274             if (artA.hasIncreased) {
275                 increaseSize();
276                 artA.hasIncreased = false;
277             }
278         }
279 
280         artA.vigilance = getVigilanceArtA();
281         double[] res = artB.getWeightVector(maps[artA.winner]);
282         if (artB.complementCoding) {
283             res = artB.unComplementCode(res);
284         }
285 
286         return new Pattern(res);
287     }
288 
289     public Pattern deScaleZeroOneOutput(Pattern pat) {
290         int len = pat.size();
291         if (artB.complementCoding)
292             len = len / 2;
293 
294         Pattern res = new Pattern(len);
295 
296         for (int i = 0; i < len; i++) {
297             res.setComponent(pat.getComponent(i) * (maxOutputs[i] - minOutputs[i]) + minOutputs[i],
298                     i);
299         }
300         return res;
301     }
302 
303     /*
304      * (non-Javadoc)
305      * 
306      * @see yawn.optim.OptimizableModel#getAdapterInstance()
307      */
308     public OptimizationAdapter getAdapterInstance() {
309         return new FuzzyArtMapJGapAdapter((FuzzyArtMapConfig) this.yieldConfiguration());
310     }
311 
312 	/***
313 	 * @return Returns the alphaArtA.
314 	 * 
315 	 * @uml.property name="alphaArtA"
316 	 */
317 	public double getAlphaArtA() {
318 		return this.alphaArtA;
319 	}
320 
321 	/***
322 	 * @return Returns the alphaArtB.
323 	 * 
324 	 * @uml.property name="alphaArtB"
325 	 */
326 	public double getAlphaArtB() {
327 		return this.alphaArtB;
328 	}
329 
330 	/***
331 	 * @return Returns the betaArtA.
332 	 * 
333 	 * @uml.property name="betaArtA"
334 	 */
335 	public double getBetaArtA() {
336 		return this.betaArtA;
337 	}
338 
339 	/***
340 	 * @return Returns the betaArtB.
341 	 * 
342 	 * @uml.property name="betaArtB"
343 	 */
344 	public double getBetaArtB() {
345 		return this.betaArtB;
346 	}
347 
348 	/***
349 	 * @return Returns the epsilon.
350 	 * 
351 	 * @uml.property name="epsilon"
352 	 */
353 	public double getEpsilon() {
354 		return this.epsilon;
355 	}
356 
357 	/***
358 	 * 
359 	 * @uml.property name="inputSize"
360 	 */
361 	/*
362 	 * (non-Javadoc)
363 	 * 
364 	 * @see yawn.nn.NeuralNetwork#inputSize()
365 	 */
366 	public int getInputSize() {
367 		return inputSize;
368 	}
369 
370 	/***
371 	 * @return Returns the matchError.
372 	 * 
373 	 * @uml.property name="matchError"
374 	 */
375 	public double getMatchError() {
376 		return this.matchError;
377 	}
378 
379 	/***
380 	 * @return Returns the maxEpochs.
381 	 * 
382 	 * @uml.property name="maxEpochs"
383 	 */
384 	public int getMaxEpochs() {
385 		return this.maxEpochs;
386 	}
387 
388 
389     /***
390      * 
391      * @see yawn.nn.NeuralNetwork#getNeuralNetworkName()
392      */
393     public String getNeuralNetworkName() {
394         return "Fuzzy ARTMAP";
395     }
396 
397     public int getNumberOfArtACategories() {
398         return artA.numberUsedNodes;
399     }
400 
401     public int getNumberOfArtBCategories() {
402         return artB.numberUsedNodes;
403     }
404 
405 	/***
406 	 * @return Returns the numberOfMismatches.
407 	 * 
408 	 * @uml.property name="numberOfMismatches"
409 	 */
410 	public int getNumberOfMismatches() {
411 		return this.numberOfMismatches;
412 	}
413 
414 	/***
415 	 * 
416 	 * @uml.property name="outputSize"
417 	 */
418 	/*
419 	 * (non-Javadoc)
420 	 * 
421 	 * @see yawn.nn.NeuralNetwork#outputSize()
422 	 */
423 	public int getOutputSize() {
424 		return outputSize;
425 	}
426 
427 	/***
428 	 * @return Returns the vigilanceArtA.
429 	 * 
430 	 * @uml.property name="vigilanceArtA"
431 	 */
432 	public double getVigilanceArtA() {
433 		return this.vigilanceArtA;
434 	}
435 
436 	/***
437 	 * @return Returns the vigilanceArtB.
438 	 * 
439 	 * @uml.property name="vigilanceArtB"
440 	 */
441 	public double getVigilanceArtB() {
442 		return this.vigilanceArtB;
443 	}
444 
445 
446     protected void increaseSize() {
447         int i;
448 
449         int newX[] = new int[maps.length + ARRAY_INCREMENT];
450         System.arraycopy(maps, 0, newX, 0, maps.length);
451         for (i = maps.length; i < newX.length; i++)
452             newX[i] = -1;
453         maps = newX;
454     }
455 
456     protected void init() {
457 
458         artA = new FuzzyArt(alphaArtA, betaArtA, vigilanceArtA, getInputSize(), maxEpochs,
459                 useComplementCoding);
460         artB = new FuzzyArt(alphaArtB, betaArtB, vigilanceArtB, getOutputSize(), maxEpochs,
461                 useComplementCoding);
462 
463         numberOfMismatches = 0;
464         maps = new int[ARRAY_INCREMENT];
465 
466         for (int i = 0; i < ARRAY_INCREMENT; i++)
467             maps[i] = -1;
468 
469     }
470 
471 	/***
472 	 * @return Returns the useComplementCoding.
473 	 * 
474 	 * @uml.property name="useComplementCoding"
475 	 */
476 	public boolean isUseComplementCoding() {
477 		return this.useComplementCoding;
478 	}
479 
480 
481     public void oneLearningStep(Pattern input, Pattern output) {
482         bThenAActivation(input, output);
483     }
484 
485     public Pattern predict(Pattern input) {
486         return deScaleZeroOneOutput(new Pattern(testPattern(scaleZeroOneInput(input)
487                 .asDoubleArray())));
488     }
489 
490     public Vector propagate(double inData[][]) {
491         Vector ret = new Vector();
492         for (int i = 0; i < inData.length; i++) {
493             ret.add(testPattern(inData[i]));
494         }
495         return ret;
496     }
497 
498     public Pattern scaleZeroOneInput(Pattern pat) {
499         Pattern res = new Pattern(pat.size());
500 
501         for (int i = 0; i < pat.size(); i++) {
502             res.setComponent((pat.getComponent(i) - minInputs[i]) / (maxInputs[i] - minInputs[i]),
503                     i);
504         }
505         return res;
506     }
507 
508     protected InputOutputPattern[] scaleZeroOneTrainingDataSet(InputOutputPattern[] iops) {
509         minInputs = new double[iops[0].input.size()];
510         minOutputs = new double[iops[0].output.size()];
511         maxInputs = new double[iops[0].input.size()];
512         maxOutputs = new double[iops[0].output.size()];
513 
514         for (int j = 0; j < minInputs.length; j++) {
515             minInputs[j] = Double.POSITIVE_INFINITY;
516             maxInputs[j] = Double.NEGATIVE_INFINITY;
517         }
518 
519         for (int k = 0; k < minOutputs.length; k++) {
520             minOutputs[k] = Double.POSITIVE_INFINITY;
521             maxOutputs[k] = Double.NEGATIVE_INFINITY;
522         }
523 
524         InputOutputPattern[] res = new InputOutputPattern[iops.length];
525 
526         // det mins and maxs
527         for (int i = 0; i < iops.length; i++) {
528             double[] currentInput = iops[i].input.asDoubleArray();
529             double[] currentOutput = iops[i].output.asDoubleArray();
530 
531             for (int j = 0; j < currentInput.length; j++) {
532                 if (minInputs[j] > currentInput[j]) {
533                     minInputs[j] = currentInput[j];
534                 }
535                 if (maxInputs[j] < currentInput[j]) {
536                     maxInputs[j] = currentInput[j];
537                 }
538             }
539 
540             for (int k = 0; k < currentOutput.length; k++) {
541                 if (minOutputs[k] > currentOutput[k]) {
542                     minOutputs[k] = currentOutput[k];
543                 }
544                 if (maxOutputs[k] < currentOutput[k]) {
545                     maxOutputs[k] = currentOutput[k];
546                 }
547             }
548         }
549 
550         // scale
551         for (int i = 0; i < iops.length; i++) {
552             res[i] = new InputOutputPattern(iops[i].input.size(), iops[i].output.size());
553 
554             double[] currentInput = iops[i].input.asDoubleArray();
555             double[] currentOutput = iops[i].output.asDoubleArray();
556 
557             for (int j = 0; j < currentInput.length; j++) {
558                 currentInput[j] = (currentInput[j] - minInputs[j]) / (maxInputs[j] - minInputs[j]);
559             }
560 
561             res[i].input = new Pattern(currentInput);
562 
563             for (int k = 0; k < currentOutput.length; k++) {
564                 currentOutput[k] = (currentOutput[k] - minOutputs[k])
565                         / (maxOutputs[k] - minOutputs[k]);
566             }
567             res[i].output = new Pattern(currentOutput);
568         }
569         return res;
570     }
571 
572 	/***
573 	 * @param alphaArtA
574 	 *            The alphaArtA to set.
575 	 * 
576 	 * @uml.property name="alphaArtA"
577 	 */
578 	public void setAlphaArtA(double alphaArtA) {
579 		this.alphaArtA = alphaArtA;
580 	}
581 
582 	/***
583 	 * @param alphaArtB
584 	 *            The alphaArtB to set.
585 	 * 
586 	 * @uml.property name="alphaArtB"
587 	 */
588 	public void setAlphaArtB(double alphaArtB) {
589 		this.alphaArtB = alphaArtB;
590 	}
591 
592 
593     public void setBaseArtAVigilance(double value) {
594         artA.vigilance = value;
595     }
596 
597 	/***
598 	 * @param betaArtA
599 	 *            The betaArtA to set.
600 	 * 
601 	 * @uml.property name="betaArtA"
602 	 */
603 	public void setBetaArtA(double betaArtA) {
604 		this.betaArtA = betaArtA;
605 	}
606 
607 	/***
608 	 * @param betaArtB
609 	 *            The betaArtB to set.
610 	 * 
611 	 * @uml.property name="betaArtB"
612 	 */
613 	public void setBetaArtB(double betaArtB) {
614 		this.betaArtB = betaArtB;
615 	}
616 
617 	/***
618 	 * @param epsilon
619 	 *            The epsilon to set.
620 	 * 
621 	 * @uml.property name="epsilon"
622 	 */
623 	public void setEpsilon(double epsilon) {
624 		this.epsilon = epsilon;
625 	}
626 
627 	/***
628 	 * @param matchError
629 	 *            The matchError to set.
630 	 * 
631 	 * @uml.property name="matchError"
632 	 */
633 	public void setMatchError(double matchError) {
634 		this.matchError = matchError;
635 	}
636 
637 	/***
638 	 * @param maxEpochs
639 	 *            The maxEpochs to set.
640 	 * 
641 	 * @uml.property name="maxEpochs"
642 	 */
643 	public void setMaxEpochs(int maxEpochs) {
644 		this.maxEpochs = maxEpochs;
645 	}
646 
647 	/***
648 	 * @param numberOfMismatches
649 	 *            The numberOfMismatches to set.
650 	 * 
651 	 * @uml.property name="numberOfMismatches"
652 	 */
653 	public void setNumberOfMismatches(int numberOfMismatches) {
654 		this.numberOfMismatches = numberOfMismatches;
655 	}
656 
657 
658     /*
659      * (non-Javadoc)
660      * 
661      * @see yawn.nn.NeuralNetwork#setup(yawn.nn.NeuralNetworkConfig)
662      */
663     public void setup(NeuralNetworkConfig config) throws ConfigurationException {
664         try {
665             BeanUtils.copyProperties(this, config);
666         } catch (IllegalAccessException e1) {
667             throw new YawnRuntimeException(e1);
668         } catch (InvocationTargetException e1) {
669             throw new YawnRuntimeException(e1);
670         }
671 
672         inputSize = config.getEnvironment().inputSize();
673         outputSize = config.getEnvironment().outputSize();
674         init();
675         currentConfig = (FuzzyArtMapConfig) config;
676     }
677 
678 	/***
679 	 * @param useComplementCoding
680 	 *            The useComplementCoding to set.
681 	 * 
682 	 * @uml.property name="useComplementCoding"
683 	 */
684 	public void setUseComplementCoding(boolean useComplementCoding) {
685 		this.useComplementCoding = useComplementCoding;
686 	}
687 
688 	/***
689 	 * @param vigilanceArtA
690 	 *            The vigilanceArtA to set.
691 	 * 
692 	 * @uml.property name="vigilanceArtA"
693 	 */
694 	public void setVigilanceArtA(double vigilanceArtA) {
695 		this.vigilanceArtA = vigilanceArtA;
696 	}
697 
698 	/***
699 	 * @param vigilanceArtB
700 	 *            The vigilanceArtB to set.
701 	 * 
702 	 * @uml.property name="vigilanceArtB"
703 	 */
704 	public void setVigilanceArtB(double vigilanceArtB) {
705 		this.vigilanceArtB = vigilanceArtB;
706 	}
707 
708     public void setVigilanceB(double value) {
709         artB.vigilance = value;
710     }
711 
712     public double[] testPattern(double inPattern[]) {
713         boolean adapt = isAdapting();
714 
715         setAdapting(false);
716 
717         double oldVigilance = artA.vigilance;
718         double result[] = null;
719 
720         artA.vigilance = 1.0;
721         artA.setInputPattern(inPattern);
722         do {
723             artA.vigilance -= 0.001;
724             artA.activate();
725             if (!artA.currentUncommitted)
726                 result = artB.getWeightVector(maps[artA.winner]);
727         } while (artA.vigilance >= 0 && artA.currentUncommitted && result == null);
728 
729         if (result == null) {
730             result = new double[artB.getNumberOfInputs()];
731             for (int i = 0; i < result.length; i++) {
732                 result[i] = Double.NaN;
733             }
734         }
735 
736         // restore things
737         artA.vigilance = oldVigilance;
738         setAdapting(adapt);
739 
740         return result;
741     }
742 
743     public void train(InputOutputPattern[] ori) {
744         init();
745         boolean done = false;
746         boolean adapt = isAdapting();
747 
748         setAdapting(true);
749 
750         InputOutputPattern[] iop = scaleZeroOneTrainingDataSet(ori);
751 
752         for (int i = 0; (i < maxEpochs) && !done; i++) {
753             // System.out.print("\n);
754             double epochError = 0;
755             numberOfMismatches = 0;
756 
757             for (int j = 0; j < iop.length; j++) {
758                 Pattern res = bThenAActivation(iop[j].input, iop[j].output);
759                 if (res.dist(iop[j].output) > matchError) {
760                     numberOfMismatches++;
761                 }
762                 epochError += res.dist(iop[j].output);
763             }
764             // trainPattern(inPatterns[j], outPatterns[j]);
765             done = !(numberOfMismatches > 0);
766             log.debug("Epoch: " + i + "; mismatches: " + numberOfMismatches
767                     + "; input (artA) cats: " + artA.numberUsedNodes + "; output (artB) cats:"
768                     + artB.numberUsedNodes + "; mean error: " + epochError / iop.length + ".");
769         }
770 
771         setAdapting(adapt);
772     }
773 
774     /*
775      * (non-Javadoc)
776      * 
777      * @see yawn.nn.NeuralNetwork#yieldConfiguration()
778      */
779     public NeuralNetworkConfig yieldConfiguration() {
780         FuzzyArtMapConfig conf = new FuzzyArtMapConfig();
781 
782         try {
783             if (currentConfig != null) {
784                 BeanUtils.copyProperties(conf, currentConfig);
785             }
786             BeanUtils.copyProperties(conf, this);
787         } catch (IllegalAccessException e1) {
788             throw new YawnRuntimeException(e1);
789         } catch (InvocationTargetException e1) {
790             throw new YawnRuntimeException(e1);
791         }
792 
793         return conf;
794     }
795 }