1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22 package de.fu_berlin.ties.extract;
23
24 import java.io.File;
25 import java.io.IOException;
26 import java.io.Reader;
27 import java.io.Writer;
28 import java.util.ArrayList;
29 import java.util.HashSet;
30 import java.util.Iterator;
31 import java.util.LinkedList;
32 import java.util.List;
33 import java.util.NoSuchElementException;
34 import java.util.Set;
35 import java.util.SortedMap;
36 import java.util.TreeMap;
37
38 import org.apache.commons.lang.ArrayUtils;
39 import org.apache.commons.lang.builder.ToStringBuilder;
40 import org.apache.commons.math.stat.univariate.moment.StandardDeviation;
41 import org.dom4j.Document;
42 import org.dom4j.DocumentException;
43
44 import de.fu_berlin.ties.Closeable;
45 import de.fu_berlin.ties.ContextMap;
46 import de.fu_berlin.ties.ParsingException;
47 import de.fu_berlin.ties.ProcessingException;
48 import de.fu_berlin.ties.TextProcessor;
49 import de.fu_berlin.ties.TiesConfiguration;
50 import de.fu_berlin.ties.classify.Prediction;
51 import de.fu_berlin.ties.eval.AccuracyView;
52 import de.fu_berlin.ties.eval.FMetrics;
53 import de.fu_berlin.ties.eval.FMetricsView;
54 import de.fu_berlin.ties.eval.FeatureCountView;
55 import de.fu_berlin.ties.eval.MultiFMetrics;
56 import de.fu_berlin.ties.eval.MultiFMetricsView;
57 import de.fu_berlin.ties.io.FieldContainer;
58 import de.fu_berlin.ties.io.FieldMap;
59 import de.fu_berlin.ties.io.IOUtils;
60 import de.fu_berlin.ties.util.Util;
61 import de.fu_berlin.ties.xml.dom.DOMUtils;
62
63 /***
64 * Trains an extractor and evaluates extraction quality. Processes shuffle
65 * files (as generated by {@link de.fu_berlin.ties.eval.ShuffleGenerator}
66 * contain the files to use for training and evaluation.
67 * For each of these files, a corresponding answer key (*.ans) must exist.
68 *
69 * <p>Instances of this class are not thread-safe.
70 *
71 * @author Christian Siefkes
72 * @version $Revision: 1.67 $, $Date: 2004/11/17 09:17:04 $, $Author: siefkes $
73 */
74 public class TrainEval extends TextProcessor implements Closeable {
75
76 /***
77 * An inner class wrapping the results of a
78 * {@linkplain TrainEval#trainAndEval(String[], File, File, String, Writer)
79 * training + evaluation run}.
80 */
81 public static class Results {
82
83 /***
84 * The map from TUNE iterations to
85 * {@link EvaluatedExtractionContainer}s containing the evaluated
86 * extractions (after the corresponding TUNE iteration).
87 */
88 private final SortedMap<Integer, EvaluatedExtractionContainer> evaluated
89 = new TreeMap<Integer, EvaluatedExtractionContainer>();
90
91 /***
92 * A read-only view on the feature count statistics collected during
93 * training.
94 */
95 private FeatureCountView trainFeatureCV;
96
97 /***
98 * A read-only view on the feature count statistics collected during
99 * extraction.
100 */
101 private FeatureCountView extractFeatureCV;
102
103 /***
104 * Creates a new instance.
105 */
106 public Results() {
107 }
108
109 /***
110 * Adds an evaluated extraction container.
111 *
112 * @param iteration the number of the TUNE iteration to add
113 * @param evaluatedExtractions the container to add
114 * @throws IllegalArgumentException if there already is a container
115 * stored for <code>iteration</code>
116 */
117 private void addEvaluated(final int iteration,
118 final EvaluatedExtractionContainer evaluatedExtractions)
119 throws IllegalArgumentException {
120 final EvaluatedExtractionContainer oldValue =
121 evaluated.put(Integer.valueOf(iteration), evaluatedExtractions);
122
123 if (oldValue != null) {
124 throw new IllegalArgumentException("Cannot store two "
125 + "extraction containers for same iteration ("
126 + iteration + "): " + oldValue + "; "
127 + evaluatedExtractions);
128 }
129 }
130
131 /***
132 * Returns one of the stored evaluated extraction containers.
133 *
134 * @param iteration the number of the TUNE iteration to return;
135 * must be contained in {@link #iterations()}, otherwise
136 * <code>null</code> will be returned
137 * @return the specified container
138 */
139 public EvaluatedExtractionContainer getEvaluated(
140 final Integer iteration) {
141 return evaluated.get(iteration);
142 }
143
144 /***
145 * Returns one of the stored evaluated extraction containers.
146 *
147 * @param iteration the number index of the TUNE iterations to return;
148 * must be contained in {@link #iterations()}, otherwise
149 * <code>null</code> will be returned
150 * @return the specified container
151 */
152 public EvaluatedExtractionContainer getEvaluated(final int iteration) {
153 return getEvaluated(Integer.valueOf(iteration));
154 }
155
156 /***
157 * Returns a read-only view on the feature count statistics collected
158 * during training.
159 * @return the stored attribute
160 */
161 public FeatureCountView getExtractFeatureCV() {
162 return extractFeatureCV;
163 }
164
165 /***
166 * Returns a read-only view on the feature count statistics collected
167 * during extraction.
168 * @return the stored attribute
169 */
170 public FeatureCountView getTrainFeatureCV() {
171 return trainFeatureCV;
172 }
173
174 /***
175 * Returns an iterator over the TUNE iterations for which extraction
176 * results are stored.
177 *
178 * @return an iterator over the TUNE iterations
179 */
180 public Iterator<Integer> iterations() {
181 return evaluated.keySet().iterator();
182 }
183
184 /***
185 * Stores a read-only view on the feature count statistics collected
186 * during training.
187 *
188 * @param extractFeatures a read-only view on the feature count
189 * statistics collected during extraction
190 */
191 private void setExtractFeatureCV(
192 final FeatureCountView extractFeatures) {
193 extractFeatureCV = extractFeatures;
194 }
195
196 /***
197 * Stores a read-only view on the feature count statistics collected
198 * during extraction.
199 *
200 * @param trainFeatures a read-only view on the feature count statistics
201 * collected during training
202 */
203 private void setTrainFeatureCV(
204 final FeatureCountView trainFeatures) {
205 trainFeatureCV = trainFeatures;
206 }
207
208 /***
209 * Returns a string representation of this object.
210 *
211 * @return a textual representation
212 */
213 public String toString() {
214 return new ToStringBuilder(this)
215 .append("training feature count", trainFeatureCV)
216 .append("extraction feature count", extractFeatureCV)
217 .append("size", evaluated.size())
218 .toString();
219 }
220
221 }
222
223 /***
224 * Configuration key: The percentage of a corpus to use for training.
225 */
226 public static final String CONFIG_TRAIN_SPLIT = "eval.train-split";
227
228 /***
229 * Configuration key: The percentage of a corpus to use for testing
230 * (evaluation).
231 */
232 public static final String CONFIG_TEST_SPLIT = "eval.test-split";
233
234 /***
235 * Configuration key: If <code>true</code>, a fully incremental setup is
236 * used where the trainer is trained on each document after the extractor
237 * processed it.
238 */
239 public static final String CONFIG_FEEDBACK = "eval.feedback";
240
241 /***
242 * Configuration key: The maximum number of iterations used for TUNE
243 * (train until no error) training; if 1, training is incremental.
244 */
245 public static final String CONFIG_TUNE = "train.tune";
246
247 /***
248 * Configuration key: TUNE training is stopped if the training accuracy
249 * didn't improve for the specified number of iterations.
250 */
251 public static final String CONFIG_TUNE_STOP = "train.tune.stop";
252
253 /***
254 * Configuration key: The maximum number of iterations used for TUNE
255 * training the sentence classifier; if 0 or negative, the value of
256 * {@link #CONFIG_TUNE} is used.
257 */
258 public static final String CONFIG_SENTENCE_TUNE = "sent.tune";
259
260 /***
261 * Configuration key: Whether to measure results after each TUNE iteration
262 * or only at the end of training.
263 */
264 public static final String CONFIG_TUNE_EACH = "eval.tune.each";
265
266 /***
267 * Configuration key: The training iteration after which to evaluate
268 * results for the first time if {@link #CONFIG_TUNE_EACH} is enabled.
269 */
270 public static final String CONFIG_TUNE_SINCE = "eval.tune.since";
271
272 /***
273 * Serialization key for the number of the iteration (when TUNE training).
274 */
275 public static final String KEY_ITERATION = "Iteration";
276
277 /***
278 * Serialization key for the number of the run.
279 */
280 public static final String KEY_RUN = "Run";
281
282 /***
283 * Serialization key for the type (either "Train" or "Eval").
284 */
285 public static final String KEY_TYPE = "Type";
286
287 /***
288 * Serialization value for the "Train" type.
289 */
290 public static final String TYPE_TRAIN = "Train";
291
292 /***
293 * Serialization value for the "Eval" type.
294 */
295 public static final String TYPE_EVAL = "Eval";
296
297 /***
298 * Results are measured after each 10th evaluated document if
299 * {@link #feedback} is given.
300 */
301 private static final int MEASURE_FEEDBACK = 10;
302
303 /***
304 * The percentage of a corpus to use for training.
305 */
306 private final float trainSplit;
307
308 /***
309 * The percentage of a corpus to use for testign; if <code>-1</code>,
310 * all remaining documents are used.
311 */
312 private final float testSplit;
313
314 /***
315 * If <code>true</code>, a fully incremental setup is used where the
316 * trainer is trained on each document after the extractor processed it.
317 */
318 private final boolean feedback;
319
320 /***
321 * The maximum number of iterations used for TUNE (train until no error)
322 * training; if 1, training is incremental.
323 */
324 private final int tuneIterations;
325
326 /***
327 * TUNE training is stopped if the training accuracy didn't improve for the
328 * specified number of iterations.
329 */
330 private final int tuneStop;
331
332 /***
333 * The maximum number of iterations used for TUNE training the sentence
334 * classifier; if 0 or negative, the value of {@link #tuneIterations} is
335 * used.
336 */
337 private final int sentenceIterations;
338
339 /***
340 * Whether to measure results after each TUNE iteration or only at the
341 * end of training.
342 */
343 private final boolean tuneEach;
344
345 /***
346 * The training iteration after which to evaluate results for the first
347 * time if {@link #tuneEach} is enabled.
348 */
349 private final int tuneSince;
350
351 /***
352 * A set of iterations after which to evaluate TUNE training in addition to
353 * the last one; ignored if {@link #tuneEach} is <code>true</code>.
354 */
355 private final Set<Integer> tuneEvaluations = new HashSet<Integer>();
356
357 /***
358 * The last TUNE iteration that was actually used for training in any
359 * batch of processed files.
360 */
361 private int lastUsedTuneIteration = -1;
362
363 /***
364 * An array of multi-metrics calculating summaries esp. standard deviations
365 * over all runs. Each array element corresponds to one TUNE operation.
366 */
367 private final MultiFMetrics[] averages;
368
369 /***
370 * An array of multi-metrics incrementally calculating summaries after
371 * evaluating each 10th document if {@link #feedback} is given.
372 */
373 private final List<MultiFMetrics> feedbackAverages;
374
375 /***
376 * Calculate feature count statistics over all runs.
377 */
378 private final FieldContainer featureCounts =
379 FieldContainer.createFieldContainer();
380
381 /***
382 * Stores precision and recall values for sentence filtering, if used.
383 */
384 private FieldContainer sentenceMetricsStore;
385
386 /***
387 * The number of the current run.
388 */
389 private int runNo = 0;
390
391 /***
392 * Creates a new instance, using a default extension and the
393 * {@linkplain TiesConfiguration#CONF standard configuration}.
394 *
395 * @throws IllegalArgumentException if the configured values are outside the
396 * allowed ranges
397 * @throws ClassCastException if the configured numeric values cannot be
398 * parsed
399 * @throws NoSuchElementException if one of the required values is
400 * missing from the configuration
401 */
402 public TrainEval() throws IllegalArgumentException, ClassCastException,
403 NoSuchElementException {
404 this("metrics");
405 }
406
407 /***
408 * Creates a new instance, using the {@linkplain TiesConfiguration#CONF
409 * standard configuration}.
410 *
411 * @param outExt the extension to use for output files
412 * @throws IllegalArgumentException if the configured values are outside the
413 * allowed ranges
414 * @throws ClassCastException if the configured numeric values cannot be
415 * parsed
416 * @throws NoSuchElementException if one of the required values is
417 * missing from the configuration
418 */
419 public TrainEval(final String outExt) throws IllegalArgumentException,
420 ClassCastException, NoSuchElementException {
421 this(outExt, TiesConfiguration.CONF);
422 }
423
424 /***
425 * Creates a new instance.
426 *
427 * @param outExt the extension to use for output files
428 * @param config used to configure this instance
429 * @throws IllegalArgumentException if the configured values are outside the
430 * allowed ranges
431 * @throws ClassCastException if the configured numeric values cannot be
432 * parsed
433 * @throws NoSuchElementException if one of the required values is
434 * missing from the configuration
435 */
436 public TrainEval(final String outExt, final TiesConfiguration config)
437 throws IllegalArgumentException, ClassCastException,
438 NoSuchElementException {
439
440 this(outExt, config.getFloat(CONFIG_TRAIN_SPLIT),
441 config.getFloat(CONFIG_TEST_SPLIT),
442 config.getInt(CONFIG_TUNE),
443 config.getInt(CONFIG_TUNE_STOP),
444 config.getBoolean(CONFIG_TUNE_EACH),
445 config.getInt(CONFIG_TUNE_SINCE),
446 config.getList("eval.tune.list"),
447 config.getInt(CONFIG_SENTENCE_TUNE),
448 config.getBoolean(CONFIG_FEEDBACK), config);
449 }
450
451 /***
452 * Creates a new instance.
453 *
454 * @param outExt the extension to use for output files
455 * @param trainingSplit the percentage of a corpus to use for training
456 * @param testingSplit the percentage of a corpus to use for testing
457 * (evaluation); if <code>-1</code>, all remaining documents (1 -
458 * <code>trainingSplit</code>) are used
459 * @param tuneRuns the maximum number of iterations used for TUNE
460 * (train until no error) training; if 1, training is incremental
461 * @param tuneStopAfter TUNE training is stopped if the training accuracy
462 * didn't improve for the specified number of iterations.
463 * @param measureEachTUNE whether to measure results after each TUNE
464 * iteration or only at the end of training
465 * @param startMeasureTUNE he training iteration after which to evaluate
466 * results for the first time if <code>measureEachTUNE</code> is enabled
467 * (ignored otherwise)
468 * @param sentenceTUNE the maximum number of iterations used for TUNE
469 * training the sentence classifier (if used); if 0 or negative, the value
470 * of <code>tuneRuns</code> is used
471 * @param tuneEvalList A list of Integers or int Strings specifying
472 * iterations after which to evaluate TUNE training in addition to the last
473 * one; ignored if <code>measureEachTUNE</code> is <code>true</code>
474 * @param giveFeedback if <code>true</code>, a fully incremental setup is
475 * used where the trainer is trained on each document after the extractor
476 * processed it; it's not allowed to set this both this and
477 * <code>measureEachTUNE</code> to <code>true</code> when training for
478 * several <code>tuneRuns</code> because that would mean to evaluate on the
479 * training set
480 * @param config used to configure superclasses, trainer, and extractor;
481 * if <code>null</code>, the {@linkplain TiesConfiguration#CONF standard
482 * configuration} is used
483 * @throws IllegalArgumentException if <code>trainingSplit</code> is not
484 * a percentage (larger than 1 or smaller than 0) or if
485 * <code>tuneRuns</code> is non-positive
486 */
487 public TrainEval(final String outExt, final float trainingSplit,
488 final float testingSplit, final int tuneRuns,
489 final int tuneStopAfter, final boolean measureEachTUNE,
490 final int startMeasureTUNE, final List tuneEvalList,
491 final int sentenceTUNE, final boolean giveFeedback,
492 final TiesConfiguration config)
493 throws IllegalArgumentException {
494 super(outExt, config);
495
496
497 if ((trainingSplit < 0.0) || (trainingSplit > 1.0)) {
498 throw new IllegalArgumentException(
499 "Train split is not a percentage: " + trainingSplit);
500 }
501 if (testingSplit > 1.0) {
502 throw new IllegalArgumentException(
503 "Test split is not a percentage: " + testingSplit);
504 }
505 if (tuneRuns < 1) {
506 throw new IllegalArgumentException(
507 "Number of TUNE runs must be at least 1: " + tuneRuns);
508 }
509 if (tuneStopAfter < 1) {
510 throw new IllegalArgumentException(
511 "TUNE stopping criterium must be at least 1: " + tuneStopAfter);
512 }
513
514 if (giveFeedback && (tuneRuns > 1) && measureEachTUNE) {
515 throw new IllegalArgumentException("It's not allowed give feedback "
516 + "when evaluating after each of several iteration because "
517 + "that would mean to evaluate on the training set");
518 }
519
520 trainSplit = trainingSplit;
521 testSplit = testingSplit;
522 tuneIterations = tuneRuns;
523 tuneStop = tuneStopAfter;
524 tuneEach = measureEachTUNE;
525 tuneSince = Math.max(startMeasureTUNE, 1);
526 sentenceIterations = sentenceTUNE;
527 feedback = giveFeedback;
528 averages = new MultiFMetrics[tuneIterations];
529
530 if (!tuneEach) {
531 for (final Iterator iter = tuneEvalList.iterator(); iter.hasNext();) {
532
533 tuneEvaluations.add(Integer.valueOf(Util.asInt(iter.next())));
534 }
535 }
536
537
538 if (feedback) {
539 feedbackAverages = new ArrayList<MultiFMetrics>();
540 } else {
541 feedbackAverages = null;
542 }
543
544
545
546
547
548 final FieldContainer dummy = new FieldContainer();
549 final MultiFMetrics dummyMetrics = new MultiFMetrics(true);
550 dummyMetrics.storeEntries(dummy);
551 StandardDeviation dummyDev = new StandardDeviation();
552 dummyDev.clear();
553 }
554
555 /***
556 * Check whether to disable sentnce training; called after each training
557 * iteration.
558 *
559 * @param trainer the used trainer
560 * @param sentenceMetrics metrics calculated for sentence filtering, if
561 * used (<code>null</code> otherwise)
562 * @param iteration the current iteration
563 */
564 private void checkSentenceTraining(final Trainer trainer,
565 final FMetrics sentenceMetrics, final int iteration) {
566 if ((sentenceIterations > 0) && (sentenceIterations == iteration)) {
567 trainer.disableSentenceTraining();
568 Util.LOG.info("Disabled sentence training after " + iteration
569 + " iterations");
570 }
571
572 if (sentenceMetrics != null) {
573 updateSentenceMetricsStore(sentenceMetrics, true, iteration);
574 }
575 }
576
577 /***
578 * {@inheritDoc}
579 */
580 public void close(final int errorCount)
581 throws IOException, ProcessingException {
582
583 if (errorCount <= 0) {
584 final File outDir = IOUtils.determineOutputDirectory(getConfig());
585
586
587 storeValues(featureCounts, outDir, "FeatureCounts", null);
588
589 final FieldContainer averagesStore =
590 FieldContainer.createFieldContainer();
591 boolean storeIterations = (averages.length > 1);
592
593
594 final int lastToStore = Math.min(averages.length,
595 lastUsedTuneIteration);
596
597 for (int i = 1; i <= lastToStore; i++) {
598
599
600
601 if (averages[i-1] != null && (tuneEach
602 || tuneEvaluations.contains(Integer.valueOf(i))
603 || i == lastToStore)) {
604
605 if (storeIterations) {
606 averagesStore.backgroundMap().put(KEY_ITERATION,
607 Integer.valueOf(i));
608 }
609
610 averages[i-1].storeEntries(averagesStore);
611 }
612 }
613 storeValues(averagesStore, outDir, "All",
614 MultiFMetrics.EXT_METRICS);
615
616
617 if (feedbackAverages != null) {
618 final FieldContainer feedbackAveragesStore =
619 FieldContainer.createFieldContainer();
620
621 for (int i = 0; i < feedbackAverages.size(); i++) {
622 feedbackAverages.get(i).storeEntries(feedbackAveragesStore);
623 }
624 storeValues(feedbackAveragesStore, outDir, "Feedback",
625 MultiFMetrics.EXT_METRICS);
626 }
627 }
628 }
629
630 /***
631 * {@inheritDoc}
632 */
633 protected void doProcess(final Reader reader, final Writer writer,
634 final ContextMap context) throws IOException, ProcessingException {
635
636 runNo++;
637 final String[] files = IOUtils.readURIList(reader);
638
639 if (files.length == 0) {
640 Util.LOG.info("No files to process");
641 } else {
642 final String baseName = IOUtils.getBaseName(
643 new File((String) context.get(KEY_LOCAL_NAME)));
644 final File inDir = (File) context.get(KEY_DIRECTORY);
645 final File outDir = IOUtils.determineOutputDirectory(getConfig());
646 Results currentResults;
647 FieldMap currentFMap;
648
649 MultiFMetricsView currentMetrics = null;
650
651
652 currentResults =
653 trainAndEval(files, inDir, outDir, baseName, writer);
654
655
656 final Integer runNumber = new Integer(runNo);
657 currentFMap = currentResults.getTrainFeatureCV().storeFields();
658 currentFMap.put(KEY_RUN, runNumber);
659 currentFMap.put(MultiFMetrics.KEY_TYPE, "training");
660 featureCounts.add(currentFMap);
661
662 currentFMap =
663 currentResults.getExtractFeatureCV().storeFields();
664 currentFMap.put(KEY_RUN, runNumber);
665 currentFMap.put(MultiFMetrics.KEY_TYPE, "extraction");
666 featureCounts.add(currentFMap);
667
668
669 final Iterator<Integer> iterIter = currentResults.iterations();
670 Integer iteration;
671 int it = 0;
672
673 while (iterIter.hasNext()) {
674 iteration = iterIter.next();
675 it = iteration.intValue();
676 currentMetrics =
677 currentResults.getEvaluated(iteration).viewMetrics();
678 if (averages[it-1] == null) {
679 averages[it-1] = new MultiFMetrics(true);
680 }
681 averages[it-1].update(currentMetrics);
682 }
683
684
685
686 for (it++; it <= averages.length; it++) {
687 if (averages[it-1] == null) {
688 averages[it-1] = new MultiFMetrics(true);
689 }
690
691
692 averages[it-1].update(currentMetrics);
693 }
694 }
695 }
696
697 /***
698 * Returns the percentage of a corpus to use for testing (evaluation).
699 *
700 * @return the percentage to use for evaluation; if negative, all
701 * remaining documents (1 - {@link #getTrainSplit()}) are used for
702 * evaluation
703 */
704 public float getTestSplit() {
705 return testSplit;
706 }
707
708 /***
709 * Returns the percentage of a corpus to use for training; the remaining
710 * documents (1-x) are used for evaluation.
711 *
712 * @return the percentage to use for training
713 */
714 public float getTrainSplit() {
715 return trainSplit;
716 }
717
718 /***
719 * Creates and initializes a extractor to use for an evaluation run,
720 * re-using the components of the provided trainer. Subclasses can
721 * overwrite this method to provide a different extractor.
722 *
723 * @param trainer trainer whose components should be re-used
724 * @return the created extractor
725 */
726 protected Extractor initExtractor(final Trainer trainer) {
727
728 return new Extractor(null, trainer);
729 }
730
731 /***
732 * Creates and initializes a trainer to use for an evaluation run,
733 * configured from the
734 * {@link de.fu_berlin.ties.ConfigurableProcessor#getConfig() stored
735 * configuration}. Subclasses can overwrite this method to provide a
736 * different trainer.
737 *
738 * @param runDirectory directory used to run the classifier
739 * @return the created trainer
740 * @throws ProcessingException if an error occurs during initialization
741 */
742 protected Trainer initTrainer(final File runDirectory)
743 throws ProcessingException {
744
745 final Trainer trainer = new Trainer(null, runDirectory, getConfig());
746
747 trainer.reset();
748 return trainer;
749 }
750
751 /***
752 * Returns a string representation of this object.
753 *
754 * @return a textual representation
755 */
756 public String toString() {
757 return new ToStringBuilder(this)
758 .appendSuper(super.toString())
759 .append("train split", trainSplit)
760 .append("test split", testSplit)
761 .append("tune iterations", tuneIterations)
762 .append("tune stops after", tuneStop)
763 .append("measure after each iteration", tuneEach)
764 .append("starting from", tuneSince)
765 .append("sentence iterations", sentenceIterations)
766 .append("feedback", feedback)
767 .toString();
768 }
769
770 /***
771 * Chooses files to use for training and files to use for evaluation,
772 * depending on the configured settings.
773 *
774 * @param allFiles the array of file names to process
775 * @param trainFiles populated with the files to use for training, will
776 * be populated with the first <em>{@link #getTrainSplit()} *
777 * allFiles.length</em> files; must initially be empty
778 * @param evalFiles populated with the files to use for evaluation, will
779 * be populated from the next <em>{@link #getTestSplit()} *
780 * allFiles.length</em> remaining files (or all remaining files if test
781 * split is negative); must initially be empty
782 * @throws IllegalArgumentException if the lists aren't empty
783 */
784 private void selectFiles(final String[] allFiles,
785 final List<String> trainFiles, final List<String> evalFiles)
786 throws IllegalArgumentException {
787
788 if (!trainFiles.isEmpty() || !evalFiles.isEmpty()) {
789 throw new IllegalArgumentException(
790 "Lists of train files and eval files must initially be empty");
791 }
792
793 final int numTrainFiles = Math.round(trainSplit * allFiles.length);
794 final int filesToUse;
795
796 if (testSplit < 0) {
797
798 filesToUse = allFiles.length;
799 } else {
800 filesToUse = Math.min(allFiles.length,
801 Math.round((trainSplit + testSplit) * allFiles.length));
802 }
803
804
805 for (int i = 0; i < filesToUse; i++) {
806 if (i < numTrainFiles) {
807
808 trainFiles.add(allFiles[i]);
809 } else {
810
811 evalFiles.add(allFiles[i]);
812 }
813 }
814
815 Util.LOG.debug("Using " + filesToUse + " of " + allFiles.length
816 + " files: " + trainFiles.size() + " for training, "
817 + evalFiles.size() + " for evaluation");
818 }
819
820 /***
821 * Helper method for serializing the calculated metrics.
822 *
823 * @param outDirectory directory used to do this run and store the results
824 * @param baseName the base name of the files to use for storing
825 * all extractions and training statistics
826 * @param iteration the number of the current iteration
827 * @param evaluated the container of evaluated extractions to tore
828 * @param sentenceMetrics metrics calculated for sentence filtering, if
829 * used (<code>null</code> otherwise)
830 * @throws IOException if an I/O error occurs while serializing
831 */
832 private void serializeExtractions(final File outDirectory,
833 final String baseName, final int iteration,
834 final EvaluatedExtractionContainer evaluated,
835 final FMetrics sentenceMetrics)
836 throws IOException {
837
838 final FieldContainer resultStorage =
839 FieldContainer.createFieldContainer();
840 evaluated.storeEntries(resultStorage);
841 final File evaluatedFile = storeValues(resultStorage,
842 outDirectory, baseName, Extractor.EXT_EXTRACTIONS);
843 Util.LOG.info("Stored results of training + evaluation run in "
844 + evaluatedFile);
845
846
847 if (sentenceMetrics != null) {
848 updateSentenceMetricsStore(sentenceMetrics, false, iteration);
849 }
850 }
851
852 /***
853 * Helper method for serializing the calculated metrics.
854 *
855 * @param outDirectory directory used to do this run and store the results
856 * @param baseName the base name of the files to use for storing
857 * all extractions and training statistics
858 * @param writer used to serialize the calculated metrics
859 * @param results contains the evaluated extractions
860 * @throws IOException if an I/O error occurs while serializing the metrics
861 */
862 private void serializeMetrics(final File outDirectory,
863 final String baseName, final Writer writer,
864 final Results results)
865 throws IOException {
866
867 final FieldContainer metricsStore =
868 FieldContainer.createFieldContainer();
869 Integer iteration;
870 EvaluatedExtractionContainer evaluated;
871 final boolean storeIterations = (tuneIterations > 1);
872 Iterator<Integer> tuneIter = results.iterations();
873
874 while (tuneIter.hasNext()) {
875 iteration = tuneIter.next();
876 evaluated = results.getEvaluated(iteration);
877
878
879 if (storeIterations) {
880 metricsStore.backgroundMap().put(KEY_ITERATION, iteration);
881 }
882
883 evaluated.viewMetrics().storeEntries(metricsStore);
884 }
885
886 metricsStore.store(writer);
887 writer.flush();
888
889
890 if (sentenceMetricsStore != null) {
891 storeValues(sentenceMetricsStore, outDirectory,
892 baseName, "sent");
893 }
894 }
895
896 /***
897 * Helper method for serializing the training accuracies.
898 *
899 * @param outDirectory directory used to do this run and store the results
900 * @param baseName the base name of the files to use for storing
901 * all extractions and training statistics
902 * @param accContainers a an of containers used to store the accuracy of
903 * each classifier, might be <code>null</code> if accuracies aren't measured
904 * @param globalAcc an array of accuracies, one for each classifier
905 * @throws IOException if an I/O error occurs while serializing
906 */
907 private void serializeTrainingMetrics(final File outDirectory,
908 final String baseName, final FieldContainer[] accContainers,
909 final AccuracyView[] globalAcc)
910 throws IOException {
911 if (accContainers != null) {
912 String name;
913
914 for (int j = 0; j < globalAcc.length; j++) {
915 if (globalAcc.length > 1) {
916
917 name = baseName + (char) ('a' + j);
918 } else {
919
920 name = baseName;
921 }
922 storeValues(accContainers[j], outDirectory, name, "train");
923 }
924 }
925 }
926
927 /***
928 * Stores the contents of a storable container.
929 *
930 * @param container the container to store
931 * @param directory the directory in which to store the data
932 * @param baseName the base name of the file using for storing the data
933 * @param extension the file extension to use; if <code>null</code>,
934 * the {@linkplain FieldContainer#recommendedExtension() recommended
935 * extension} is used
936 * @return the file used for storing the data
937 * @throws IOException if an I/O error occurs
938 */
939 private File storeValues(final FieldContainer container,
940 final File directory, final String baseName,
941 final String extension) throws IOException {
942
943
944 final String usedExtension = ((extension != null)
945 ? extension : FieldContainer.recommendedExtension());
946 final File outFile = IOUtils.createOutFile(directory, baseName,
947 usedExtension);
948 final Writer writer = IOUtils.openWriter(outFile, getConfig());
949 container.store(writer);
950 writer.flush();
951 writer.close();
952 return outFile;
953 }
954
955 /***
956 * Processes an array of files. For each file, a corresponding answer key
957 * (*.ans) must exist.
958 *
959 * @param files the array of file names to process (relative to the
960 * <code>inDirectory</code>)
961 * @param inDirectory directory containing the files to process
962 * @param outDirectory directory used to do this run and store the results
963 * @param baseName the base name of the files to use for storing
964 * all extractions and training statistics
965 * @param writer used to serialize the calculated metrics
966 * @return a wrapper of the results of this run
967 * @throws IOException if an I/O error occurs
968 * @throws ProcessingException if an error occurs during processing
969 */
970 public Results trainAndEval(final String[] files, final File inDirectory,
971 final File outDirectory, final String baseName,
972 final Writer writer) throws IOException, ProcessingException {
973
974 IOUtils.setDefaultDirectory(outDirectory);
975
976
977 final List<String> trainFiles = new LinkedList<String>();
978 final List<String> evalFiles = new LinkedList<String>();
979 selectFiles(files, trainFiles, evalFiles);
980
981 final long trainStartTime = System.currentTimeMillis();
982 File currentFile;
983 Document currentDoc;
984 ExtractionContainer currentAnswers;
985
986
987 FieldContainer[] accContainers = null;
988 AccuracyView[] currentAcc;
989 AccuracyView[] globalAcc = null;
990 FieldMap accMap;
991 final boolean storeIterations = (tuneIterations > 1);
992
993
994 double[] currentOverallAcc;
995 double[] lastOverallAcc = null;
996 boolean allAreOptimal;
997 boolean noneGotBetter;
998 boolean someGotWorse;
999 int noneGotBetterCounter = 0;
1000 boolean continueTraining = true;
1001 int evalFileNum, feedbackNum;
1002
1003
1004 final Trainer trainer = initTrainer(outDirectory);
1005 Extractor extractor = null;
1006
1007
1008 FMetrics sentenceMetrics = null;
1009 FMetricsView newSentenceMetrics;
1010
1011 if (trainer.isSentenceFiltering()) {
1012 sentenceMetricsStore = FieldContainer.createFieldContainer();
1013 } else {
1014 sentenceMetricsStore = null;
1015 }
1016
1017
1018 final Results results = new Results();
1019
1020
1021
1022
1023
1024 for (int i = 1; (i <= tuneIterations) && continueTraining; i++) {
1025 if (tuneIterations > 1) {
1026 Util.LOG.debug("Starting TUNE iteration " + i + "/"
1027 + tuneIterations + " (will stop if no improvement for "
1028 + tuneStop + " iterations)");
1029 }
1030
1031
1032
1033
1034
1035
1036
1037
1038 final Iterator trainIter = trainFiles.iterator();
1039 trainer.resetGlobalAccuracy();
1040
1041
1042 if (trainer.isSentenceFiltering()) {
1043 sentenceMetrics = new FMetrics();
1044 }
1045
1046
1047 while (trainIter.hasNext()) {
1048 currentFile = new File(inDirectory, (String) trainIter.next());
1049 Util.LOG.debug("Starting to train " + currentFile);
1050
1051 try {
1052 currentDoc =
1053 DOMUtils.readDocument(currentFile, getConfig());
1054 } catch (DocumentException de) {
1055
1056 throw new ParsingException("Error while parsing "
1057 + currentFile + ": " + de.toString(), de);
1058 }
1059
1060
1061 currentAnswers = AnswerBuilder.readCorrespondingAnswerKeys(
1062 trainer.getTargetStructure(), currentFile, getConfig());
1063
1064
1065 currentAcc = trainer.train(currentDoc, currentAnswers);
1066
1067
1068 if (currentAcc != null) {
1069 globalAcc = trainer.viewGlobalAccuracy();
1070
1071
1072 if (accContainers == null) {
1073 accContainers = new FieldContainer[globalAcc.length];
1074 for (int j = 0; j < globalAcc.length; j++) {
1075 accContainers[j]
1076 = FieldContainer.createFieldContainer();
1077 }
1078 }
1079
1080
1081 for (int j = 0; j < globalAcc.length; j++) {
1082 accMap = globalAcc[j].storeFields();
1083 accMap.putAll(currentAcc[j].storeFields());
1084 accMap.put(Prediction.KEY_SOURCE,
1085 IOUtils.getBaseName(currentFile));
1086 if (storeIterations) {
1087 accMap.put(KEY_ITERATION, new Integer(i));
1088 }
1089 accContainers[j].add(accMap);
1090 }
1091 }
1092 Util.LOG.info("Trained " + currentFile);
1093
1094
1095 if (trainer.isSentenceFiltering()) {
1096 newSentenceMetrics = trainer.evaluateSentenceFiltering();
1097 Util.LOG.debug("Evaluated sentence filtering for trained "
1098 + "document: " + newSentenceMetrics);
1099 sentenceMetrics.update(newSentenceMetrics);
1100 }
1101 }
1102
1103
1104 checkSentenceTraining(trainer, sentenceMetrics, i);
1105
1106
1107 if (globalAcc != null) {
1108 currentOverallAcc = new double[globalAcc.length];
1109 allAreOptimal = true;
1110 noneGotBetter = true;
1111 someGotWorse = false;
1112
1113 for (int j = 0; j < currentOverallAcc.length; j++) {
1114 currentOverallAcc[j] = globalAcc[j].getAccuracy();
1115
1116
1117 allAreOptimal = allAreOptimal
1118 && (currentOverallAcc[j] >= 1.0);
1119
1120
1121 if (lastOverallAcc != null) {
1122 noneGotBetter = noneGotBetter
1123 && (currentOverallAcc[j] <= lastOverallAcc[j]);
1124 someGotWorse = someGotWorse
1125 || (currentOverallAcc[j] < lastOverallAcc[j]);
1126 } else {
1127 noneGotBetter = false;
1128 }
1129 }
1130
1131 if (noneGotBetter) {
1132 noneGotBetterCounter++;
1133
1134 if (noneGotBetterCounter >= tuneStop) {
1135
1136 continueTraining = false;
1137 Util.LOG.debug("Stopping TUNE training after " + i
1138 + " iterations because current accuracies ("
1139 + ArrayUtils.toString(currentOverallAcc)
1140 + ") aren't higher than last ones ("
1141 + ArrayUtils.toString(lastOverallAcc)
1142 + ") for the " + tuneStop + ". time");
1143 } else if (someGotWorse) {
1144
1145 continueTraining = false;
1146 Util.LOG.debug("Stopping TUNE training after " + i
1147 + " iterations because current accuracies ("
1148 + ArrayUtils.toString(currentOverallAcc)
1149 + ") are lower than last ones ("
1150 + ArrayUtils.toString(lastOverallAcc) + ")");
1151 }
1152 }
1153
1154 if (allAreOptimal) {
1155 continueTraining = false;
1156 Util.LOG.debug("Stopping TUNE training after " + i
1157 + " iterations because all accuracies are already "
1158 + "optimal: "
1159 + ArrayUtils.toString(currentOverallAcc));
1160 }
1161
1162 lastOverallAcc = currentOverallAcc;
1163 }
1164
1165 if ((tuneEach && (i >= tuneSince))
1166 || !continueTraining
1167 || (i == tuneIterations)
1168 || tuneEvaluations.contains(Integer.valueOf(i))) {
1169
1170 Util.LOG.info("Finished training using "
1171 + ArrayUtils.toString(trainer.getClassifiers()) + "; "
1172 + Util.showDuration(trainStartTime));
1173 final long evalStartTime = System.currentTimeMillis();
1174 evalFileNum = 0;
1175
1176
1177 if (extractor == null) {
1178 extractor = initExtractor(trainer);
1179 }
1180
1181 final Iterator evalIter = evalFiles.iterator();
1182 final EvaluatedExtractionContainer evaluated =
1183 new EvaluatedExtractionContainer(
1184 extractor.getTargetStructure(), getConfig());
1185 ExtractionContainer currentResults;
1186 MultiFMetricsView interimMetrics;
1187 String currentType;
1188 Iterator typeIter;
1189
1190
1191 if (extractor.isSentenceFiltering()) {
1192 sentenceMetrics = new FMetrics();
1193 }
1194
1195
1196 while (evalIter.hasNext()) {
1197 currentFile = new File(inDirectory,
1198 (String) evalIter.next());
1199 evalFileNum++;
1200 Util.LOG.debug("Starting to extract and evaluate file #"
1201 + evalFileNum + ": " + currentFile);
1202
1203 try {
1204 currentDoc = DOMUtils.readDocument(currentFile,
1205 getConfig());
1206 } catch (DocumentException de) {
1207
1208 throw new ParsingException(de);
1209 }
1210
1211
1212 currentResults = extractor.extract(currentDoc);
1213
1214
1215 currentAnswers = AnswerBuilder.readCorrespondingAnswerKeys(
1216 trainer.getTargetStructure(), currentFile, getConfig());
1217
1218
1219 if (feedback) {
1220 currentAcc = trainer.train(currentDoc, currentAnswers);
1221 }
1222
1223
1224
1225
1226 if (extractor.isSentenceFiltering()) {
1227 newSentenceMetrics =
1228 extractor.evaluateSentenceFiltering(currentAnswers);
1229 Util.LOG.debug("Evaluated sentence filtering for "
1230 + "current document: " + newSentenceMetrics);
1231 sentenceMetrics.update(newSentenceMetrics);
1232 }
1233
1234
1235 evaluated.evaluateBatch(currentResults, currentAnswers,
1236 IOUtils.getBaseName(currentFile));
1237
1238
1239 interimMetrics = evaluated.viewMetrics();
1240 Util.LOG.info("Extracted and evaluated " + currentFile
1241 + ", interim results: " + interimMetrics.viewAll());
1242 typeIter = interimMetrics.types().iterator();
1243
1244 while (typeIter.hasNext()) {
1245 currentType = (String) typeIter.next();
1246 Util.LOG.debug("Interim results for " + currentType
1247 + ": " + interimMetrics.view(currentType));
1248 }
1249
1250
1251
1252 if (feedback && ((evalFileNum % MEASURE_FEEDBACK == 0)
1253 || !evalIter.hasNext())) {
1254 feedbackNum = new Double(Math.ceil((double) evalFileNum
1255 / MEASURE_FEEDBACK)).intValue();
1256
1257
1258 if (feedbackAverages.size() < feedbackNum) {
1259 feedbackAverages.add(new MultiFMetrics(true));
1260 }
1261
1262
1263 feedbackAverages.get(feedbackNum - 1).update(
1264 evaluated.viewMetrics());
1265 }
1266 }
1267
1268 serializeExtractions(outDirectory, baseName, i, evaluated,
1269 sentenceMetrics);
1270 results.addEvaluated(i, evaluated);
1271 Util.LOG.info("Finished extraction and evaluation using "
1272 + ArrayUtils.toString(extractor.getClassifiers()) + "; "
1273 + Util.showDuration(evalStartTime));
1274 }
1275
1276
1277 lastUsedTuneIteration = Math.max(lastUsedTuneIteration, i);
1278 }
1279
1280
1281 serializeTrainingMetrics(outDirectory, baseName, accContainers,
1282 globalAcc);
1283 serializeMetrics(outDirectory, baseName, writer, results);
1284 results.setTrainFeatureCV(trainer.viewFeatureCount());
1285 results.setExtractFeatureCV(extractor.viewFeatureCount());
1286 return results;
1287 }
1288
1289 /***
1290 * Adds the results of a training or evaluation iteration to the statistics
1291 * calculated for sentence filtering.
1292 *
1293 * @param newMetrics the metrics to add
1294 * @param isTraining whether this is a training or an evaluation result
1295 * @param iteration the number of the iteration (counting from 1)
1296 */
1297 private void updateSentenceMetricsStore(final FMetricsView newMetrics,
1298 final boolean isTraining, final int iteration) {
1299 final FieldMap fields = newMetrics.storeFields();
1300
1301
1302 fields.put(KEY_TYPE, isTraining ? TYPE_TRAIN : TYPE_EVAL);
1303 fields.put(KEY_ITERATION, new Integer(iteration));
1304 sentenceMetricsStore.add(fields);
1305 }
1306
1307 }