1
2
3
4
5
6
7 package yawn.nn;
8
9 import org.apache.commons.logging.Log;
10 import org.apache.commons.logging.LogFactory;
11
12 import yawn.envs.EnvironmentException;
13 import yawn.envs.synthetic.CircleInTheSquareEnvironment;
14 import yawn.envs.synthetic.SyntheticDataEnvironment;
15
16 /***
17 * An abstract class that implements the common functionality a neural networks.
18 *
19 * <p>$Id: AbstractCircleInTheSquareNetworkTest.java,v 1.9 2005/04/20 18:55:12 supermarti Exp $</p>
20 *
21 * @author Luis Martí (luis dot marti at uc3m dot es)
22 * @version $Revision: 1.9 $
23 */
24 public abstract class AbstractCircleInTheSquareNetworkTest extends AbstractNeuralNetworkTest {
25
26 private static final Log log = LogFactory.getLog(AbstractCircleInTheSquareNetworkTest.class);
27
28 public static final double TARGETED_PREDICTION_MSE = 0.08;
29
30 public static final int TRAIN_SET_SIZE = 1000;
31
32 public static final int TEST_SET_SIZE = 200;
33
34 protected SyntheticDataEnvironment createEnvironment() {
35 CircleInTheSquareEnvironment env = new CircleInTheSquareEnvironment();
36 env.setTrainSetSize(TRAIN_SET_SIZE);
37 env.setTestSetSize(TEST_SET_SIZE);
38 env.setNumberOfSystemRuns(1);
39 return env;
40 }
41
42 public void testTrain() {
43 NeuralNetwork net = conf.configuredNetworkFactory();
44 net.train(trainSet);
45
46 double error = Double.NaN;
47 try {
48 error = net.getStatisticsFacility().computePredictionError(net, env.getTestDataset(0));
49 } catch (EnvironmentException e) {
50 fail();
51 }
52
53 log.info("Test set MSE: " + error);
54 assertTrue("Prediction error `" + error + "´ should be less than the required mark `"
55 + TARGETED_PREDICTION_MSE + "´.", error < TARGETED_PREDICTION_MSE);
56 }
57 }