1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22 package de.fu_berlin.ties.classify;
23
24 import java.util.HashSet;
25 import java.util.Iterator;
26 import java.util.List;
27 import java.util.Set;
28
29 import org.apache.commons.lang.ArrayUtils;
30 import org.apache.commons.lang.builder.ToStringBuilder;
31
32 import de.fu_berlin.ties.TiesConfiguration;
33 import de.fu_berlin.ties.util.Util;
34
35 /***
36 * This class provides support for iterative training, also called TUNE
37 * (Train-until-no-errors) training.
38 *
39 * <p>Instances of this class are not thread-safe and must be synchronized
40 * externally, if required.
41 *
42 * @author Christian Siefkes
43 * @version $Revision: 1.9 $, $Date: 2006/10/21 16:03:55 $, $Author: siefkes $
44 */
45 public class Tuner {
46
47 /***
48 * Configuration key: If given, the specified string is used to separate
49 * the training from the testing section of the corpus (e.g. "---") and
50 * the {@linkplain #getTrainSplit() train split} and
51 * {@linkplain #getTestSplit() test split} values are ignored.
52 */
53 public static final String CONFIG_SPLIT_SEPARATPR = "eval.split.separator";
54
55 /***
56 * Configuration key: The percentage of a corpus to use for training.
57 */
58 public static final String CONFIG_TRAIN_SPLIT = "eval.train-split";
59
60 /***
61 * Configuration key: The percentage of a corpus to use for testing
62 * (evaluation).
63 */
64 public static final String CONFIG_TEST_SPLIT = "eval.test-split";
65
66 /***
67 * Configuration key: The maximum number of iterations used for TUNE
68 * (train until no error) training; if 1, training is incremental.
69 */
70 public static final String CONFIG_TUNE = "train.tune";
71
72 /***
73 * Configuration key: TUNE training is stopped if the training accuracy
74 * didn't improve for the specified number of iterations.
75 */
76 public static final String CONFIG_TUNE_STOP = "train.tune.stop";
77
78 /***
79 * Configuration key: Whether to measure results after each TUNE iteration
80 * or only at the end of training.
81 */
82 public static final String CONFIG_TUNE_EACH = "eval.tune.each";
83
84 /***
85 * Configuration key: The training iteration after which to evaluate
86 * results for the first time if {@link #CONFIG_TUNE_EACH} is enabled.
87 */
88 public static final String CONFIG_TUNE_SINCE = "eval.tune.since";
89
90
91 /***
92 * If not <code>null</code>, the specified string is used to separate
93 * the training from the testing section of the corpus (e.g. "---") and
94 * the {@linkplain #getTrainSplit() train split} and
95 * {@linkplain #getTestSplit() test split} values are ignored.
96 */
97 private final String splitSeparator;
98
99 /***
100 * The percentage of a corpus to use for training.
101 */
102 private final float trainSplit;
103
104 /***
105 * The percentage of a corpus to use for testing; if <code>-1</code>,
106 * all remaining documents should be used.
107 */
108 private final float testSplit;
109
110 /***
111 * The maximum number of iterations used for TUNE (train until no error)
112 * training; if 1, training is incremental. <em>Note that iterations should
113 * be indexed from 1 to X (this number) instead of from 0 to X-1 for
114 * compatibility with {@link #getTuneEvaluations()}.</em>
115 */
116 private final int tuneIterations;
117
118 /***
119 * TUNE training is stopped if the training accuracy didn't improve for the
120 * specified number of iterations.
121 */
122 private final int tuneStop;
123
124 /***
125 * Whether to measure results after each TUNE iteration or only at the
126 * end of training.
127 */
128 private final boolean tuneEach;
129
130 /***
131 * The training iteration after which to evaluate results for the first
132 * time if {@link #tuneEach} is enabled.
133 */
134 private final int tuneSince;
135
136 /***
137 * A set of iterations after which to evaluate TUNE training in addition to
138 * the last one; ignored if {@link #tuneEach} is <code>true</code>.
139 */
140 private final Set<Integer> tuneEvaluations = new HashSet<Integer>();
141
142 /***
143 * Used to decide when to stop TUNE training.
144 */
145 private double[] lastAcc = null;
146
147 /***
148 * Used to decide when to stop TUNE training.
149 */
150 private boolean allAreOptimal;
151
152 /***
153 * Used to decide when to stop TUNE training.
154 */
155 private boolean noneGotBetter;
156
157 /***
158 * Used to decide when to stop TUNE training.
159 */
160 private boolean someGotWorse;
161
162 /***
163 * Used to decide when to stop TUNE training.
164 */
165 private int noneGotBetterCounter = 0;
166
167
168 /***
169 * Creates a new instance.
170 *
171 * @param config used to configure this instance
172 * @param suffix an optional suffix used to
173 * {@linkplain TiesConfiguration#adaptKey(String, String) adapt
174 * configuration keys}; might be <code>null</code>
175 */
176 public Tuner(final TiesConfiguration config, final String suffix) {
177 this(config.getFloat(config.adaptKey(CONFIG_TRAIN_SPLIT, suffix)),
178 config.getFloat(config.adaptKey(CONFIG_TEST_SPLIT, suffix)),
179 config.getString(config.adaptKey(
180 CONFIG_SPLIT_SEPARATPR, suffix), null),
181 config.getInt(config.adaptKey(CONFIG_TUNE, suffix)),
182 config.getInt(config.adaptKey(CONFIG_TUNE_STOP, suffix)),
183 config.getBoolean(config.adaptKey(CONFIG_TUNE_EACH, suffix)),
184 config.getInt(config.adaptKey(CONFIG_TUNE_SINCE, suffix)),
185 config.getList(config.adaptKey("eval.tune.list", suffix)));
186 }
187
188 /***
189 * Creates a new instance.
190 *
191 * @param trainingSplit the percentage of a corpus to use for training
192 * @param testingSplit the percentage of a corpus to use for testing
193 * (evaluation); if <code>-1</code>, all remaining documents (1 -
194 * <code>trainingSplit</code>) are used
195 * @param splitSep if not <code>null</code>, the specified string is used
196 * to separate the training from the testing section of the corpus and
197 * the {@linkplain #getTrainSplit() train split} and
198 * {@linkplain #getTestSplit() test split} values are ignored.
199 * @param tuneRuns the maximum number of iterations used for TUNE
200 * (train until no error) training; if 1, training is incremental
201 * @param tuneStopAfter TUNE training is stopped if the training accuracy
202 * didn't improve for the specified number of iterations.
203 * @param measureEachTUNE whether to measure results after each TUNE
204 * iteration or only at the end of training
205 * @param startMeasureTUNE he training iteration after which to evaluate
206 * results for the first time if <code>measureEachTUNE</code> is enabled
207 * (ignored otherwise)
208 * @param tuneEvalList A list of Integers or int Strings specifying
209 * iterations after which to evaluate TUNE training in addition to the last
210 * one; ignored if <code>measureEachTUNE</code> is <code>true</code>
211 * @throws IllegalArgumentException if <code>trainingSplit</code> is not
212 * a percentage (larger than 1 or smaller than 0) or if
213 * <code>tuneRuns</code> is non-positive
214 */
215 public Tuner(final float trainingSplit, final float testingSplit,
216 final String splitSep, final int tuneRuns, final int tuneStopAfter,
217 final boolean measureEachTUNE, final int startMeasureTUNE,
218 final List tuneEvalList) throws IllegalArgumentException {
219 super();
220
221 if ((splitSep == null)
222 && ((trainingSplit < 0.0) || (trainingSplit > 1.0))) {
223 throw new IllegalArgumentException(
224 "Train split is not a percentage: " + trainingSplit);
225 }
226 if ((splitSep == null) && (testingSplit > 1.0)) {
227 throw new IllegalArgumentException(
228 "Test split is not a percentage: " + testingSplit);
229 }
230 if (tuneRuns < 1) {
231 throw new IllegalArgumentException(
232 "Number of TUNE runs must be at least 1: " + tuneRuns);
233 }
234 if (tuneStopAfter < 1) {
235 throw new IllegalArgumentException(
236 "TUNE stopping criterium must be at least 1: " + tuneStopAfter);
237 }
238
239 trainSplit = trainingSplit;
240 testSplit = testingSplit;
241 splitSeparator = splitSep;
242 tuneIterations = tuneRuns;
243 tuneStop = tuneStopAfter;
244 tuneEach = measureEachTUNE;
245 tuneSince = Math.max(startMeasureTUNE, 1);
246
247 if (!tuneEach) {
248 for (final Iterator iter = tuneEvalList.iterator();
249 iter.hasNext();) {
250
251 tuneEvaluations.add(Integer.valueOf(Util.asInt(iter.next())));
252 }
253 }
254
255 }
256
257
258 /***
259 * Whether to continue TUNE training after finishing an iteration.
260 *
261 * @param currentAcc the list of accuracies for the just finished
262 * TUNE iteration
263 * @param i the number of the just finished TUNE iterations,
264 * <strong>counting starts with 1 not with 0</strong>
265 * @return whether to continue TUNE training
266 */
267 public boolean continueTraining(final double[] currentAcc, final int i) {
268 if (i >= getTuneIterations()) {
269
270 return false;
271 }
272
273 boolean continueTraining = true;
274 allAreOptimal = true;
275 noneGotBetter = true;
276 someGotWorse = false;
277
278 for (int j = 0; j < currentAcc.length; j++) {
279
280 allAreOptimal = allAreOptimal
281 && (currentAcc[j] >= 1.0);
282
283
284 if (lastAcc != null) {
285 noneGotBetter = noneGotBetter
286 && (currentAcc[j] <= lastAcc[j]);
287 someGotWorse = someGotWorse
288 || (currentAcc[j] < lastAcc[j]);
289 } else {
290 noneGotBetter = false;
291 }
292 }
293
294 if (noneGotBetter) {
295 noneGotBetterCounter++;
296
297 if (noneGotBetterCounter >= tuneStop) {
298
299 continueTraining = false;
300 Util.LOG.debug("Stopping TUNE training after " + i
301 + " iterations because current accuracies ("
302 + ArrayUtils.toString(currentAcc)
303 + ") aren't higher than last ones ("
304 + ArrayUtils.toString(lastAcc)
305 + ") for the " + tuneStop
306 + ". time");
307 } else if (someGotWorse) {
308
309 continueTraining = false;
310 Util.LOG.debug("Stopping TUNE training after " + i
311 + " iterations because current accuracies ("
312 + ArrayUtils.toString(currentAcc)
313 + ") are lower than last ones ("
314 + ArrayUtils.toString(lastAcc) + ")");
315 }
316 }
317
318 if (allAreOptimal) {
319 continueTraining = false;
320 Util.LOG.debug("Stopping TUNE training after " + i
321 + " iterations because all accuracies are already "
322 + "optimal: "
323 + ArrayUtils.toString(currentAcc));
324 }
325
326 lastAcc = currentAcc;
327 return continueTraining;
328 }
329
330 /***
331 * Returns the percentage of a corpus to use for testing;
332 * if <code>-1</code>, all remaining documents should be used.
333 *
334 * @return the value of the attribute
335 */
336 public float getTestSplit() {
337 return testSplit;
338 }
339
340 /***
341 * Returns the percentage of a corpus to use for training.
342 *
343 * @return the value of the attribute
344 */
345 public float getTrainSplit() {
346 return trainSplit;
347 }
348
349 /***
350 * Returns the set of iterations after which to evaluate TUNE training in
351 * addition to the last one; should be ignored if {@link #isTuneEach()} is
352 * <code>true</code>.
353 *
354 * @return the value of the attribute
355 */
356 public Set<Integer> getTuneEvaluations() {
357 return tuneEvaluations;
358 }
359
360 /***
361 * If not <code>null</code>, the returned string should be used to separate
362 * the training from the testing section of the corpus (e.g. "---") and
363 * the {@linkplain #getTrainSplit() train split} and
364 * {@linkplain #getTestSplit() test split} values should be ignored.
365 *
366 * @return the value of the attribute
367 */
368 public String getSplitSeparator() {
369 return splitSeparator;
370 }
371
372 /***
373 * Returns the maximum number of iterations used for TUNE (train until no
374 * error) training; if 1, training is incremental. <em>Note that iterations
375 * should be indexed from 1 to X (this number) instead of from 0 to X-1 for
376 * compatibility with {@link #getTuneEvaluations()}.</em>
377 *
378 * @return the value of the attribute
379 */
380 public int getTuneIterations() {
381 return tuneIterations;
382 }
383
384 /***
385 * Returns the training iteration after which to evaluate results for the
386 * first time if {@link #isTuneEach()} is enabled.
387 *
388 * @return the value of the attribute
389 */
390 public int getTuneSince() {
391 return tuneSince;
392 }
393
394 /***
395 * Returns the TUNE stopping criterion: TUNE training should be stopped if
396 * the training accuracy didn't improve for the specified number of
397 * iterations.
398 *
399 * @return the value of the attribute
400 */
401 public int getTuneStop() {
402 return tuneStop;
403 }
404
405 /***
406 * Whether to measure results after each TUNE iteration or only at the
407 * end of training.
408 *
409 * @return the value of the attribute
410 */
411 public boolean isTuneEach() {
412 return tuneEach;
413 }
414
415 /***
416 * Resets the state of this instance. This method must be called before
417 * starting to TUNE train a set of instances after finishing training
418 * another set.
419 */
420 public void reset() {
421 noneGotBetterCounter = 0;
422 lastAcc = null;
423 }
424
425 /***
426 * Chooses files to use for training and files to use for evaluation,
427 * depending on the configured settings.
428 *
429 * @param allFiles the array of file names to process
430 * @param trainFiles populated with the files to use for training, will
431 * be populated with the first <em>{@link Tuner#getTrainSplit()} *
432 * allFiles.length</em> files; must initially be empty
433 * @param evalFiles populated with the files to use for evaluation, will
434 * be populated from the next <em>{@link Tuner#getTestSplit()} *
435 * allFiles.length</em> remaining files (or all remaining files if test
436 * split is negative); must initially be empty
437 * @throws IllegalArgumentException if the lists aren't empty
438 */
439 public void selectFiles(final String[] allFiles,
440 final List<String> trainFiles, final List<String> evalFiles)
441 throws IllegalArgumentException {
442
443 if (!trainFiles.isEmpty() || !evalFiles.isEmpty()) {
444 throw new IllegalArgumentException(
445 "Lists of train files and eval files must initially be empty");
446 }
447
448 final int numTrainFiles = Math.round(getTrainSplit() * allFiles.length);
449 final int filesToUse;
450
451 if (splitSeparator != null) {
452
453 int i = 0;
454
455 for (; (i < allFiles.length)
456 && !splitSeparator.equals(allFiles[i]); i++) {
457
458 trainFiles.add(allFiles[i]);
459 }
460
461
462 i++;
463
464 for (; (i < allFiles.length)
465 && !splitSeparator.equals(allFiles[i]); i++) {
466
467 evalFiles.add(allFiles[i]);
468 }
469 } else {
470
471 if (getTestSplit() < 0) {
472
473 filesToUse = allFiles.length;
474 } else {
475 filesToUse = Math.min(allFiles.length, Math.round(
476 (getTrainSplit() + getTestSplit()) * allFiles.length));
477 }
478
479
480 for (int i = 0; i < filesToUse; i++) {
481 if (i < numTrainFiles) {
482
483 trainFiles.add(allFiles[i]);
484 } else {
485
486 evalFiles.add(allFiles[i]);
487 }
488 }
489 }
490
491 final int usedFiles = trainFiles.size() + evalFiles.size();
492 Util.LOG.debug("Using " + usedFiles + " of " + allFiles.length
493 + " files: " + trainFiles.size() + " for training, "
494 + evalFiles.size() + " for evaluation");
495 }
496
497 /***
498 * Whether to evaluate results after this TUNE iteration.
499 *
500 * @param continueTraining the result returned by the preceding call to
501 * {@link #continueTraining(double[], int)}
502 * @param i the number of the just finished TUNE iterations,
503 * <strong>counting starts with 1 not with 0</strong>
504 * @return whether to evaluate
505 */
506 public boolean shouldEvaluate(final boolean continueTraining, final int i) {
507 return (isTuneEach() && (i >= getTuneSince()))
508 || !continueTraining
509 || (i == getTuneIterations())
510 || getTuneEvaluations().contains(Integer.valueOf(i));
511 }
512
513 /***
514 * Returns a string representation of this object.
515 *
516 * @return a textual representation
517 */
518 public String toString() {
519 final ToStringBuilder builder = new ToStringBuilder(this)
520 .appendSuper(super.toString());
521
522 if (splitSeparator != null) {
523 builder.append("split separator", splitSeparator);
524 } else {
525 builder.append("train split", trainSplit)
526 .append("test split", testSplit);
527 }
528
529 builder.append("tune iterations", tuneIterations)
530 .append("tune stops after", tuneStop)
531 .append("measure after each iteration", tuneEach)
532 .append("starting from", tuneSince);
533 return builder.toString();
534 }
535
536 }