View Javadoc

1   /***
2    * 
3    */
4   package yawn.util;
5   
6   import yawn.nn.NeuralNetwork;
7   
8   /***
9    * This is yawn.util.DefaultStatisticsFacility, part of the yawn project.
10   * 
11   * <p>$Id: DefaultStatisticsFacility.java,v 1.4 2005/04/20 18:55:19 supermarti Exp $</p>
12   * 
13   * @author Luis Mart&iacute; (luis dot marti at uc3m dot es)
14   * @version $Revision: 1.4 $
15   */
16  public class DefaultStatisticsFacility implements StatisticsFacility {
17  
18      /***
19       * Calculates the prediction error as the mean squared error
20       * 
21       * @see yawn.util.StatisticsFacility#computePredictionError(yawn.nn.NeuralNetwork,
22       *      yawn.util.InputOutputPattern[])
23       */
24      public double computePredictionError(NeuralNetwork network, InputOutputPattern[] set) {
25          return meanSquaredError(set, getPredictions(network, set));
26      }
27  
28      protected static void checkDimensions(Object[] desired, Object[] predictions) {
29          if (desired.length != predictions.length) {
30              throw new ArrayIndexOutOfBoundsException("Using arrays with different sizes.");
31          }
32      }
33  
34      /***
35       * 
36       * @param set the dataset to be used as test set
37       * @param predictions the values of the predictions made with the above set
38       * @return the mean squared error
39       */
40      public double meanSquaredError(InputOutputPattern[] set, Pattern[] predictions) {
41          double res = 0;
42          checkDimensions(set, predictions);
43          for (int i = 0; i < set.length; i++) {
44              res += set[i].output.dist(predictions[i]) / set[i].output.size();
45          }
46          return res / set.length;
47      }
48  
49      protected Pattern[] getPredictions(NeuralNetwork net, InputOutputPattern[] set) {
50          Pattern[] res = new Pattern[set.length];
51  
52          for (int i = 0; i < set.length; i++) {
53              res[i] = net.predict(set[i].input);
54          }
55          return res;
56      }
57  }