1
2
3
4
5
6
7 package yawn.envs.synthetic;
8
9 import java.util.ArrayList;
10
11 import org.apache.commons.logging.Log;
12 import org.apache.commons.logging.LogFactory;
13
14 import yawn.config.ValidationException;
15 import yawn.envs.Environment;
16 import yawn.envs.EnvironmentException;
17 import yawn.util.ErrorUtils;
18 import yawn.util.InputOutputPattern;
19 import yawn.util.Pattern;
20
21 /***
22 * An abstract class that represents a <i>Synthetic environment </i>. A
23 * synthetic environment is one that produces the data by randomly sampling a priori
24 * known function. These type of environments are most usefull for automated
25 * testing.
26 *
27 * <p>$Id: SyntheticDataEnvironment.java,v 1.6 2005/05/09 11:04:57 supermarti Exp $</p>
28 *
29 * @author Luis Martí (luis dot marti at uc3m dot es)
30 * @version $Revision: 1.6 $
31 */
32 public abstract class SyntheticDataEnvironment extends Environment {
33
34 protected static final Log log = LogFactory.getLog(CircleInTheSquareEnvironment.class);
35
36 /***
37 *
38 * @uml.property name="testSetSize"
39 */
40 public int getTestSetSize() {
41 return this.testSetSize;
42 }
43
44 /***
45 *
46 * @uml.property name="testSetSize"
47 */
48 public void setTestSetSize(int testSetSize) {
49 this.testSetSize = testSetSize;
50 }
51
52 /***
53 *
54 * @uml.property name="trainSetSize"
55 */
56 public int getTrainSetSize() {
57 return this.trainSetSize;
58 }
59
60 /***
61 *
62 * @uml.property name="trainSetSize"
63 */
64 public void setTrainSetSize(int trainSetSize) {
65 this.trainSetSize = trainSetSize;
66 }
67
68 /***
69 *
70 * @uml.property name="trainSetSize"
71 */
72 protected int trainSetSize;
73
74 /***
75 *
76 * @uml.property name="testSetSize"
77 */
78 protected int testSetSize;
79
80 /***
81 *
82 * @uml.property name="numberOfSystemRuns"
83 */
84 protected int numberOfSystemRuns = 1;
85
86 /***
87 *
88 * @uml.property name="trainSetsList"
89 * @uml.associationEnd multiplicity="(0 -1)" elementType="[Lyawn.util.InputOutputPattern;"
90 */
91 protected ArrayList trainSetsList;
92
93 /***
94 *
95 * @uml.property name="testSetsList"
96 * @uml.associationEnd multiplicity="(0 -1)" elementType="[Lyawn.util.InputOutputPattern;"
97 */
98 protected ArrayList testSetsList;
99
100 /***
101 *
102 * @uml.property name="numberOfSystemRuns"
103 */
104 public int getNumberOfSystemRuns() {
105 return numberOfSystemRuns;
106 }
107
108 /***
109 *
110 * @uml.property name="numberOfSystemRuns"
111 */
112 public void setNumberOfSystemRuns(int numberOfRuns) {
113 numberOfSystemRuns = numberOfRuns;
114 }
115
116 public void init() {
117 trainSetsList = new ArrayList();
118 testSetsList = new ArrayList();
119 for (int i = 0; i < numberOfSystemRuns; i++) {
120 trainSetsList.add(generateSet(getTrainSetSize()));
121 testSetsList.add(generateSet(getTestSetSize()));
122 }
123 }
124
125 public InputOutputPattern[] getTrainingDataset(int runNumber) throws EnvironmentException {
126 InputOutputPattern[] iops = null;
127 try {
128 iops = (InputOutputPattern[]) trainSetsList.get(runNumber);
129 } catch (NullPointerException e) {
130 init();
131 return getTrainingDataset(runNumber);
132 }
133 return iops;
134 }
135
136 protected InputOutputPattern[] generateSet(int setSize) {
137 InputOutputPattern[] res = new InputOutputPattern[setSize];
138
139 for (int i = 0; i < setSize; i++) {
140 InputOutputPattern iop = new InputOutputPattern(2, 1);
141 Pattern input = generateRandomInput();
142 iop.input = input;
143 iop.output = synthetizeOutput(input);
144 res[i] = iop;
145 }
146
147 return res;
148 }
149
150 protected abstract Pattern generateRandomInput();
151
152 protected abstract Pattern synthetizeOutput(Pattern input);
153
154 public Pattern[] getTestDatasetInputs(int runNumber) throws EnvironmentException {
155 InputOutputPattern[] iops = null;
156 try {
157 iops = (InputOutputPattern[]) testSetsList.get(runNumber);
158 } catch (NullPointerException e) {
159 init();
160 return getTestDatasetInputs(runNumber);
161 }
162
163 Pattern[] res = new Pattern[iops.length];
164
165 for (int i = 0; i < iops.length; i++) {
166 res[i] = iops[i].input;
167 }
168
169 return res;
170 }
171
172 public InputOutputPattern[] getTestDataset(int runNumber) throws EnvironmentException {
173 InputOutputPattern[] iops = null;
174 try {
175 iops = (InputOutputPattern[]) testSetsList.get(runNumber);
176 } catch (NullPointerException e) {
177 init();
178 return getTestDataset(runNumber);
179 }
180
181 return iops;
182 }
183
184 public Pattern[] getTestExpectedOutputsSet(int runNumber) throws EnvironmentException {
185 InputOutputPattern[] iops = null;
186 try {
187 iops = (InputOutputPattern[]) testSetsList.get(runNumber);
188 } catch (NullPointerException e) {
189 init();
190 return getTestExpectedOutputsSet(runNumber);
191 }
192
193 Pattern[] res = new Pattern[iops.length];
194
195 for (int i = 0; i < iops.length; i++) {
196 res[i] = iops[i].output;
197 }
198
199 return res;
200 }
201
202 public InputOutputPattern[] getFullTestDataset(int runNumber) {
203 InputOutputPattern[] iops = null;
204 try {
205 iops = (InputOutputPattern[]) testSetsList.get(runNumber);
206 } catch (NullPointerException e) {
207 init();
208 return getFullTestDataset(runNumber);
209 }
210 return iops;
211 }
212
213 public void writeResults(Pattern[] results, int runNumber) throws EnvironmentException {
214 Pattern[] res = getTestExpectedOutputsSet(runNumber);
215 log.info("Mean square test set error for run number " + runNumber + ": "
216 + ErrorUtils.meanSquaredError(res, results) + ".");
217 }
218
219 public void validate() throws ValidationException {
220 if ((trainSetSize < 0) || (testSetSize < 0)) {
221 throw new ValidationException("Wrong set sizes");
222 }
223 }
224
225
226
227
228
229
230 public int inputSize() {
231 if (trainSetsList == null) {
232 init();
233 }
234 return ((InputOutputPattern[]) trainSetsList.get(0))[0].input.size();
235 }
236
237
238
239
240
241
242 public int outputSize() {
243 if (trainSetsList == null) {
244 init();
245 }
246 return ((InputOutputPattern[]) trainSetsList.get(0))[0].output.size();
247 }
248
249 }