View Javadoc

1   /*
2    * Copyright (C) 2005-2006 Christian Siefkes <christian@siefkes.net>.
3    * Development of this software is supported by the German Research Society,
4    * Berlin-Brandenburg Graduate School in Distributed Information Systems
5    * (DFG grant no. GRK 316).
6    *
7    * This program is free software; you can redistribute it and/or modify
8    * it under the terms of the GNU General Public License as published by
9    * the Free Software Foundation; either version 2 of the License, or
10   * (at your option) any later version.
11   *
12   * This program is distributed in the hope that it will be useful,
13   * but WITHOUT ANY WARRANTY; without even the implied warranty of
14   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15   * GNU General Public License for more details.
16   *
17   * You should have received a copy of the GNU General Public License
18   * along with this program; if not, visit
19   * http://www.gnu.org/licenses/gpl.html or write to the Free Software
20   * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
21   */
22  package de.fu_berlin.ties.classify;
23  
24  import java.util.HashSet;
25  import java.util.Iterator;
26  import java.util.List;
27  import java.util.Set;
28  
29  import org.apache.commons.lang.ArrayUtils;
30  import org.apache.commons.lang.builder.ToStringBuilder;
31  
32  import de.fu_berlin.ties.TiesConfiguration;
33  import de.fu_berlin.ties.util.Util;
34  
35  /***
36   * This class provides support for iterative training, also called TUNE
37   * (Train-until-no-errors) training.
38   *
39   * <p>Instances of this class are not thread-safe and must be synchronized
40   * externally, if required.
41   *
42   * @author Christian Siefkes
43   * @version $Revision: 1.9 $, $Date: 2006/10/21 16:03:55 $, $Author: siefkes $
44   */
45  public class Tuner {
46  
47      /***
48       * Configuration key: If given, the specified string is used to separate
49       * the training from the testing section of the corpus (e.g. "---") and
50       * the {@linkplain #getTrainSplit() train split} and
51       * {@linkplain #getTestSplit() test split} values are ignored.
52       */
53      public static final String CONFIG_SPLIT_SEPARATPR = "eval.split.separator";
54  
55      /***
56       * Configuration key: The percentage of a corpus to use for training.
57       */
58      public static final String CONFIG_TRAIN_SPLIT = "eval.train-split";
59  
60      /***
61       * Configuration key: The percentage of a corpus to use for testing
62       * (evaluation).
63       */
64      public static final String CONFIG_TEST_SPLIT = "eval.test-split";
65  
66      /***
67       * Configuration key: The maximum number of iterations used for TUNE
68       * (train until no error) training; if 1, training is incremental.
69       */
70      public static final String CONFIG_TUNE = "train.tune";
71  
72      /***
73       * Configuration key: TUNE training is stopped if the training accuracy
74       * didn't improve for the specified number of iterations.
75       */
76      public static final String CONFIG_TUNE_STOP = "train.tune.stop";
77  
78      /***
79       * Configuration key: Whether to measure results after each TUNE iteration
80       * or only at the end of training.
81       */
82      public static final String CONFIG_TUNE_EACH = "eval.tune.each";
83  
84      /***
85       * Configuration key: The training iteration after which to evaluate
86       * results for the first time if {@link #CONFIG_TUNE_EACH} is enabled.
87       */
88      public static final String CONFIG_TUNE_SINCE = "eval.tune.since";
89  
90  
91      /***
92       * If not <code>null</code>, the specified string is used to separate
93       * the training from the testing section of the corpus (e.g. "---") and
94       * the {@linkplain #getTrainSplit() train split} and
95       * {@linkplain #getTestSplit() test split} values are ignored.
96       */
97      private final String splitSeparator;
98  
99      /***
100      * The percentage of a corpus to use for training.
101      */
102     private final float trainSplit;
103 
104     /***
105      * The percentage of a corpus to use for testing; if <code>-1</code>,
106      * all remaining documents should be used.
107      */
108     private final float testSplit;
109 
110     /***
111      * The maximum number of iterations used for TUNE (train until no error)
112      * training; if 1, training is incremental. <em>Note that iterations should
113      * be indexed from 1 to X (this number) instead of from 0 to X-1 for
114      * compatibility with {@link #getTuneEvaluations()}.</em>
115      */
116     private final int tuneIterations;
117 
118     /***
119      * TUNE training is stopped if the training accuracy didn't improve for the
120      * specified number of iterations.
121      */
122     private final int tuneStop;
123 
124     /***
125      * Whether to measure results after each TUNE iteration or only at the
126      * end of training.
127      */
128     private final boolean tuneEach;
129 
130     /***
131      * The training iteration after which to evaluate results for the first
132      * time if {@link #tuneEach} is enabled.
133      */
134     private final int tuneSince;
135 
136     /***
137      * A set of iterations after which to evaluate TUNE training in addition to
138      * the last one; ignored if {@link #tuneEach} is <code>true</code>.
139      */
140     private final Set<Integer> tuneEvaluations = new HashSet<Integer>();
141 
142     /***
143      * Used to decide when to stop TUNE training.
144      */
145     private double[] lastAcc = null;
146 
147     /***
148      * Used to decide when to stop TUNE training.
149      */
150     private boolean allAreOptimal;
151 
152     /***
153      * Used to decide when to stop TUNE training.
154      */
155     private boolean noneGotBetter;
156 
157     /***
158      * Used to decide when to stop TUNE training.
159      */
160     private boolean someGotWorse;
161 
162     /***
163      * Used to decide when to stop TUNE training.
164      */
165     private int noneGotBetterCounter = 0;
166 
167 
168     /***
169      * Creates a new instance.
170      *
171      * @param config used to configure this instance
172      * @param suffix an optional suffix used to
173      * {@linkplain TiesConfiguration#adaptKey(String, String) adapt
174      * configuration keys}; might be <code>null</code>
175      */
176     public Tuner(final TiesConfiguration config, final String suffix) {
177         this(config.getFloat(config.adaptKey(CONFIG_TRAIN_SPLIT, suffix)),
178                 config.getFloat(config.adaptKey(CONFIG_TEST_SPLIT, suffix)),
179                 config.getString(config.adaptKey(
180                         CONFIG_SPLIT_SEPARATPR, suffix), null),
181                 config.getInt(config.adaptKey(CONFIG_TUNE, suffix)),
182                 config.getInt(config.adaptKey(CONFIG_TUNE_STOP, suffix)),
183                 config.getBoolean(config.adaptKey(CONFIG_TUNE_EACH, suffix)),
184                 config.getInt(config.adaptKey(CONFIG_TUNE_SINCE, suffix)),
185                 config.getList(config.adaptKey("eval.tune.list", suffix)));
186     }
187 
188     /***
189      * Creates a new instance.
190      *
191      * @param trainingSplit the percentage of a corpus to use for training
192      * @param testingSplit the percentage of a corpus to use for testing
193      * (evaluation); if <code>-1</code>, all remaining documents (1 -
194      * <code>trainingSplit</code>) are used
195      * @param splitSep if not <code>null</code>, the specified string is used
196      * to separate the training from the testing section of the corpus and
197      * the {@linkplain #getTrainSplit() train split} and
198      * {@linkplain #getTestSplit() test split} values are ignored.
199      * @param tuneRuns the maximum number of iterations used for TUNE
200      * (train until no error) training; if 1, training is incremental
201      * @param tuneStopAfter TUNE training is stopped if the training accuracy
202      * didn't improve for the specified number of iterations.
203      * @param measureEachTUNE whether to measure results after each TUNE
204      * iteration or only at the end of training
205      * @param startMeasureTUNE he training iteration after which to evaluate
206      * results for the first time if <code>measureEachTUNE</code> is enabled
207      * (ignored otherwise)
208      * @param tuneEvalList A list of Integers or int Strings specifying
209      * iterations after which to evaluate TUNE training in addition to the last
210      * one; ignored if <code>measureEachTUNE</code> is <code>true</code>
211      * @throws IllegalArgumentException if <code>trainingSplit</code> is not
212      * a percentage (larger than 1 or smaller than 0) or if
213      * <code>tuneRuns</code> is non-positive
214       */
215     public Tuner(final float trainingSplit, final float testingSplit,
216             final String splitSep, final int tuneRuns, final int tuneStopAfter,
217             final boolean measureEachTUNE, final int startMeasureTUNE,
218             final List tuneEvalList) throws IllegalArgumentException {
219         super();
220         // evaluate arguments prior to storing
221         if ((splitSep == null)
222                 && ((trainingSplit < 0.0) || (trainingSplit > 1.0))) {
223             throw new IllegalArgumentException(
224                 "Train split is not a percentage: " + trainingSplit);
225         }
226         if ((splitSep == null) && (testingSplit > 1.0)) {
227             throw new IllegalArgumentException(
228                 "Test split is not a percentage: " + testingSplit);
229         }
230         if (tuneRuns < 1) {
231             throw new IllegalArgumentException(
232                 "Number of TUNE runs must be at least 1: " + tuneRuns);
233         }
234         if (tuneStopAfter < 1) {
235             throw new IllegalArgumentException(
236                 "TUNE stopping criterium must be at least 1: " + tuneStopAfter);
237         }
238 
239         trainSplit = trainingSplit;
240         testSplit = testingSplit;
241         splitSeparator = splitSep;
242         tuneIterations = tuneRuns;
243         tuneStop = tuneStopAfter;
244         tuneEach = measureEachTUNE;
245         tuneSince = Math.max(startMeasureTUNE, 1); // must be 1 or higher
246 
247         if (!tuneEach) { // otherwise we don't need this
248             for (final Iterator iter = tuneEvalList.iterator();
249                     iter.hasNext();) {
250                 // store iterations after which to evaluate
251                 tuneEvaluations.add(Integer.valueOf(Util.asInt(iter.next())));
252             }
253         }
254 
255     }
256 
257 
258     /***
259      * Whether to continue TUNE training after finishing an iteration.
260      *
261      * @param currentAcc the list of accuracies for the just finished
262      * TUNE iteration
263      * @param i the number of the just finished TUNE iterations,
264      * <strong>counting starts with 1 not with 0</strong>
265      * @return whether to continue TUNE training
266      */
267     public boolean continueTraining(final double[] currentAcc, final int i) {
268         if (i >= getTuneIterations()) {
269             // TUNEd for maximum number of iterations
270             return false;
271         }
272 
273         boolean continueTraining = true;
274         allAreOptimal = true;
275         noneGotBetter = true;
276         someGotWorse = false;
277 
278         for (int j = 0; j < currentAcc.length; j++) {
279             // test whether all accuracies are optimal
280             allAreOptimal = allAreOptimal
281                 && (currentAcc[j] >= 1.0);
282 
283             // test whether all accuracies are better than the last ones
284             if (lastAcc != null) {
285                 noneGotBetter = noneGotBetter
286                     && (currentAcc[j] <= lastAcc[j]);
287                 someGotWorse = someGotWorse
288                     || (currentAcc[j] < lastAcc[j]);
289             } else {
290                 noneGotBetter = false;
291             }
292         }
293 
294         if (noneGotBetter) {
295             noneGotBetterCounter++;
296 
297             if (noneGotBetterCounter >= tuneStop) {
298                 // reached stopping criterium for TUNE training
299                 continueTraining = false;
300                 Util.LOG.debug("Stopping TUNE training after " + i
301                         + " iterations because current accuracies ("
302                         + ArrayUtils.toString(currentAcc)
303                         + ") aren't higher than last ones ("
304                         + ArrayUtils.toString(lastAcc)
305                         + ") for the " + tuneStop
306                         + ". time");
307             } else if (someGotWorse) {
308                 // stop TUNE training because accuracy degraded
309                 continueTraining = false;
310                 Util.LOG.debug("Stopping TUNE training after " + i
311                         + " iterations because current accuracies ("
312                         + ArrayUtils.toString(currentAcc)
313                         + ") are lower than last ones ("
314                         + ArrayUtils.toString(lastAcc) + ")");
315             }
316         }
317 
318         if (allAreOptimal) {
319             continueTraining = false;
320             Util.LOG.debug("Stopping TUNE training after " + i
321                     + " iterations because all accuracies are already "
322                     + "optimal: "
323                     + ArrayUtils.toString(currentAcc));
324         }
325 
326         lastAcc = currentAcc;
327         return continueTraining;
328     }
329 
330     /***
331      * Returns the percentage of a corpus to use for testing;
332      * if <code>-1</code>, all remaining documents should be used.
333      *
334      * @return the value of the attribute
335      */
336     public float getTestSplit() {
337         return testSplit;
338     }
339 
340     /***
341      * Returns the percentage of a corpus to use for training.
342      *
343      * @return the value of the attribute
344      */
345     public float getTrainSplit() {
346         return trainSplit;
347     }
348 
349     /***
350      * Returns the set of iterations after which to evaluate TUNE training in
351      * addition to the last one; should be ignored if {@link #isTuneEach()} is
352      * <code>true</code>.
353      *
354      * @return the value of the attribute
355      */
356     public Set<Integer> getTuneEvaluations() {
357         return tuneEvaluations;
358     }
359 
360     /***
361      * If not <code>null</code>, the returned string should be used to separate
362      * the training from the testing section of the corpus (e.g. "---") and
363      * the {@linkplain #getTrainSplit() train split} and
364      * {@linkplain #getTestSplit() test split} values should be ignored.
365      *
366      * @return the value of the attribute
367      */
368     public String getSplitSeparator() {
369         return splitSeparator;
370     }
371 
372     /***
373      * Returns the maximum number of iterations used for TUNE (train until no
374      * error) training; if 1, training is incremental.  <em>Note that iterations
375      * should be indexed from 1 to X (this number) instead of from 0 to X-1 for
376      * compatibility with {@link #getTuneEvaluations()}.</em>
377      *
378      * @return the value of the attribute
379      */
380     public int getTuneIterations() {
381         return tuneIterations;
382     }
383 
384     /***
385      * Returns the training iteration after which to evaluate results for the
386      * first time if {@link #isTuneEach()} is enabled.
387      *
388      * @return the value of the attribute
389      */
390     public int getTuneSince() {
391         return tuneSince;
392     }
393 
394     /***
395      * Returns the TUNE stopping criterion: TUNE training should be stopped if
396      * the training accuracy didn't improve for the specified number of
397      * iterations.
398      *
399      * @return the value of the attribute
400      */
401     public int getTuneStop() {
402         return tuneStop;
403     }
404 
405     /***
406      * Whether to measure results after each TUNE iteration or only at the
407      * end of training.
408      *
409      * @return the value of the attribute
410      */
411     public boolean isTuneEach() {
412         return tuneEach;
413     }
414 
415     /***
416      * Resets the state of this instance. This method must be called before
417      * starting to TUNE train a set of instances after finishing training
418      * another set.
419      */
420     public void reset() {
421         noneGotBetterCounter = 0;
422         lastAcc = null;
423     }
424 
425     /***
426      * Chooses files to use for training and files to use for evaluation,
427      * depending on the configured settings.
428      *
429      * @param allFiles the array of file names to process
430      * @param trainFiles populated with the files to use for training, will
431      * be populated with the first <em>{@link Tuner#getTrainSplit()} *
432      * allFiles.length</em> files; must initially be empty
433      * @param evalFiles populated with the files to use for evaluation, will
434      * be populated from the next <em>{@link Tuner#getTestSplit()} *
435      * allFiles.length</em> remaining files (or all remaining files if test
436      * split is negative); must initially be empty
437      * @throws IllegalArgumentException if the lists aren't empty
438      */
439     public void selectFiles(final String[] allFiles,
440             final List<String> trainFiles, final List<String> evalFiles)
441             throws IllegalArgumentException {
442         // check arguments
443         if (!trainFiles.isEmpty() || !evalFiles.isEmpty()) {
444             throw new IllegalArgumentException(
445                 "Lists of train files and eval files must initially be empty");
446         }
447 
448         final int numTrainFiles = Math.round(getTrainSplit() * allFiles.length);
449         final int filesToUse;
450 
451         if (splitSeparator != null) {
452             // use split separator
453             int i = 0;
454 
455             for (; (i < allFiles.length)
456                     && !splitSeparator.equals(allFiles[i]); i++) {
457                 // add files to train set until split separator or end of array
458                 trainFiles.add(allFiles[i]);
459             }
460 
461             // skip split separator
462             i++;
463 
464             for (; (i < allFiles.length)
465                     && !splitSeparator.equals(allFiles[i]); i++) {
466                 // add files to test set until split separator or end of array
467                 evalFiles.add(allFiles[i]);
468             }
469         } else {
470             // use train + test split
471             if (getTestSplit() < 0) {
472                 // use all remaining files for evaluation
473                 filesToUse = allFiles.length;
474             } else {
475                 filesToUse = Math.min(allFiles.length, Math.round(
476                         (getTrainSplit() + getTestSplit()) * allFiles.length));
477             }
478 
479             // add first numTrainFiles to trainFiles, rest to evalFiles
480             for (int i = 0; i < filesToUse; i++) {
481                 if (i < numTrainFiles) {
482                     // add to training files
483                     trainFiles.add(allFiles[i]);
484                 } else {
485                     // add to evaluation files
486                     evalFiles.add(allFiles[i]);
487                 }
488             }
489         }
490 
491         final int usedFiles = trainFiles.size() + evalFiles.size();
492         Util.LOG.debug("Using " + usedFiles + " of " + allFiles.length
493             + " files: " + trainFiles.size() + " for training, "
494             + evalFiles.size() + " for evaluation");
495     }
496 
497     /***
498      * Whether to evaluate results after this TUNE iteration.
499      *
500      * @param continueTraining the result returned by the preceding call to
501      * {@link #continueTraining(double[], int)}
502      * @param i the number of the just finished TUNE iterations,
503      * <strong>counting starts with 1 not with 0</strong>
504      * @return whether to evaluate
505      */
506     public boolean shouldEvaluate(final boolean continueTraining, final int i) {
507         return (isTuneEach() && (i >= getTuneSince()))
508                 || !continueTraining
509                 || (i == getTuneIterations())
510                 || getTuneEvaluations().contains(Integer.valueOf(i));
511     }
512 
513     /***
514      * Returns a string representation of this object.
515      *
516      * @return a textual representation
517      */
518     public String toString() {
519         final ToStringBuilder builder = new ToStringBuilder(this)
520             .appendSuper(super.toString());
521 
522         if (splitSeparator != null) {
523             builder.append("split separator", splitSeparator);
524         } else {
525             builder.append("train split", trainSplit)
526                 .append("test split", testSplit);
527         }
528 
529         builder.append("tune iterations", tuneIterations)
530             .append("tune stops after", tuneStop)
531             .append("measure after each iteration", tuneEach)
532             .append("starting from", tuneSince);
533         return builder.toString();
534     }
535 
536 }