View Javadoc

1   /*
2    * Copyright (C) 2004 Christian Siefkes <christian@siefkes.net>.
3    * Development of this software is supported by the German Research Society,
4    * Berlin-Brandenburg Graduate School in Distributed Information Systems
5    * (DFG grant no. GRK 316).
6    *
7    * This library is free software; you can redistribute it and/or
8    * modify it under the terms of the GNU Lesser General Public
9    * License as published by the Free Software Foundation; either
10   * version 2.1 of the License, or (at your option) any later version.
11   *
12   * This library is distributed in the hope that it will be useful,
13   * but WITHOUT ANY WARRANTY; without even the implied warranty of
14   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
15   * Lesser General Public License for more details.
16   *
17   * You should have received a copy of the GNU Lesser General Public
18   * License along with this library; if not, visit
19   * http://www.gnu.org/licenses/lgpl.html or write to the Free Software
20   * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307, USA.
21   */
22  package de.fu_berlin.ties.extract;
23  
24  import java.io.File;
25  import java.io.IOException;
26  import java.io.Reader;
27  import java.io.Writer;
28  import java.util.ArrayList;
29  import java.util.HashSet;
30  import java.util.Iterator;
31  import java.util.LinkedList;
32  import java.util.List;
33  import java.util.NoSuchElementException;
34  import java.util.Set;
35  import java.util.SortedMap;
36  import java.util.TreeMap;
37  
38  import org.apache.commons.lang.ArrayUtils;
39  import org.apache.commons.lang.builder.ToStringBuilder;
40  import org.apache.commons.math.stat.univariate.moment.StandardDeviation;
41  import org.dom4j.Document;
42  import org.dom4j.DocumentException;
43  
44  import de.fu_berlin.ties.Closeable;
45  import de.fu_berlin.ties.ContextMap;
46  import de.fu_berlin.ties.ParsingException;
47  import de.fu_berlin.ties.ProcessingException;
48  import de.fu_berlin.ties.TextProcessor;
49  import de.fu_berlin.ties.TiesConfiguration;
50  import de.fu_berlin.ties.classify.Prediction;
51  import de.fu_berlin.ties.eval.AccuracyView;
52  import de.fu_berlin.ties.eval.FMetrics;
53  import de.fu_berlin.ties.eval.FMetricsView;
54  import de.fu_berlin.ties.eval.FeatureCountView;
55  import de.fu_berlin.ties.eval.MultiFMetrics;
56  import de.fu_berlin.ties.eval.MultiFMetricsView;
57  import de.fu_berlin.ties.io.FieldContainer;
58  import de.fu_berlin.ties.io.FieldMap;
59  import de.fu_berlin.ties.io.IOUtils;
60  import de.fu_berlin.ties.util.Util;
61  import de.fu_berlin.ties.xml.dom.DOMUtils;
62  
63  /***
64   * Trains an extractor and evaluates extraction quality. Processes shuffle
65   * files (as generated by {@link de.fu_berlin.ties.eval.ShuffleGenerator}
66   * contain the files to use for training and evaluation.
67   * For each of these files, a corresponding answer key (*.ans) must exist.
68   *
69   * <p>Instances of this class are not thread-safe.
70   *
71   * @author Christian Siefkes
72   * @version $Revision: 1.67 $, $Date: 2004/11/17 09:17:04 $, $Author: siefkes $
73   */
74  public class TrainEval extends TextProcessor implements Closeable {
75  
76      /***
77       * An inner class wrapping the results of a
78       * {@linkplain TrainEval#trainAndEval(String[], File, File, String, Writer)
79       * training + evaluation run}.
80       */
81      public static class Results {
82  
83          /***
84           * The map from TUNE iterations to
85           * {@link EvaluatedExtractionContainer}s containing the evaluated
86           * extractions (after the corresponding TUNE iteration).
87           */
88          private final SortedMap<Integer, EvaluatedExtractionContainer> evaluated
89              = new TreeMap<Integer, EvaluatedExtractionContainer>();
90  
91          /***
92           * A read-only view on the feature count statistics collected during
93           * training.
94           */
95          private FeatureCountView trainFeatureCV;
96  
97          /***
98           * A read-only view on the feature count statistics collected during
99           * extraction.
100          */
101         private FeatureCountView extractFeatureCV;
102 
103         /***
104          * Creates a new instance.
105          */
106         public Results() {
107         }
108 
109         /***
110          * Adds an evaluated extraction container.
111          *
112          * @param iteration the number of the TUNE iteration to add
113          * @param evaluatedExtractions the container to add
114          * @throws IllegalArgumentException if there already is a container
115          * stored for <code>iteration</code>
116          */
117         private void addEvaluated(final int iteration,
118                 final EvaluatedExtractionContainer evaluatedExtractions) 
119         throws IllegalArgumentException {
120             final EvaluatedExtractionContainer oldValue =
121                 evaluated.put(Integer.valueOf(iteration), evaluatedExtractions);
122 
123             if (oldValue != null) {
124                 throw new IllegalArgumentException("Cannot store two "
125                         + "extraction containers for same iteration ("
126                         + iteration + "): " + oldValue + "; "
127                         + evaluatedExtractions);
128             }
129         }
130 
131         /***
132          * Returns one of the stored evaluated extraction containers.
133          *
134          * @param iteration the number of the TUNE iteration to return;
135          * must be contained in {@link #iterations()}, otherwise
136          * <code>null</code> will be returned
137          * @return the specified container
138          */
139         public EvaluatedExtractionContainer getEvaluated(
140                 final Integer iteration) {
141             return evaluated.get(iteration);
142         }
143 
144         /***
145          * Returns one of the stored evaluated extraction containers.
146          *
147          * @param iteration the number index of the TUNE iterations to return;
148          * must be contained in {@link #iterations()}, otherwise
149          * <code>null</code> will be returned
150          * @return the specified container
151          */
152         public EvaluatedExtractionContainer getEvaluated(final int iteration) {
153             return getEvaluated(Integer.valueOf(iteration));
154         }
155 
156         /***
157          * Returns a read-only view on the feature count statistics collected
158          * during training.
159          * @return the stored attribute
160          */
161         public FeatureCountView getExtractFeatureCV() {
162             return extractFeatureCV;
163         }
164 
165         /***
166          * Returns a read-only view on the feature count statistics collected
167          * during extraction.
168          * @return the stored attribute
169          */
170         public FeatureCountView getTrainFeatureCV() {
171             return trainFeatureCV;
172         }
173 
174         /***
175          * Returns an iterator over the TUNE iterations for which extraction
176          * results are stored.
177          *
178          * @return an iterator over the TUNE iterations
179          */
180         public Iterator<Integer> iterations() {
181             return evaluated.keySet().iterator();
182         }
183 
184         /***
185          * Stores a read-only view on the feature count statistics collected
186          * during training.
187          *
188          * @param extractFeatures a read-only view on the feature count
189          * statistics collected during extraction
190          */
191         private void setExtractFeatureCV(
192                 final FeatureCountView extractFeatures) {
193             extractFeatureCV = extractFeatures;
194         }
195 
196         /***
197          * Stores a read-only view on the feature count statistics collected
198          * during extraction.
199          *
200          * @param trainFeatures a read-only view on the feature count statistics
201          * collected during training
202          */
203         private void setTrainFeatureCV(
204                 final FeatureCountView trainFeatures) {
205             trainFeatureCV = trainFeatures;
206         }
207 
208         /***
209          * Returns a string representation of this object.
210          *
211          * @return a textual representation
212          */
213         public String toString() {
214             return new ToStringBuilder(this)
215                 .append("training feature count", trainFeatureCV)
216                 .append("extraction feature count", extractFeatureCV)
217                 .append("size", evaluated.size())
218                 .toString();
219         }
220 
221     }
222 
223     /***
224      * Configuration key: The percentage of a corpus to use for training.
225      */
226     public static final String CONFIG_TRAIN_SPLIT = "eval.train-split";
227 
228     /***
229      * Configuration key: The percentage of a corpus to use for testing
230      * (evaluation).
231      */
232     public static final String CONFIG_TEST_SPLIT = "eval.test-split";
233 
234     /***
235      * Configuration key: If <code>true</code>, a fully incremental setup is
236      * used where the trainer is trained on each document after the extractor
237      * processed it.
238      */
239     public static final String CONFIG_FEEDBACK = "eval.feedback";
240 
241     /***
242      * Configuration key: The maximum number of iterations used for TUNE
243      * (train until no error) training; if 1, training is incremental.
244      */
245     public static final String CONFIG_TUNE = "train.tune";
246 
247     /***
248      * Configuration key: TUNE training is stopped if the training accuracy
249      * didn't improve for the specified number of iterations.
250      */
251     public static final String CONFIG_TUNE_STOP = "train.tune.stop";
252 
253     /***
254      * Configuration key: The maximum number of iterations used for TUNE
255      * training the sentence classifier; if 0 or negative, the value of
256      * {@link #CONFIG_TUNE} is used.
257      */
258     public static final String CONFIG_SENTENCE_TUNE = "sent.tune";
259 
260     /***
261      * Configuration key: Whether to measure results after each TUNE iteration
262      * or only at the end of training.
263      */
264     public static final String CONFIG_TUNE_EACH = "eval.tune.each";
265 
266     /***
267      * Configuration key: The training iteration after which to evaluate
268      * results for the first time if {@link #CONFIG_TUNE_EACH} is enabled.
269      */
270     public static final String CONFIG_TUNE_SINCE = "eval.tune.since";
271 
272     /***
273      * Serialization key for the number of the iteration (when TUNE training).
274      */
275     public static final String KEY_ITERATION = "Iteration";
276 
277     /***
278      * Serialization key for the number of the run.
279      */
280     public static final String KEY_RUN = "Run";
281 
282     /***
283      * Serialization key for the type (either "Train" or "Eval").
284      */
285     public static final String KEY_TYPE = "Type";
286 
287     /***
288      * Serialization value for the "Train" type.
289      */
290     public static final String TYPE_TRAIN = "Train";
291 
292     /***
293      * Serialization value for the "Eval" type.
294      */
295     public static final String TYPE_EVAL = "Eval";
296 
297     /***
298      * Results are measured after each 10th evaluated document if
299      * {@link #feedback} is given.
300      */
301     private static final int MEASURE_FEEDBACK = 10;
302 
303     /***
304      * The percentage of a corpus to use for training.
305      */
306     private final float trainSplit;
307 
308     /***
309      * The percentage of a corpus to use for testign; if <code>-1</code>,
310      * all remaining documents are used.
311      */
312     private final float testSplit;
313 
314     /***
315      * If <code>true</code>, a fully incremental setup is used where the
316      * trainer is trained on each document after the extractor processed it.
317      */
318     private final boolean feedback;
319 
320     /***
321      * The maximum number of iterations used for TUNE (train until no error)
322      * training; if 1, training is incremental.
323      */
324     private final int tuneIterations;
325 
326     /***
327      * TUNE training is stopped if the training accuracy didn't improve for the
328      * specified number of iterations.
329      */
330     private final int tuneStop;
331 
332     /***
333      * The maximum number of iterations used for TUNE training the sentence
334      * classifier; if 0 or negative, the value of {@link #tuneIterations} is
335      * used.
336      */
337     private final int sentenceIterations;
338 
339     /***
340      * Whether to measure results after each TUNE iteration or only at the
341      * end of training.
342      */
343     private final boolean tuneEach;
344 
345     /***
346      * The training iteration after which to evaluate results for the first
347      * time if {@link #tuneEach} is enabled.
348      */
349     private final int tuneSince;
350 
351     /***
352      * A set of iterations after which to evaluate TUNE training in addition to
353      * the last one; ignored if {@link #tuneEach} is <code>true</code>.
354      */
355     private final Set<Integer> tuneEvaluations = new HashSet<Integer>();
356 
357     /***
358      * The last TUNE iteration that was actually used for training in any
359      * batch of processed files.
360      */
361     private int lastUsedTuneIteration = -1;
362 
363     /***
364      * An array of multi-metrics calculating summaries esp. standard deviations
365      * over all runs. Each array element corresponds to one TUNE operation.
366      */
367     private final MultiFMetrics[] averages;
368 
369     /***
370      * An array of multi-metrics incrementally calculating summaries after
371      * evaluating each 10th document if {@link #feedback} is given.
372      */
373     private final List<MultiFMetrics> feedbackAverages;
374 
375     /***
376      * Calculate feature count statistics over all runs.
377      */
378     private final FieldContainer featureCounts =
379         FieldContainer.createFieldContainer();
380 
381     /***
382      * Stores precision and recall values for sentence filtering, if used.
383      */
384     private FieldContainer sentenceMetricsStore;
385 
386     /***
387      * The number of the current run.
388      */
389     private int runNo = 0;
390 
391     /***
392      * Creates a new instance, using a default extension and the
393      * {@linkplain TiesConfiguration#CONF standard configuration}.
394      *
395      * @throws IllegalArgumentException if the configured values are outside the
396      * allowed ranges
397      * @throws ClassCastException if the configured numeric values cannot be
398      * parsed
399      * @throws NoSuchElementException if one of the required values is
400      * missing from the configuration
401      */
402     public TrainEval() throws IllegalArgumentException, ClassCastException,
403     NoSuchElementException {
404         this("metrics");
405     }
406 
407     /***
408      * Creates a new instance, using the {@linkplain TiesConfiguration#CONF
409      * standard configuration}.
410      *
411      * @param outExt the extension to use for output files
412      * @throws IllegalArgumentException if the configured values are outside the
413      * allowed ranges
414      * @throws ClassCastException if the configured numeric values cannot be
415      * parsed
416      * @throws NoSuchElementException if one of the required values is
417      * missing from the configuration
418      */
419     public TrainEval(final String outExt) throws IllegalArgumentException,
420             ClassCastException, NoSuchElementException {
421         this(outExt, TiesConfiguration.CONF);
422     }
423 
424     /***
425      * Creates a new instance.
426      *
427      * @param outExt the extension to use for output files
428      * @param config used to configure this instance
429      * @throws IllegalArgumentException if the configured values are outside the
430      * allowed ranges
431      * @throws ClassCastException if the configured numeric values cannot be
432      * parsed
433      * @throws NoSuchElementException if one of the required values is
434      * missing from the configuration
435      */
436     public TrainEval(final String outExt, final TiesConfiguration config)
437             throws IllegalArgumentException, ClassCastException,
438             NoSuchElementException {
439         // file names are matched case-sensitive
440         this(outExt, config.getFloat(CONFIG_TRAIN_SPLIT),
441                 config.getFloat(CONFIG_TEST_SPLIT),
442                 config.getInt(CONFIG_TUNE),
443                 config.getInt(CONFIG_TUNE_STOP),
444                 config.getBoolean(CONFIG_TUNE_EACH),
445                 config.getInt(CONFIG_TUNE_SINCE),
446                 config.getList("eval.tune.list"),
447                 config.getInt(CONFIG_SENTENCE_TUNE),
448                 config.getBoolean(CONFIG_FEEDBACK), config);
449     }
450 
451     /***
452      * Creates a new instance.
453      *
454      * @param outExt the extension to use for output files
455      * @param trainingSplit the percentage of a corpus to use for training
456      * @param testingSplit the percentage of a corpus to use for testing
457      * (evaluation); if <code>-1</code>, all remaining documents (1 -
458      * <code>trainingSplit</code>) are used
459      * @param tuneRuns the maximum number of iterations used for TUNE
460      * (train until no error) training; if 1, training is incremental
461      * @param tuneStopAfter TUNE training is stopped if the training accuracy
462      * didn't improve for the specified number of iterations.
463      * @param measureEachTUNE whether to measure results after each TUNE
464      * iteration or only at the end of training
465      * @param startMeasureTUNE he training iteration after which to evaluate
466      * results for the first time if <code>measureEachTUNE</code> is enabled
467      * (ignored otherwise)
468      * @param sentenceTUNE the maximum number of iterations used for TUNE
469      * training the sentence classifier (if used); if 0 or negative, the value
470      * of <code>tuneRuns</code> is used
471      * @param tuneEvalList A list of Integers or int Strings specifying
472      * iterations after which to evaluate TUNE training in addition to the last
473      * one; ignored if <code>measureEachTUNE</code> is <code>true</code>
474      * @param giveFeedback if <code>true</code>, a fully incremental setup is
475      * used where the trainer is trained on each document after the extractor
476      * processed it; it's not allowed to set this both this and
477      * <code>measureEachTUNE</code> to <code>true</code> when training for
478      * several <code>tuneRuns</code> because that would mean to evaluate on the
479      * training set
480      * @param config used to configure superclasses, trainer, and extractor;
481      * if <code>null</code>, the {@linkplain TiesConfiguration#CONF standard
482      * configuration} is used
483      * @throws IllegalArgumentException if <code>trainingSplit</code> is not
484      * a percentage (larger than 1 or smaller than 0) or if
485      * <code>tuneRuns</code> is non-positive
486      */
487     public TrainEval(final String outExt, final float trainingSplit,
488             final float testingSplit, final int tuneRuns,
489             final int tuneStopAfter, final boolean measureEachTUNE,
490             final int startMeasureTUNE, final List tuneEvalList,
491             final int sentenceTUNE, final boolean giveFeedback,
492             final TiesConfiguration config)
493     throws IllegalArgumentException {
494         super(outExt, config);
495 
496         // evaluate arguments prior to storing
497         if ((trainingSplit < 0.0) || (trainingSplit > 1.0)) {
498             throw new IllegalArgumentException(
499                 "Train split is not a percentage: " + trainingSplit);
500         }
501         if (testingSplit > 1.0) {
502             throw new IllegalArgumentException(
503                 "Test split is not a percentage: " + testingSplit);
504         }
505         if (tuneRuns < 1) {
506             throw new IllegalArgumentException(
507                 "Number of TUNE runs must be at least 1: " + tuneRuns);
508         }
509         if (tuneStopAfter < 1) {
510             throw new IllegalArgumentException(
511                 "TUNE stopping criterium must be at least 1: " + tuneStopAfter);
512         }
513 
514         if (giveFeedback && (tuneRuns > 1) && measureEachTUNE) {
515             throw new IllegalArgumentException("It's not allowed give feedback "
516                     + "when evaluating after each of several iteration because "
517                     + "that would mean to evaluate on the training set");
518         }
519 
520         trainSplit = trainingSplit;
521         testSplit = testingSplit;
522         tuneIterations = tuneRuns;
523         tuneStop = tuneStopAfter;
524         tuneEach = measureEachTUNE;
525         tuneSince = Math.max(startMeasureTUNE, 1); // must be 1 or higher
526         sentenceIterations = sentenceTUNE;
527         feedback = giveFeedback;
528         averages = new MultiFMetrics[tuneIterations];
529 
530         if (!tuneEach) { // otherwise we don't need this
531             for (final Iterator iter = tuneEvalList.iterator(); iter.hasNext();) {
532                 // store iterations after which to evaluate
533                 tuneEvaluations.add(Integer.valueOf(Util.asInt(iter.next())));
534             }            
535         }
536 
537         // init metrics to store incremental F1 values if feedback is given
538         if (feedback) {
539             feedbackAverages = new ArrayList<MultiFMetrics>();
540         } else {
541             feedbackAverages = null;
542         }
543 
544         // Dummy serialization to preload statistical implementations.
545         // This avoid instantiation error in close() method if the JAR file
546         // has changed in the meanwhile (which might happen if a nightly build
547         // takes place during a long test run).
548         final FieldContainer dummy = new FieldContainer();
549         final MultiFMetrics dummyMetrics = new MultiFMetrics(true);
550         dummyMetrics.storeEntries(dummy);
551         StandardDeviation dummyDev = new StandardDeviation();
552         dummyDev.clear();
553     }
554 
555     /***
556      * Check whether to disable sentnce training; called after each training
557      * iteration.
558      * 
559      * @param trainer the used trainer
560      * @param sentenceMetrics metrics calculated for sentence filtering, if
561      * used (<code>null</code> otherwise)
562      * @param iteration the current iteration
563      */
564     private void checkSentenceTraining(final Trainer trainer,
565             final FMetrics sentenceMetrics, final int iteration) {
566         if ((sentenceIterations > 0) && (sentenceIterations == iteration)) {
567             trainer.disableSentenceTraining();
568             Util.LOG.info("Disabled sentence training after " + iteration
569                     + " iterations");
570         }
571 
572         if (sentenceMetrics != null) {
573             updateSentenceMetricsStore(sentenceMetrics, true, iteration);
574         }
575     }
576 
577     /***
578      * {@inheritDoc}
579      */
580     public void close(final int errorCount)
581             throws IOException, ProcessingException {
582         // don't do anything if there were errors
583         if (errorCount <= 0) {
584             final File outDir = IOUtils.determineOutputDirectory(getConfig());
585 
586             // store feature counts + average metrics incl. standard deviations
587             storeValues(featureCounts, outDir, "FeatureCounts", null);
588 
589             final FieldContainer averagesStore =
590                 FieldContainer.createFieldContainer();
591             boolean storeIterations = (averages.length > 1);
592 
593             // don't store results for more TUNE iterations than actually used
594             final int lastToStore = Math.min(averages.length,
595                     lastUsedTuneIteration);
596 
597             for (int i = 1; i <= lastToStore; i++) {
598                 // Some might be null because of eval.tune.since or ...list.
599                 // If tuneEach is disabled, we only store iterations from the
600                 // tuneEvaluations list + the very last one.
601                 if (averages[i-1] != null && (tuneEach
602                         || tuneEvaluations.contains(Integer.valueOf(i))
603                         || i == lastToStore)) {
604                     // add iteration number if we TUNEd
605                     if (storeIterations) {
606                         averagesStore.backgroundMap().put(KEY_ITERATION,
607                                 Integer.valueOf(i));
608                     }
609 
610                     averages[i-1].storeEntries(averagesStore);
611                 }
612             }
613             storeValues(averagesStore, outDir, "All",
614                 MultiFMetrics.EXT_METRICS);
615 
616             // store incremental F1 if feedback is enabled
617             if (feedbackAverages != null) {
618                 final FieldContainer feedbackAveragesStore =
619                     FieldContainer.createFieldContainer();
620 
621                 for (int i = 0; i < feedbackAverages.size(); i++) {
622                     feedbackAverages.get(i).storeEntries(feedbackAveragesStore);
623                 }                
624                 storeValues(feedbackAveragesStore, outDir, "Feedback",
625                         MultiFMetrics.EXT_METRICS);
626             }
627         }
628     }
629 
630     /***
631      * {@inheritDoc}
632      */
633     protected void doProcess(final Reader reader, final Writer writer,
634             final ContextMap context) throws IOException, ProcessingException {
635         // increase run number + read list of files
636         runNo++;
637         final String[] files = IOUtils.readURIList(reader);
638 
639         if (files.length == 0) {
640             Util.LOG.info("No files to process");
641         } else {
642             final String baseName = IOUtils.getBaseName(
643                 new File((String) context.get(KEY_LOCAL_NAME)));
644             final File inDir = (File) context.get(KEY_DIRECTORY);
645             final File outDir = IOUtils.determineOutputDirectory(getConfig());
646             Results currentResults;
647             FieldMap currentFMap;
648 
649             MultiFMetricsView currentMetrics = null;
650 
651             //Util.LOG.info("Starting evaluation run " + runNo);
652             currentResults =
653                 trainAndEval(files, inDir, outDir, baseName, writer);
654 
655             // store feature counts for training + evaluation
656             final Integer runNumber = new Integer(runNo);
657             currentFMap = currentResults.getTrainFeatureCV().storeFields();
658             currentFMap.put(KEY_RUN, runNumber);
659             currentFMap.put(MultiFMetrics.KEY_TYPE, "training");
660             featureCounts.add(currentFMap);
661 
662             currentFMap =
663                 currentResults.getExtractFeatureCV().storeFields();
664             currentFMap.put(KEY_RUN, runNumber);
665             currentFMap.put(MultiFMetrics.KEY_TYPE, "extraction");
666             featureCounts.add(currentFMap);
667 
668             // update average metrics
669             final Iterator<Integer> iterIter = currentResults.iterations();
670             Integer iteration;
671             int it = 0;
672 
673             while (iterIter.hasNext()) {
674                 iteration = iterIter.next();
675                 it = iteration.intValue();
676                 currentMetrics =
677                     currentResults.getEvaluated(iteration).viewMetrics();
678                 if (averages[it-1] == null) {
679                     averages[it-1] = new MultiFMetrics(true);
680                 }
681                 averages[it-1].update(currentMetrics);
682             }
683 
684             // if there are additional averages it means that we stopped TUNEing
685             // early so the last result is valid for all later iterations too
686             for (it++; it <= averages.length; it++) {
687                 if (averages[it-1] == null) {
688                     averages[it-1] = new MultiFMetrics(true);
689                 }
690 
691                 // this will copy the last existing result which is what we want
692                 averages[it-1].update(currentMetrics);
693             }
694         }
695     }
696 
697     /***
698      * Returns the percentage of a corpus to use for testing (evaluation).
699      *
700      * @return the percentage to use for evaluation; if negative, all
701      * remaining documents (1 - {@link #getTrainSplit()}) are used for
702      * evaluation
703      */
704     public float getTestSplit() {
705         return testSplit;
706     }
707 
708     /***
709      * Returns the percentage of a corpus to use for training; the remaining
710      * documents (1-x) are used for evaluation.
711      *
712      * @return the percentage to use for training
713      */
714     public float getTrainSplit() {
715         return trainSplit;
716     }
717 
718     /***
719      * Creates and initializes a extractor to use for an evaluation run,
720      * re-using the components of the provided trainer. Subclasses can
721      * overwrite this method to provide a different extractor.
722      *
723      * @param trainer trainer whose components should be re-used
724      * @return the created extractor
725      */
726     protected Extractor initExtractor(final Trainer trainer) {
727         // we don't need an output extension
728         return new Extractor(null, trainer);
729     }
730 
731     /***
732      * Creates and initializes a trainer to use for an evaluation run,
733      * configured from the
734      * {@link de.fu_berlin.ties.ConfigurableProcessor#getConfig() stored
735      * configuration}. Subclasses can overwrite this method to provide a
736      * different trainer.
737      *
738      * @param runDirectory directory used to run the classifier
739      * @return the created trainer
740      * @throws ProcessingException if an error occurs during initialization
741      */
742     protected Trainer initTrainer(final File runDirectory)
743             throws ProcessingException {
744         // we don't need an output extension
745         final Trainer trainer = new Trainer(null, runDirectory, getConfig());
746         // ensure the prediction model is empty
747         trainer.reset();
748         return trainer;
749     }
750 
751     /***
752      * Returns a string representation of this object.
753      *
754      * @return a textual representation
755      */
756     public String toString() {
757         return new ToStringBuilder(this)
758             .appendSuper(super.toString())
759             .append("train split", trainSplit)
760             .append("test split", testSplit)
761             .append("tune iterations", tuneIterations)
762             .append("tune stops after", tuneStop)
763             .append("measure after each iteration", tuneEach)
764             .append("starting from", tuneSince)
765             .append("sentence iterations", sentenceIterations)
766             .append("feedback", feedback)
767             .toString();
768     }
769 
770     /***
771      * Chooses files to use for training and files to use for evaluation,
772      * depending on the configured settings.
773      *
774      * @param allFiles the array of file names to process
775      * @param trainFiles populated with the files to use for training, will
776      * be populated with the first <em>{@link #getTrainSplit()} *
777      * allFiles.length</em> files; must initially be empty
778      * @param evalFiles populated with the files to use for evaluation, will
779      * be populated from the next <em>{@link #getTestSplit()} *
780      * allFiles.length</em> remaining files (or all remaining files if test
781      * split is negative); must initially be empty
782      * @throws IllegalArgumentException if the lists aren't empty
783      */
784     private void selectFiles(final String[] allFiles,
785             final List<String> trainFiles, final List<String> evalFiles)
786             throws IllegalArgumentException {
787         // check arguments
788         if (!trainFiles.isEmpty() || !evalFiles.isEmpty()) {
789             throw new IllegalArgumentException(
790                 "Lists of train files and eval files must initially be empty");
791         }
792 
793         final int numTrainFiles = Math.round(trainSplit * allFiles.length);
794         final int filesToUse;
795 
796         if (testSplit < 0) {
797             // use all remaining files for evaluation
798             filesToUse = allFiles.length;
799         } else {
800             filesToUse = Math.min(allFiles.length,
801                 Math.round((trainSplit + testSplit) * allFiles.length));
802         }
803 
804         // add first numTrainFiles to trainFiles, rest to evalFiles
805         for (int i = 0; i < filesToUse; i++) {
806             if (i < numTrainFiles) {
807                 // add to training files
808                 trainFiles.add(allFiles[i]);
809             } else {
810                 // add to evaluation files
811                 evalFiles.add(allFiles[i]);
812             }
813         }
814 
815         Util.LOG.debug("Using " + filesToUse + " of " + allFiles.length
816             + " files: " + trainFiles.size() + " for training, "
817             + evalFiles.size() + " for evaluation");
818     }
819 
820     /***
821      * Helper method for serializing the calculated metrics.
822      *
823      * @param outDirectory directory used to do this run and store the results
824      * @param baseName the base name of the files to use for storing
825      * all extractions and training statistics
826      * @param iteration the number of the current iteration
827      * @param evaluated the container of evaluated extractions to tore
828      * @param sentenceMetrics metrics calculated for sentence filtering, if
829      * used (<code>null</code> otherwise)
830      * @throws IOException if an I/O error occurs while serializing
831      */
832     private void serializeExtractions(final File outDirectory,
833             final String baseName, final int iteration,
834             final EvaluatedExtractionContainer evaluated,
835             final FMetrics sentenceMetrics)
836     throws IOException {
837         // serialize + store evaluated extractions
838         final FieldContainer resultStorage =
839             FieldContainer.createFieldContainer();
840         evaluated.storeEntries(resultStorage);
841         final File evaluatedFile = storeValues(resultStorage,
842             outDirectory, baseName, Extractor.EXT_EXTRACTIONS);
843         Util.LOG.info("Stored results of training + evaluation run in "
844             + evaluatedFile);
845 
846         // update metrics of sentence filtering, if used
847         if (sentenceMetrics != null) {
848             updateSentenceMetricsStore(sentenceMetrics, false, iteration);
849         }
850     }
851 
852     /***
853      * Helper method for serializing the calculated metrics.
854      *
855      * @param outDirectory directory used to do this run and store the results
856      * @param baseName the base name of the files to use for storing
857      * all extractions and training statistics
858      * @param writer used to serialize the calculated metrics
859      * @param results contains the evaluated extractions
860      * @throws IOException if an I/O error occurs while serializing the metrics
861      */
862     private void serializeMetrics(final File outDirectory,
863             final String baseName, final Writer writer,
864             final Results results)
865     throws IOException {
866         // serialize metrics to output file
867         final FieldContainer metricsStore =
868             FieldContainer.createFieldContainer();
869         Integer iteration;
870         EvaluatedExtractionContainer evaluated;
871         final boolean storeIterations = (tuneIterations > 1);
872         Iterator<Integer> tuneIter = results.iterations();
873 
874         while (tuneIter.hasNext()) {
875             iteration = tuneIter.next();
876             evaluated = results.getEvaluated(iteration);
877 
878             // add iteration number, if there are several ones to store
879             if (storeIterations) {
880                 metricsStore.backgroundMap().put(KEY_ITERATION, iteration);
881             }
882 
883             evaluated.viewMetrics().storeEntries(metricsStore);
884         }
885 
886         metricsStore.store(writer);
887         writer.flush();
888 
889         // serialize sentence metrics
890         if (sentenceMetricsStore != null) {
891             storeValues(sentenceMetricsStore, outDirectory,
892                     baseName, "sent");
893         }
894     }
895 
896     /***
897      * Helper method for serializing the training accuracies.
898      *
899      * @param outDirectory directory used to do this run and store the results
900      * @param baseName the base name of the files to use for storing
901      * all extractions and training statistics
902      * @param accContainers a an of containers used to store the accuracy of
903      * each classifier, might be <code>null</code> if accuracies aren't measured
904      * @param globalAcc an array of accuracies, one for each classifier
905      * @throws IOException if an I/O error occurs while serializing
906      */
907     private void serializeTrainingMetrics(final File outDirectory,
908             final String baseName, final FieldContainer[] accContainers,
909             final AccuracyView[] globalAcc)
910     throws IOException {
911         if (accContainers != null) {
912             String name;
913 
914             for (int j = 0; j < globalAcc.length; j++) {
915                 if (globalAcc.length > 1) {
916                     // append prefix letter (a, b, c etc.)
917                     name = baseName + (char) ('a' + j);
918                 } else {
919                     // sufficient to use base name
920                     name = baseName;
921                 }
922                 storeValues(accContainers[j], outDirectory, name, "train");
923             }
924         }
925     }
926 
927     /***
928      * Stores the contents of a storable container.
929      *
930      * @param container the container to store
931      * @param directory the directory in which to store the data
932      * @param baseName the base name of the file using for storing the data
933      * @param extension the file extension to use; if <code>null</code>,
934      * the {@linkplain FieldContainer#recommendedExtension() recommended
935      * extension} is used
936      * @return the file used for storing the data
937      * @throws IOException if an I/O error occurs
938      */
939     private File storeValues(final FieldContainer container,
940             final File directory, final String baseName,
941             final String extension) throws IOException {
942 
943         // use default extension if none is specified
944         final String usedExtension = ((extension != null)
945                 ? extension : FieldContainer.recommendedExtension());
946         final File outFile = IOUtils.createOutFile(directory, baseName,
947             usedExtension);
948         final Writer writer = IOUtils.openWriter(outFile, getConfig());
949         container.store(writer);
950         writer.flush();
951         writer.close();
952         return outFile;
953     }
954 
955     /***
956      * Processes an array of files. For each file, a corresponding answer key
957      * (*.ans) must exist.
958      *
959      * @param files the array of file names to process (relative to the
960      * <code>inDirectory</code>)
961      * @param inDirectory directory containing the files to process
962      * @param outDirectory directory used to do this run and store the results
963      * @param baseName the base name of the files to use for storing
964      * all extractions and training statistics
965      * @param writer used to serialize the calculated metrics
966      * @return a wrapper of the results of this run
967      * @throws IOException if an I/O error occurs
968      * @throws ProcessingException if an error occurs during processing
969      */
970     public Results trainAndEval(final String[] files, final File inDirectory,
971             final File outDirectory, final String baseName,
972             final Writer writer) throws IOException, ProcessingException {
973         // make given directory default dir
974         IOUtils.setDefaultDirectory(outDirectory);
975 
976         // select files for evaluation + rest for training
977         final List<String> trainFiles = new LinkedList<String>();
978         final List<String> evalFiles = new LinkedList<String>();
979         selectFiles(files, trainFiles, evalFiles);
980 
981         final long trainStartTime = System.currentTimeMillis();
982         File currentFile;
983         Document currentDoc;
984         ExtractionContainer currentAnswers;
985 
986         // store token accuracies during TOE training
987         FieldContainer[] accContainers = null;
988         AccuracyView[] currentAcc;
989         AccuracyView[] globalAcc = null;
990         FieldMap accMap;
991         final boolean storeIterations = (tuneIterations > 1);
992 
993         // measure accuracies for TUNE training
994         double[] currentOverallAcc;
995         double[] lastOverallAcc = null;
996         boolean allAreOptimal;
997         boolean noneGotBetter;
998         boolean someGotWorse;
999         int noneGotBetterCounter = 0;
1000         boolean continueTraining = true;
1001         int evalFileNum, feedbackNum;
1002 
1003         // create new trainer
1004         final Trainer trainer = initTrainer(outDirectory);
1005         Extractor extractor = null;
1006 
1007         // measure sentence filtering, if used
1008         FMetrics sentenceMetrics = null;
1009         FMetricsView newSentenceMetrics;
1010 
1011         if (trainer.isSentenceFiltering()) {
1012             sentenceMetricsStore = FieldContainer.createFieldContainer();
1013         } else {
1014             sentenceMetricsStore = null;
1015         }
1016 
1017         // store evaluated extraction containers
1018         final Results results = new Results();
1019 
1020 /*        // pseudo-randonness generator if TUNEing
1021         final Random pseudoRandom = (tuneIterations > 1) ?
1022                 Util.reproducibleRandom() : null; */
1023 
1024         for (int i = 1; (i <= tuneIterations) && continueTraining; i++) {
1025             if (tuneIterations > 1) {
1026                 Util.LOG.debug("Starting TUNE iteration " + i + "/" 
1027                         + tuneIterations + " (will stop if no improvement for "
1028                         + tuneStop + " iterations)");
1029             }
1030 
1031 /* degrades results:
1032             // reshuffling train files in 2nd + following TUNE iterations
1033             if (i > 1) {
1034                 Collections.shuffle(trainFiles, pseudoRandom);
1035                 Util.LOG.debug("Pseudo-randomly reshuffled training files");
1036             } */
1037 
1038             final Iterator trainIter = trainFiles.iterator();
1039             trainer.resetGlobalAccuracy();
1040 
1041             // measure sentence filtering, if used
1042             if (trainer.isSentenceFiltering()) {
1043                 sentenceMetrics = new FMetrics();
1044             }
1045 
1046             //  run on trainFiles
1047             while (trainIter.hasNext()) {
1048                 currentFile = new File(inDirectory, (String) trainIter.next());
1049                 Util.LOG.debug("Starting to train " + currentFile);
1050 
1051                 try {
1052                     currentDoc =
1053                         DOMUtils.readDocument(currentFile, getConfig());
1054                 } catch (DocumentException de) {
1055                     // wrap exception
1056                     throw new ParsingException("Error while parsing "
1057                             + currentFile + ": " + de.toString(), de);
1058                 }
1059 
1060                 // read answer keys in DSV format (*.ans file must exist)
1061                 currentAnswers = AnswerBuilder.readCorrespondingAnswerKeys(
1062                     trainer.getTargetStructure(), currentFile, getConfig());
1063 
1064                 // train the trainer
1065                 currentAcc = trainer.train(currentDoc, currentAnswers);
1066 
1067                 // store global + local accuracy if in TOE mode
1068                 if (currentAcc != null) {
1069                     globalAcc = trainer.viewGlobalAccuracy();
1070 
1071                     // initialize containers once
1072                     if (accContainers == null) {
1073                         accContainers = new FieldContainer[globalAcc.length];
1074                         for (int j = 0; j < globalAcc.length; j++) {
1075                             accContainers[j]
1076                                 = FieldContainer.createFieldContainer();
1077                         }
1078                     }
1079 
1080                     // store each accuracy
1081                     for (int j = 0; j < globalAcc.length; j++) {
1082                         accMap = globalAcc[j].storeFields();
1083                         accMap.putAll(currentAcc[j].storeFields());
1084                         accMap.put(Prediction.KEY_SOURCE,
1085                             IOUtils.getBaseName(currentFile));
1086                         if (storeIterations) {
1087                             accMap.put(KEY_ITERATION, new Integer(i));
1088                         }
1089                         accContainers[j].add(accMap);
1090                     }
1091                 }
1092                 Util.LOG.info("Trained " + currentFile);
1093 
1094                 // measure sentence filtering, if used
1095                 if (trainer.isSentenceFiltering()) {
1096                     newSentenceMetrics = trainer.evaluateSentenceFiltering();
1097                     Util.LOG.debug("Evaluated sentence filtering for trained "
1098                             + "document: " + newSentenceMetrics);
1099                     sentenceMetrics.update(newSentenceMetrics);
1100                 }
1101             }
1102 
1103             // check whether to continue sentence training + update metrics
1104             checkSentenceTraining(trainer, sentenceMetrics, i);
1105 
1106             // stop TUNE if all accuracies did not or cannot longer increase
1107             if (globalAcc != null) {
1108                 currentOverallAcc = new double[globalAcc.length];
1109                 allAreOptimal = true;
1110                 noneGotBetter = true;
1111                 someGotWorse = false;
1112 
1113                 for (int j = 0; j < currentOverallAcc.length; j++) {
1114                     currentOverallAcc[j] = globalAcc[j].getAccuracy();
1115 
1116                     // test whether all accuracies are optimal
1117                     allAreOptimal = allAreOptimal
1118                         && (currentOverallAcc[j] >= 1.0);
1119 
1120                     // test whether all accuracies are better than the last ones
1121                     if (lastOverallAcc != null) {
1122                         noneGotBetter = noneGotBetter
1123                             && (currentOverallAcc[j] <= lastOverallAcc[j]);
1124                         someGotWorse = someGotWorse
1125                             || (currentOverallAcc[j] < lastOverallAcc[j]);
1126                     } else {
1127                         noneGotBetter = false;
1128                     }
1129                 }
1130 
1131                 if (noneGotBetter) {
1132                     noneGotBetterCounter++;
1133 
1134                     if (noneGotBetterCounter >= tuneStop) {
1135                         // reached stopping criterium for TUNE training
1136                         continueTraining = false;
1137                         Util.LOG.debug("Stopping TUNE training after " + i
1138                                 + " iterations because current accuracies ("
1139                                 + ArrayUtils.toString(currentOverallAcc)
1140                                 + ") aren't higher than last ones ("
1141                                 + ArrayUtils.toString(lastOverallAcc)
1142                                 + ") for the " + tuneStop + ". time");                        
1143                     } else if (someGotWorse) {
1144                         // stop TUNE training because accuracy degraded
1145                         continueTraining = false;
1146                         Util.LOG.debug("Stopping TUNE training after " + i
1147                                 + " iterations because current accuracies ("
1148                                 + ArrayUtils.toString(currentOverallAcc)
1149                                 + ") are lower than last ones ("
1150                                 + ArrayUtils.toString(lastOverallAcc) + ")");                        
1151                     }
1152                 }
1153 
1154                 if (allAreOptimal) {
1155                     continueTraining = false;
1156                     Util.LOG.debug("Stopping TUNE training after " + i
1157                             + " iterations because all accuracies are already "
1158                             + "optimal: "
1159                             + ArrayUtils.toString(currentOverallAcc));
1160                 }
1161 
1162                 lastOverallAcc = currentOverallAcc;
1163             }
1164 
1165             if ((tuneEach && (i >= tuneSince))
1166                     || !continueTraining
1167                     || (i == tuneIterations)
1168                     || tuneEvaluations.contains(Integer.valueOf(i))) {
1169                 // evaluate extractions
1170                 Util.LOG.info("Finished training using "
1171                     + ArrayUtils.toString(trainer.getClassifiers()) + "; "
1172                     + Util.showDuration(trainStartTime));
1173                 final long evalStartTime = System.currentTimeMillis();
1174                 evalFileNum = 0;
1175 
1176                 // create extractor (first time only)
1177                 if (extractor == null) {
1178                     extractor = initExtractor(trainer);
1179                 }
1180 
1181                 final Iterator evalIter = evalFiles.iterator();
1182                 final EvaluatedExtractionContainer evaluated =
1183                     new EvaluatedExtractionContainer(
1184                         extractor.getTargetStructure(), getConfig());
1185                 ExtractionContainer currentResults;
1186                 MultiFMetricsView interimMetrics;
1187                 String currentType;
1188                 Iterator typeIter;
1189 
1190                 // measure sentence filtering, if used
1191                 if (extractor.isSentenceFiltering()) {
1192                     sentenceMetrics = new FMetrics();
1193                  }
1194 
1195                 // run on evalFiles
1196                 while (evalIter.hasNext()) {
1197                     currentFile = new File(inDirectory,
1198                         (String) evalIter.next());
1199                     evalFileNum++;
1200                     Util.LOG.debug("Starting to extract and evaluate file #"
1201                             + evalFileNum + ": " + currentFile);
1202 
1203                     try {
1204                         currentDoc = DOMUtils.readDocument(currentFile,
1205                             getConfig());
1206                     } catch (DocumentException de) {
1207                         // wrap exception
1208                         throw new ParsingException(de);
1209                     }
1210 
1211                     // extract results
1212                     currentResults = extractor.extract(currentDoc);
1213 
1214                     // read answer key
1215                     currentAnswers = AnswerBuilder.readCorrespondingAnswerKeys(
1216                         trainer.getTargetStructure(), currentFile, getConfig());
1217 
1218                     // invoke the trainer if feedback should be given
1219                     if (feedback) {
1220                         currentAcc = trainer.train(currentDoc, currentAnswers);                        
1221                     }
1222                     
1223                     // measure sentence filtering, if used (we do this prior to
1224                     // evaluating the predicted extractions because
1225                     // evaluateBatch destructively modifies the answer keys)
1226                     if (extractor.isSentenceFiltering()) {
1227                         newSentenceMetrics =
1228                             extractor.evaluateSentenceFiltering(currentAnswers);
1229                         Util.LOG.debug("Evaluated sentence filtering for "
1230                                 + "current document: " + newSentenceMetrics);
1231                         sentenceMetrics.update(newSentenceMetrics);
1232                     }
1233 
1234                     // compare with results, storing base name as key
1235                     evaluated.evaluateBatch(currentResults, currentAnswers,
1236                         IOUtils.getBaseName(currentFile));
1237 
1238                     // log results
1239                     interimMetrics = evaluated.viewMetrics();
1240                     Util.LOG.info("Extracted and evaluated " + currentFile
1241                         + ", interim results: " + interimMetrics.viewAll());
1242                     typeIter = interimMetrics.types().iterator();
1243 
1244                     while (typeIter.hasNext()) {
1245                         currentType = (String) typeIter.next();
1246                         Util.LOG.debug("Interim results for " + currentType
1247                             + ": " + interimMetrics.view(currentType));
1248                     }
1249 
1250                     // store results after each 10th file (and after last file)
1251                     // if feedback gis iven
1252                     if (feedback && ((evalFileNum % MEASURE_FEEDBACK == 0)
1253                             || !evalIter.hasNext())) {
1254                         feedbackNum = new Double(Math.ceil((double) evalFileNum
1255                                 / MEASURE_FEEDBACK)).intValue();
1256 
1257                         // initialize new metrics if necessary (first run)
1258                         if (feedbackAverages.size() < feedbackNum) {
1259                             feedbackAverages.add(new MultiFMetrics(true));
1260                         }
1261 
1262                         // add current results
1263                         feedbackAverages.get(feedbackNum - 1).update(
1264                                 evaluated.viewMetrics());
1265                     }
1266                 }
1267 
1268                 serializeExtractions(outDirectory, baseName, i, evaluated,
1269                         sentenceMetrics);
1270                 results.addEvaluated(i, evaluated);
1271                 Util.LOG.info("Finished extraction and evaluation using "
1272                     + ArrayUtils.toString(extractor.getClassifiers()) + "; "
1273                     + Util.showDuration(evalStartTime));
1274             }
1275 
1276             // update max. number of iterations used in any batch if higher
1277             lastUsedTuneIteration = Math.max(lastUsedTuneIteration, i);
1278         }
1279 
1280         // serialize metrics + return extractions and feature counts
1281         serializeTrainingMetrics(outDirectory, baseName, accContainers,
1282                 globalAcc);
1283         serializeMetrics(outDirectory, baseName, writer, results);
1284         results.setTrainFeatureCV(trainer.viewFeatureCount());
1285         results.setExtractFeatureCV(extractor.viewFeatureCount());
1286         return results;
1287     }
1288 
1289     /***
1290      * Adds the results of a training or evaluation iteration to the statistics
1291      * calculated for sentence filtering.
1292      *
1293      * @param newMetrics the metrics to add
1294      * @param isTraining whether this is a training or an evaluation result
1295      * @param iteration the number of the iteration (counting from 1)
1296      */
1297     private void updateSentenceMetricsStore(final FMetricsView newMetrics,
1298             final boolean isTraining, final int iteration) {
1299         final FieldMap fields = newMetrics.storeFields();
1300 
1301         // type is either "Train" or "Eval"
1302         fields.put(KEY_TYPE, isTraining ? TYPE_TRAIN : TYPE_EVAL);
1303         fields.put(KEY_ITERATION, new Integer(iteration));
1304         sentenceMetricsStore.add(fields);
1305     }
1306 
1307 }