View Javadoc

1   /***
2    * 
3    */
4   package yawn.optim.genetic;
5   
6   import java.lang.reflect.InvocationTargetException;
7   import java.util.Arrays;
8   import java.util.Comparator;
9   import java.util.List;
10  
11  import org.apache.commons.beanutils.BeanUtils;
12  import org.apache.commons.collections.KeyValue;
13  import org.apache.commons.collections.keyvalue.DefaultKeyValue;
14  import org.apache.commons.logging.Log;
15  import org.apache.commons.logging.LogFactory;
16  import org.jgap.Chromosome;
17  import org.jgap.Configuration;
18  import org.jgap.DeltaFitnessEvaluator;
19  import org.jgap.Genotype;
20  import org.jgap.InvalidConfigurationException;
21  import org.jgap.Population;
22  import org.jgap.impl.DefaultConfiguration;
23  import org.jgap.impl.WeightedRouletteSelector;
24  
25  import yawn.YawnRuntimeException;
26  import yawn.config.ConfigurationException;
27  import yawn.config.NeuralNetworkConfig;
28  import yawn.nn.NeuralNetwork;
29  import yawn.nn.committee.NetworkCommittee;
30  import yawn.nn.committee.functions.Popularity;
31  import yawn.optim.OptimizableModel;
32  import yawn.optim.ParamerFitter;
33  import yawn.util.InputOutputPattern;
34  import yawn.util.Pattern;
35  
36  /***
37   * This is yawn.optim.GeneticParameterFitter, part of the yawn project.
38   * 
39   * <p>$Id: GeneticParameterFitter.java,v 1.6 2005/05/09 11:04:56 supermarti Exp $</p>
40   * 
41   * @author Luis Mart&iacute; (luis dot marti at uc3m dot es)
42   * @version $Revision: 1.6 $
43   */
44  public class GeneticParameterFitter extends ParamerFitter {
45  
46  	private static final Log log = LogFactory
47  			.getLog(GeneticParameterFitter.class);
48  
49  	/***
50  	 * 
51  	 */
52  	private static final long serialVersionUID = 3833462907550709300L;
53  
54  	/***
55  	 * 
56  	 * @uml.property name="fittedModel"
57  	 * @uml.associationEnd multiplicity="(0 1)"
58  	 */
59  	protected NeuralNetwork fittedModel;
60  
61  	/***
62  	 * 
63  	 * @uml.property name="modeltoBeFittedConfig"
64  	 * @uml.associationEnd multiplicity="(0 1)"
65  	 */
66  	protected NeuralNetworkConfig modeltoBeFittedConfig;
67  
68  	/***
69  	 * 
70  	 * @uml.property name="maxGenerations" 
71  	 */
72  	protected int maxGenerations;
73  
74  	/***
75  	 * 
76  	 * @uml.property name="populationSize" 
77  	 */
78  	protected int populationSize;
79  
80  	/***
81  	 * 
82  	 * @uml.property name="preserveFittestIndividual" 
83  	 */
84  	protected boolean preserveFittestIndividual;
85  
86  
87  	/***
88  	 * 
89  	 */
90  	public GeneticParameterFitter() {
91  		super();
92  	}
93  
94  	/***
95  	 * @return Returns the fittedModel.
96  	 * 
97  	 * @uml.property name="fittedModel"
98  	 */
99  	public NeuralNetwork getFittedModel() {
100 		return fittedModel;
101 	}
102 
103 	/***
104 	 * @return Returns the modeltoBeFittedConfig.
105 	 * 
106 	 * @uml.property name="modeltoBeFittedConfig"
107 	 */
108 	public NeuralNetworkConfig getModeltoBeFittedConfig() {
109 		return modeltoBeFittedConfig;
110 	}
111 
112 
113 	/*
114 	 * (non-Javadoc)
115 	 * 
116 	 * @see yawn.nn.NeuralNetwork#getInputSize()
117 	 */
118 	public int getInputSize() {
119 		return modeltoBeFittedConfig.getEnvironment().inputSize();
120 	}
121 
122 	/***
123 	 * @return Returns the maxGenerations.
124 	 * 
125 	 * @uml.property name="maxGenerations"
126 	 */
127 	public int getMaxGenerations() {
128 		return maxGenerations;
129 	}
130 
131 
132 	/*
133 	 * (non-Javadoc)
134 	 * 
135 	 * @see yawn.nn.NeuralNetwork#getNeuralNetworkName()
136 	 */
137 	public String getNeuralNetworkName() {
138 		// TODO Auto-generated method stub
139 		return "Genetic Algorithm Fitter of Model Parameters";
140 	}
141 
142 	/*
143 	 * (non-Javadoc)
144 	 * 
145 	 * @see yawn.nn.NeuralNetwork#getOutputSize()
146 	 */
147 	public int getOutputSize() {
148 		return modeltoBeFittedConfig.getEnvironment().outputSize();
149 	}
150 
151 	/***
152 	 * @return Returns the populationSize.
153 	 * 
154 	 * @uml.property name="populationSize"
155 	 */
156 	public int getPopulationSize() {
157 		return populationSize;
158 	}
159 
160 	/***
161 	 * @return Returns the preserveFittestIndividual.
162 	 * 
163 	 * @uml.property name="preserveFittestIndividual"
164 	 */
165 	public boolean isPreserveFittestIndividual() {
166 		return preserveFittestIndividual;
167 	}
168 
169 
170 	/*
171 	 * (non-Javadoc)
172 	 * 
173 	 * @see yawn.nn.NeuralNetwork#oneLearningStep(yawn.util.Pattern,
174 	 *      yawn.util.Pattern)
175 	 */
176 	public void oneLearningStep(Pattern input, Pattern output) {
177 		// TODO Auto-generated method stub
178 
179 	}
180 
181 	/*
182 	 * (non-Javadoc)
183 	 * 
184 	 * @see yawn.nn.NeuralNetwork#predict(yawn.util.Pattern)
185 	 */
186 	public Pattern predict(Pattern input) {
187 		return fittedModel.predict(input);
188 	}
189 
190 	/***
191 	 * @param fittedModel
192 	 *            The fittedModel to set.
193 	 * 
194 	 * @uml.property name="fittedModel"
195 	 */
196 	public void setFittedModel(NeuralNetwork fittedModel) {
197 		this.fittedModel = fittedModel;
198 	}
199 
200 	/***
201 	 * @param initialModelConfig
202 	 *            The initialModelConfig to set.
203 	 * 
204 	 * @uml.property name="modeltoBeFittedConfig"
205 	 */
206 	public void setModeltoBeFittedConfig(NeuralNetworkConfig initialModelConfig) {
207 		this.modeltoBeFittedConfig = initialModelConfig;
208 	}
209 
210 	/***
211 	 * @param maxGenerations
212 	 *            The maxGenerations to set.
213 	 * 
214 	 * @uml.property name="maxGenerations"
215 	 */
216 	public void setMaxGenerations(int maxGenerations) {
217 		this.maxGenerations = maxGenerations;
218 	}
219 
220 	/***
221 	 * @param populationSize
222 	 *            The populationSize to set.
223 	 * 
224 	 * @uml.property name="populationSize"
225 	 */
226 	public void setPopulationSize(int populationSize) {
227 		this.populationSize = populationSize;
228 	}
229 
230 	/***
231 	 * @param preserveFittestIndividual
232 	 *            The preserveFittestIndividual to set.
233 	 * 
234 	 * @uml.property name="preserveFittestIndividual"
235 	 */
236 	public void setPreserveFittestIndividual(boolean preserveFittestIndividual) {
237 		this.preserveFittestIndividual = preserveFittestIndividual;
238 	}
239 
240 	/*
241 	 * (non-Javadoc)
242 	 * 
243 	 * @see yawn.nn.NeuralNetwork#setup(yawn.config.NeuralNetworkConfig)
244 	 */
245 	public void setup(NeuralNetworkConfig config) throws ConfigurationException {
246 		try {
247 			BeanUtils.copyProperties(this, config);
248 		} catch (IllegalAccessException e) {
249 			throw new ConfigurationException(e);
250 		} catch (InvocationTargetException e) {
251 			throw new ConfigurationException(e);
252 		}
253 
254 		NeuralNetwork dummyNet = modeltoBeFittedConfig
255 				.configuredNetworkFactory();
256 
257 		if (!(dummyNet instanceof OptimizableModel)) {
258 			throw new ConfigurationException(
259 					"The model to be optimized does not implements yawn.optim.OptimizableModel.");
260 		}
261 	}
262 
263 	/*
264 	 * (non-Javadoc)
265 	 * 
266 	 * @see yawn.nn.NeuralNetwork#train(yawn.util.InputOutputPattern[])
267 	 */
268 	public void train(InputOutputPattern[] iop) {
269 
270 		Configuration conf = new DefaultConfiguration();
271 
272 		NeuralNetwork dummyNet = modeltoBeFittedConfig
273 				.configuredNetworkFactory();
274 
275 		if (!(dummyNet instanceof OptimizableModel)) {
276 
277 			throw new YawnRuntimeException(
278 					"The model to be optimazed does not implements yawn.optim.OptimizableModel.");
279 		}
280 
281 		JGapAdapter adapter = (JGapAdapter) ((OptimizableModel) dummyNet)
282 				.getAdapterInstance();
283 
284 		int trainSetSize = (int) Math.round(0.8 * iop.length);
285 
286 		InputOutputPattern[] trainSet = new InputOutputPattern[trainSetSize];
287 		InputOutputPattern[] testSet = new InputOutputPattern[iop.length
288 				- trainSetSize];
289 		InputOutputPattern[] aset = InputOutputPattern.randomOrderList(iop);
290 
291 		System.arraycopy(aset, 0, trainSet, 0, trainSetSize);
292 		System.arraycopy(aset, trainSetSize, testSet, 0, testSet.length);
293 
294 		adapter.setTrainingSet(trainSet);
295 		adapter.setEvalSet(testSet);
296 
297 		Genotype genotype = null;
298 
299 		try {
300 			conf.setFitnessFunction(adapter);
301 			conf.setSampleChromosome(adapter.getSampleChromosome());
302 			conf.setPopulationSize(populationSize);
303 			conf.setPreservFittestIndividual(preserveFittestIndividual);
304 			conf.setFitnessEvaluator(new DeltaFitnessEvaluator());
305 
306 			conf.addNaturalSelector(new WeightedRouletteSelector(), true);
307 			genotype = Genotype.randomInitialGenotype(conf);
308 		} catch (InvalidConfigurationException e) {
309 			log.debug(e);
310 			throw new YawnRuntimeException(e);
311 		}
312 
313 		for (int i = 0; i < maxGenerations; i++) {
314 			log.debug("Starting evolving generation " + i + " of "
315 					+ maxGenerations + "...");
316 			genotype.evolve();
317 		}
318 
319 		log.debug("Training fittest models ...");
320 
321 		NetworkCommittee fit = new NetworkCommittee(new Popularity());
322 		NeuralNetwork[] nets = getTopOfPopulation(adapter, genotype
323 				.getPopulation(), (int)Math.round( populationSize * 0.2) + 1);
324 
325 		for (int i = 0; i < nets.length; i++) {
326 			fit.addCommitteeMembers(nets[i]);
327 		}
328 
329 		fittedModel = fit;
330 		fittedModel.train(iop);
331 	}
332 
333 	/*
334 	 * (non-Javadoc)
335 	 * 
336 	 * @see yawn.nn.NeuralNetwork#yieldConfiguration()
337 	 */
338 	public NeuralNetworkConfig yieldConfiguration() {
339 		NeuralNetworkConfig res = new GeneticParameterFitterConfig();
340 		try {
341 			BeanUtils.copyProperties(res, this);
342 		} catch (IllegalAccessException e) {
343 			throw new YawnRuntimeException(e);
344 		} catch (InvocationTargetException e) {
345 			throw new YawnRuntimeException(e);
346 		}
347 		return res;
348 	}
349 
350 	private NeuralNetwork[] getTopOfPopulation(JGapAdapter adapter,
351 			Population population, int amount) {
352 		List chromosomes = population.getChromosomes();
353 
354 		// HashMap map = new HashMap();
355 		//
356 		// for (Iterator i = chromosomes.iterator(); i.hasNext();) {
357 		// Chromosome chromo = (Chromosome) i.next();
358 		// double fitness = adapter.evaluate(chromo);
359 		// map.put(new Double(fitness), chromo);
360 		// }
361 		DefaultKeyValue[] keyVals = new DefaultKeyValue[chromosomes.size()];
362 
363 		for (int i = 0; i < chromosomes.size(); i++) {
364 			Chromosome chromo = (Chromosome) chromosomes.get(i);
365 			double fitness = adapter.evaluate(chromo);
366 			keyVals[i] = new DefaultKeyValue(new Double(fitness), chromo);
367 		}
368 
369 		// Double[] keysList = (Double[]) map.keySet().toArray(
370 		// new Double[map.size()]);
371 		Arrays.sort(keyVals, new KeyValueComparator());
372 
373 		NeuralNetwork[] res = new NeuralNetwork[Math
374 				.min(keyVals.length, amount)];
375 
376 		for (int i = 0; i < res.length; i++) {
377 			log.debug("Selected chromosome with eval index: " + keyVals[i]);
378 			res[i] = adapter.buildNeuralNetworkFromChromosome((Chromosome)keyVals[i].getValue());
379 		}
380 
381 		return res;
382 	}
383 
384 	private class KeyValueComparator implements Comparator {
385 
386 		public int compare(Object arg0, Object arg1) {
387 			double a = ((Double) ((KeyValue) arg0).getKey()).doubleValue();
388 			double b = ((Double) ((KeyValue) arg1).getKey()).doubleValue();
389 			if (a == b) {
390 				return 0;
391 			}
392 			if (a > b) {
393 				return 1;
394 			}
395 			return -1;
396 		}
397 
398 	}
399 }