1   /*
2    * AbstractCircleInTheSquareNetworkTest.java
3    * Part of the yawn project
4    * Created on 01-dic-2004 by marti.
5    *
6    */
7   package yawn.nn;
8   
9   import org.apache.commons.logging.Log;
10  import org.apache.commons.logging.LogFactory;
11  
12  import yawn.envs.EnvironmentException;
13  import yawn.envs.synthetic.CircleInTheSquareEnvironment;
14  import yawn.envs.synthetic.SyntheticDataEnvironment;
15  
16  /***
17   * An abstract class that implements the common functionality a neural networks.
18   * 
19   * <p>$Id: AbstractCircleInTheSquareNetworkTest.java,v 1.9 2005/04/20 18:55:12 supermarti Exp $</p>
20   * 
21   * @author Luis Mart&iacute; (luis dot marti at uc3m dot es)
22   * @version $Revision: 1.9 $
23   */
24  public abstract class AbstractCircleInTheSquareNetworkTest extends AbstractNeuralNetworkTest {
25  
26      private static final Log log = LogFactory.getLog(AbstractCircleInTheSquareNetworkTest.class);
27  
28      public static final double TARGETED_PREDICTION_MSE = 0.08;
29  
30      public static final int TRAIN_SET_SIZE = 1000;
31  
32      public static final int TEST_SET_SIZE = 200;
33  
34      protected SyntheticDataEnvironment createEnvironment() {
35          CircleInTheSquareEnvironment env = new CircleInTheSquareEnvironment();
36          env.setTrainSetSize(TRAIN_SET_SIZE);
37          env.setTestSetSize(TEST_SET_SIZE);
38          env.setNumberOfSystemRuns(1);
39          return env;
40      }
41  
42      public void testTrain() {
43          NeuralNetwork net = conf.configuredNetworkFactory();
44          net.train(trainSet);
45          
46          double error = Double.NaN;
47          try {
48              error = net.getStatisticsFacility().computePredictionError(net, env.getTestDataset(0));
49          } catch (EnvironmentException e) {
50              fail();
51          }
52  
53          log.info("Test set MSE: " + error);
54          assertTrue("Prediction error `" + error + "´ should be less than the required mark `"
55                  + TARGETED_PREDICTION_MSE + "´.", error < TARGETED_PREDICTION_MSE);
56      }
57  }