1 /***
2 *
3 */
4 package yawn.optim.genetic;
5
6 import java.lang.reflect.InvocationTargetException;
7 import java.util.Arrays;
8 import java.util.Comparator;
9 import java.util.List;
10
11 import org.apache.commons.beanutils.BeanUtils;
12 import org.apache.commons.collections.KeyValue;
13 import org.apache.commons.collections.keyvalue.DefaultKeyValue;
14 import org.apache.commons.logging.Log;
15 import org.apache.commons.logging.LogFactory;
16 import org.jgap.Chromosome;
17 import org.jgap.Configuration;
18 import org.jgap.DeltaFitnessEvaluator;
19 import org.jgap.Genotype;
20 import org.jgap.InvalidConfigurationException;
21 import org.jgap.Population;
22 import org.jgap.impl.DefaultConfiguration;
23 import org.jgap.impl.WeightedRouletteSelector;
24
25 import yawn.YawnRuntimeException;
26 import yawn.config.ConfigurationException;
27 import yawn.config.NeuralNetworkConfig;
28 import yawn.nn.NeuralNetwork;
29 import yawn.nn.committee.NetworkCommittee;
30 import yawn.nn.committee.functions.Popularity;
31 import yawn.optim.OptimizableModel;
32 import yawn.optim.ParamerFitter;
33 import yawn.util.InputOutputPattern;
34 import yawn.util.Pattern;
35
36 /***
37 * This is yawn.optim.GeneticParameterFitter, part of the yawn project.
38 *
39 * <p>$Id: GeneticParameterFitter.java,v 1.6 2005/05/09 11:04:56 supermarti Exp $</p>
40 *
41 * @author Luis Martí (luis dot marti at uc3m dot es)
42 * @version $Revision: 1.6 $
43 */
44 public class GeneticParameterFitter extends ParamerFitter {
45
46 private static final Log log = LogFactory
47 .getLog(GeneticParameterFitter.class);
48
49 /***
50 *
51 */
52 private static final long serialVersionUID = 3833462907550709300L;
53
54 /***
55 *
56 * @uml.property name="fittedModel"
57 * @uml.associationEnd multiplicity="(0 1)"
58 */
59 protected NeuralNetwork fittedModel;
60
61 /***
62 *
63 * @uml.property name="modeltoBeFittedConfig"
64 * @uml.associationEnd multiplicity="(0 1)"
65 */
66 protected NeuralNetworkConfig modeltoBeFittedConfig;
67
68 /***
69 *
70 * @uml.property name="maxGenerations"
71 */
72 protected int maxGenerations;
73
74 /***
75 *
76 * @uml.property name="populationSize"
77 */
78 protected int populationSize;
79
80 /***
81 *
82 * @uml.property name="preserveFittestIndividual"
83 */
84 protected boolean preserveFittestIndividual;
85
86
87 /***
88 *
89 */
90 public GeneticParameterFitter() {
91 super();
92 }
93
94 /***
95 * @return Returns the fittedModel.
96 *
97 * @uml.property name="fittedModel"
98 */
99 public NeuralNetwork getFittedModel() {
100 return fittedModel;
101 }
102
103 /***
104 * @return Returns the modeltoBeFittedConfig.
105 *
106 * @uml.property name="modeltoBeFittedConfig"
107 */
108 public NeuralNetworkConfig getModeltoBeFittedConfig() {
109 return modeltoBeFittedConfig;
110 }
111
112
113
114
115
116
117
118 public int getInputSize() {
119 return modeltoBeFittedConfig.getEnvironment().inputSize();
120 }
121
122 /***
123 * @return Returns the maxGenerations.
124 *
125 * @uml.property name="maxGenerations"
126 */
127 public int getMaxGenerations() {
128 return maxGenerations;
129 }
130
131
132
133
134
135
136
137 public String getNeuralNetworkName() {
138
139 return "Genetic Algorithm Fitter of Model Parameters";
140 }
141
142
143
144
145
146
147 public int getOutputSize() {
148 return modeltoBeFittedConfig.getEnvironment().outputSize();
149 }
150
151 /***
152 * @return Returns the populationSize.
153 *
154 * @uml.property name="populationSize"
155 */
156 public int getPopulationSize() {
157 return populationSize;
158 }
159
160 /***
161 * @return Returns the preserveFittestIndividual.
162 *
163 * @uml.property name="preserveFittestIndividual"
164 */
165 public boolean isPreserveFittestIndividual() {
166 return preserveFittestIndividual;
167 }
168
169
170
171
172
173
174
175
176 public void oneLearningStep(Pattern input, Pattern output) {
177
178
179 }
180
181
182
183
184
185
186 public Pattern predict(Pattern input) {
187 return fittedModel.predict(input);
188 }
189
190 /***
191 * @param fittedModel
192 * The fittedModel to set.
193 *
194 * @uml.property name="fittedModel"
195 */
196 public void setFittedModel(NeuralNetwork fittedModel) {
197 this.fittedModel = fittedModel;
198 }
199
200 /***
201 * @param initialModelConfig
202 * The initialModelConfig to set.
203 *
204 * @uml.property name="modeltoBeFittedConfig"
205 */
206 public void setModeltoBeFittedConfig(NeuralNetworkConfig initialModelConfig) {
207 this.modeltoBeFittedConfig = initialModelConfig;
208 }
209
210 /***
211 * @param maxGenerations
212 * The maxGenerations to set.
213 *
214 * @uml.property name="maxGenerations"
215 */
216 public void setMaxGenerations(int maxGenerations) {
217 this.maxGenerations = maxGenerations;
218 }
219
220 /***
221 * @param populationSize
222 * The populationSize to set.
223 *
224 * @uml.property name="populationSize"
225 */
226 public void setPopulationSize(int populationSize) {
227 this.populationSize = populationSize;
228 }
229
230 /***
231 * @param preserveFittestIndividual
232 * The preserveFittestIndividual to set.
233 *
234 * @uml.property name="preserveFittestIndividual"
235 */
236 public void setPreserveFittestIndividual(boolean preserveFittestIndividual) {
237 this.preserveFittestIndividual = preserveFittestIndividual;
238 }
239
240
241
242
243
244
245 public void setup(NeuralNetworkConfig config) throws ConfigurationException {
246 try {
247 BeanUtils.copyProperties(this, config);
248 } catch (IllegalAccessException e) {
249 throw new ConfigurationException(e);
250 } catch (InvocationTargetException e) {
251 throw new ConfigurationException(e);
252 }
253
254 NeuralNetwork dummyNet = modeltoBeFittedConfig
255 .configuredNetworkFactory();
256
257 if (!(dummyNet instanceof OptimizableModel)) {
258 throw new ConfigurationException(
259 "The model to be optimized does not implements yawn.optim.OptimizableModel.");
260 }
261 }
262
263
264
265
266
267
268 public void train(InputOutputPattern[] iop) {
269
270 Configuration conf = new DefaultConfiguration();
271
272 NeuralNetwork dummyNet = modeltoBeFittedConfig
273 .configuredNetworkFactory();
274
275 if (!(dummyNet instanceof OptimizableModel)) {
276
277 throw new YawnRuntimeException(
278 "The model to be optimazed does not implements yawn.optim.OptimizableModel.");
279 }
280
281 JGapAdapter adapter = (JGapAdapter) ((OptimizableModel) dummyNet)
282 .getAdapterInstance();
283
284 int trainSetSize = (int) Math.round(0.8 * iop.length);
285
286 InputOutputPattern[] trainSet = new InputOutputPattern[trainSetSize];
287 InputOutputPattern[] testSet = new InputOutputPattern[iop.length
288 - trainSetSize];
289 InputOutputPattern[] aset = InputOutputPattern.randomOrderList(iop);
290
291 System.arraycopy(aset, 0, trainSet, 0, trainSetSize);
292 System.arraycopy(aset, trainSetSize, testSet, 0, testSet.length);
293
294 adapter.setTrainingSet(trainSet);
295 adapter.setEvalSet(testSet);
296
297 Genotype genotype = null;
298
299 try {
300 conf.setFitnessFunction(adapter);
301 conf.setSampleChromosome(adapter.getSampleChromosome());
302 conf.setPopulationSize(populationSize);
303 conf.setPreservFittestIndividual(preserveFittestIndividual);
304 conf.setFitnessEvaluator(new DeltaFitnessEvaluator());
305
306 conf.addNaturalSelector(new WeightedRouletteSelector(), true);
307 genotype = Genotype.randomInitialGenotype(conf);
308 } catch (InvalidConfigurationException e) {
309 log.debug(e);
310 throw new YawnRuntimeException(e);
311 }
312
313 for (int i = 0; i < maxGenerations; i++) {
314 log.debug("Starting evolving generation " + i + " of "
315 + maxGenerations + "...");
316 genotype.evolve();
317 }
318
319 log.debug("Training fittest models ...");
320
321 NetworkCommittee fit = new NetworkCommittee(new Popularity());
322 NeuralNetwork[] nets = getTopOfPopulation(adapter, genotype
323 .getPopulation(), (int)Math.round( populationSize * 0.2) + 1);
324
325 for (int i = 0; i < nets.length; i++) {
326 fit.addCommitteeMembers(nets[i]);
327 }
328
329 fittedModel = fit;
330 fittedModel.train(iop);
331 }
332
333
334
335
336
337
338 public NeuralNetworkConfig yieldConfiguration() {
339 NeuralNetworkConfig res = new GeneticParameterFitterConfig();
340 try {
341 BeanUtils.copyProperties(res, this);
342 } catch (IllegalAccessException e) {
343 throw new YawnRuntimeException(e);
344 } catch (InvocationTargetException e) {
345 throw new YawnRuntimeException(e);
346 }
347 return res;
348 }
349
350 private NeuralNetwork[] getTopOfPopulation(JGapAdapter adapter,
351 Population population, int amount) {
352 List chromosomes = population.getChromosomes();
353
354
355
356
357
358
359
360
361 DefaultKeyValue[] keyVals = new DefaultKeyValue[chromosomes.size()];
362
363 for (int i = 0; i < chromosomes.size(); i++) {
364 Chromosome chromo = (Chromosome) chromosomes.get(i);
365 double fitness = adapter.evaluate(chromo);
366 keyVals[i] = new DefaultKeyValue(new Double(fitness), chromo);
367 }
368
369
370
371 Arrays.sort(keyVals, new KeyValueComparator());
372
373 NeuralNetwork[] res = new NeuralNetwork[Math
374 .min(keyVals.length, amount)];
375
376 for (int i = 0; i < res.length; i++) {
377 log.debug("Selected chromosome with eval index: " + keyVals[i]);
378 res[i] = adapter.buildNeuralNetworkFromChromosome((Chromosome)keyVals[i].getValue());
379 }
380
381 return res;
382 }
383
384 private class KeyValueComparator implements Comparator {
385
386 public int compare(Object arg0, Object arg1) {
387 double a = ((Double) ((KeyValue) arg0).getKey()).doubleValue();
388 double b = ((Double) ((KeyValue) arg1).getKey()).doubleValue();
389 if (a == b) {
390 return 0;
391 }
392 if (a > b) {
393 return 1;
394 }
395 return -1;
396 }
397
398 }
399 }