1 /***
2 *
3 */
4 package yawn.util;
5
6 import yawn.nn.NeuralNetwork;
7
8 /***
9 * This is yawn.util.DefaultStatisticsFacility, part of the yawn project.
10 *
11 * <p>$Id: DefaultStatisticsFacility.java,v 1.4 2005/04/20 18:55:19 supermarti Exp $</p>
12 *
13 * @author Luis Martí (luis dot marti at uc3m dot es)
14 * @version $Revision: 1.4 $
15 */
16 public class DefaultStatisticsFacility implements StatisticsFacility {
17
18 /***
19 * Calculates the prediction error as the mean squared error
20 *
21 * @see yawn.util.StatisticsFacility#computePredictionError(yawn.nn.NeuralNetwork,
22 * yawn.util.InputOutputPattern[])
23 */
24 public double computePredictionError(NeuralNetwork network, InputOutputPattern[] set) {
25 return meanSquaredError(set, getPredictions(network, set));
26 }
27
28 protected static void checkDimensions(Object[] desired, Object[] predictions) {
29 if (desired.length != predictions.length) {
30 throw new ArrayIndexOutOfBoundsException("Using arrays with different sizes.");
31 }
32 }
33
34 /***
35 *
36 * @param set the dataset to be used as test set
37 * @param predictions the values of the predictions made with the above set
38 * @return the mean squared error
39 */
40 public double meanSquaredError(InputOutputPattern[] set, Pattern[] predictions) {
41 double res = 0;
42 checkDimensions(set, predictions);
43 for (int i = 0; i < set.length; i++) {
44 res += set[i].output.dist(predictions[i]) / set[i].output.size();
45 }
46 return res / set.length;
47 }
48
49 protected Pattern[] getPredictions(NeuralNetwork net, InputOutputPattern[] set) {
50 Pattern[] res = new Pattern[set.length];
51
52 for (int i = 0; i < set.length; i++) {
53 res[i] = net.predict(set[i].input);
54 }
55 return res;
56 }
57 }