1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22 package de.fu_berlin.ties.extract;
23
24 import java.io.File;
25 import java.io.IOException;
26 import java.io.Writer;
27 import java.util.Set;
28
29 import org.apache.commons.lang.ArrayUtils;
30 import org.apache.commons.lang.StringUtils;
31 import org.dom4j.Document;
32 import org.dom4j.Element;
33
34 import de.fu_berlin.ties.combi.CombinationState;
35 import de.fu_berlin.ties.combi.CombinationStrategy;
36 import de.fu_berlin.ties.ContextMap;
37 import de.fu_berlin.ties.ProcessingException;
38 import de.fu_berlin.ties.TiesConfiguration;
39
40 import de.fu_berlin.ties.classify.Prediction;
41 import de.fu_berlin.ties.classify.PredictionDistribution;
42 import de.fu_berlin.ties.classify.Probability;
43 import de.fu_berlin.ties.classify.TrainableClassifier;
44 import de.fu_berlin.ties.context.ContextDetails;
45 import de.fu_berlin.ties.context.Representation;
46 import de.fu_berlin.ties.eval.Accuracy;
47 import de.fu_berlin.ties.eval.AccuracyView;
48 import de.fu_berlin.ties.eval.FMetricsView;
49 import de.fu_berlin.ties.extract.amend.FinalReextractor;
50 import de.fu_berlin.ties.extract.reestimate.Reestimator;
51 import de.fu_berlin.ties.filter.DocumentRewriter;
52 import de.fu_berlin.ties.filter.EmbeddingElements;
53 import de.fu_berlin.ties.filter.FilteringTokenWalker;
54 import de.fu_berlin.ties.filter.Oracle;
55 import de.fu_berlin.ties.filter.TrainableFilter;
56 import de.fu_berlin.ties.filter.TrainableFilteringTokenWalker;
57 import de.fu_berlin.ties.text.TokenDetails;
58 import de.fu_berlin.ties.text.TokenizerFactory;
59 import de.fu_berlin.ties.util.CollUtils;
60 import de.fu_berlin.ties.util.Util;
61
62 /***
63 * A trainer trains a local {@link de.fu_berlin.ties.classify.Classifier}
64 * to be used for extraction.
65 *
66 * <p>Instances of this class are not thread-safe and cannot handle training on
67 * several documents in parallel.
68 *
69 * @author Christian Siefkes
70 * @version $Revision: 1.72 $, $Date: 2006/10/21 16:04:14 $, $Author: siefkes $
71 */
72 public class Trainer extends ExtractorBase implements Oracle {
73
74 /***
75 * Configuration key for determining the training mode
76 * ({@link #isTrainingOnlyErrors()}).
77 */
78 public static final String CONFIG_TOE = "train.only-errors";
79
80 /***
81 * Configuration key determining whether the trainer only ensures that all
82 * answer keys exist and can be located in the document instead of doing
83 * any training.
84 */
85 public static final String CONFIG_TEST_ONLY = "train.test-only";
86
87 /***
88 * Prefix used for serializing the global (overall) accuracy.
89 */
90 public static final String PREFIX_GLOBAL_ACC = "Overall ";
91
92 /***
93 * Prefix used for serializing the local (document-specific) accuracy.
94 */
95 public static final String PREFIX_LOCAL_ACC = "Document ";
96
97 /***
98 * Training the embedded sentence filter (if used) can be disabled by
99 * setting this to <code>false</code>.
100 */
101 private boolean sentenceTrainingEnabled = true;
102
103 /***
104 * The trainalbe classifier(s) used for the local classification decisions.
105 */
106 private final TrainableClassifier[] trainableClassifiers;
107
108 /***
109 * Whether to train only errors (TOE mode, recommmended) or to train all
110 * instances (brute-force mode).
111 */
112 private final boolean trainingOnlyErrors;
113
114 /***
115 * If <code>true</code> the trainer only ensures that all answer keys exist
116 * and can be located in the document instead of doing any training.
117 */
118 private final boolean testingOnly;
119
120 /***
121 * The token accuracy of all documents trained so far by each classifier,
122 * measured in {@linkplain #isTrainingOnlyErrors() TOE mode}.
123 */
124 private Accuracy[] globalAccuracies = null;
125
126 /***
127 * The token accuracy of the current document so far by each classifier,
128 * measured in {@linkplain #isTrainingOnlyErrors() TOE mode}.
129 */
130 private Accuracy[] localAccuracies = null;
131
132 /***
133 * Used to locate extractions.
134 */
135 private ExtractionLocator locator;
136
137 /***
138 * Used to determine which elements contain extractions if sentence
139 * filtering is used (first step of a double classification approach).
140 */
141 private EmbeddingElements embeddingElements;
142
143 /***
144 * A copy of the current extraction that contains only the part that has
145 * already been recognized, to be included in the context representation.
146 */
147 private Extraction partialExtraction;
148
149 /***
150 * Creates a new instance without specifying an output extension (which
151 * isn't needed anyway, because this class doesn't produce output).
152 * Delegates to {@link #Trainer(String)} using a dummy
153 * extension.
154 *
155 * @throws IllegalArgumentException if the combination strategy cannot be
156 * initialized
157 * (cf. {@link CombinationStrategy#createStrategy(java.util.Set,
158 * TiesConfiguration)})
159 * @throws ProcessingException if an error occurs during initialization
160 */
161 public Trainer() throws IllegalArgumentException, ProcessingException {
162 this("tmp");
163 }
164
165 /***
166 * Creates a new instance. Delegates to
167 * {@link #Trainer(String, TiesConfiguration)} using the
168 * {@linkplain TiesConfiguration#CONF standard configuration}.
169 *
170 * @param outExt the extension to use for output files
171 * @throws IllegalArgumentException if the combination strategy cannot be
172 * initialized
173 * (cf. {@link CombinationStrategy#createStrategy(java.util.Set,
174 * TiesConfiguration)})
175 * @throws ProcessingException if an error occurs during initialization
176 */
177 public Trainer(final String outExt)
178 throws IllegalArgumentException, ProcessingException {
179 this(outExt, TiesConfiguration.CONF);
180 }
181
182 /***
183 * Creates a new instance. Delegates to the
184 * {@link #Trainer(String, File, TiesConfiguration)} constructor without
185 * specifying a <code>runDirectory</code>.
186 *
187 * @param outExt the extension to use for output files
188 * @param config the configuration to use
189 * @throws IllegalArgumentException if the combination strategy cannot be
190 * initialized
191 * (cf. {@link CombinationStrategy#createStrategy(java.util.Set,
192 * TiesConfiguration)})
193 * @throws ProcessingException if an error occurs during initialization
194 */
195 public Trainer(final String outExt, final TiesConfiguration config)
196 throws IllegalArgumentException, ProcessingException {
197 this(outExt, null, config);
198 }
199
200 /***
201 * Creates a new instance. Sets the training mode
202 * ({@link #isTrainingOnlyErrors()}) to the value of the {@link #CONFIG_TOE}
203 * configuration key in the provided configuration and delegates to the
204 * corresponding {@linkplain
205 * ExtractorBase#ExtractorBase(String, TiesConfiguration) super constructor}
206 * to configure the fields.
207 *
208 * @param outExt the extension to use for output files
209 * @param runDirectory the directory to run the classifier in; used instead
210 * of the
211 * {@linkplain de.fu_berlin.ties.classify.ExternalClassifier#CONFIG_DIR
212 * configured directory} if not <code>null</code>
213 * @param config the configuration to use
214 * @throws IllegalArgumentException if the combination strategy cannot be
215 * initialized
216 * (cf. {@link CombinationStrategy#createStrategy(java.util.Set,
217 * TiesConfiguration)})
218 * @throws ProcessingException if an error occurs during initialization
219 */
220 public Trainer(final String outExt, final File runDirectory,
221 final TiesConfiguration config)
222 throws IllegalArgumentException, ProcessingException {
223 super(outExt, runDirectory, config);
224 trainingOnlyErrors = config.getBoolean(CONFIG_TOE);
225 testingOnly = config.getBoolean(CONFIG_TEST_ONLY);
226
227
228 trainableClassifiers = new TrainableClassifier[getClassifiers().length];
229
230 for (int i = 0; i < trainableClassifiers.length; i++) {
231 trainableClassifiers[i] =
232 (TrainableClassifier) getClassifiers()[i];
233 }
234
235
236 resetGlobalAccuracy();
237 }
238
239 /***
240 * Creates a new instance, using the
241 * {@link TiesConfiguration#CONF standard configuration} to configure the
242 * remaining fields.
243 *
244 * @param outExt the extension to use for output files
245 * @param targetStruct the target structure specifying the classes to
246 * recognize
247 * @param theClassifiers the array of classifiers to train
248 * @param theRepresentation the context representation to use training
249 * @param combiStrat the combination strategy to use
250 * @param reextract an optional re-extractor that can modify extractions in
251 * any suitable way
252 * @param tFactory used to instantiate tokenizers
253 * @param estimator the last element of the re-estimator chain, or
254 * <code>null</code> if the chain is empty
255 * @param sentFilter the filter used in the first step of a double
256 * classification approach ("sentence filtering"); if <code>null</code>,
257 * no sentence filtering is used
258 * @throws ProcessingException if an error occurs during initialization
259 */
260 public Trainer(final String outExt, final TargetStructure targetStruct,
261 final TrainableClassifier[] theClassifiers,
262 final Representation theRepresentation,
263 final CombinationStrategy combiStrat,
264 final FinalReextractor reextract, final TokenizerFactory tFactory,
265 final Reestimator estimator, final TrainableFilter sentFilter)
266 throws ProcessingException {
267 this(outExt, targetStruct, theClassifiers, theRepresentation,
268 combiStrat, reextract, tFactory, estimator,
269 createDocumentRewriters(TiesConfiguration.CONF), sentFilter,
270 CollUtils.arrayAsSet(TiesConfiguration.CONF
271 .getStringArray(CONFIG_RELEVANT_PUNCTUATION)),
272 TiesConfiguration.CONF.getBoolean(CONFIG_TOE),
273 TiesConfiguration.CONF.getBoolean(CONFIG_TEST_ONLY),
274 TiesConfiguration.CONF);
275 }
276
277 /***
278 * Creates a new instance.
279 *
280 * @param outExt the extension to use for output files
281 * @param targetStruct the target structure specifying the classes to
282 * recognize
283 * @param theClassifiers the array of classifiers to train
284 * @param theRepresentation the context representation to use training
285 * @param combiStrat the combination strategy to use
286 * @param reextract an optional re-extractor that can modify extractions in
287 * any suitable way
288 * @param tFactory used to instantiate tokenizers
289 * @param estimator the last element of the re-estimator chain, or
290 * <code>null</code> if the chain is empty
291 * @param docFilters a list (possibly empty) of document processors that are
292 * invoked to modify the XML representations of the documents to process
293 * @param sentFilter the filter used in the first step of a double
294 * classification approach ("sentence filtering"); if <code>null</code>,
295 * no sentence filtering is used
296 * @param relevantPunct a set of punctuation tokens that have been found to
297 * be relevant for token classification; might be empty but not
298 * <code>null</code>
299 * @param trainOnlyErrors whether to train only errors (TOE mode,
300 * recommmended) or to train all instances (brute-force mode)
301 * @param testOnly if <code>true</code> the trainer only ensures that all
302 * answer keys exist and can be located in the document instead of doing
303 * any training
304 * @param config used to configure superclasses; if <code>null</code>,
305 * the {@linkplain TiesConfiguration#CONF standard configuration} is used
306 */
307 public Trainer(final String outExt, final TargetStructure targetStruct,
308 final TrainableClassifier[] theClassifiers,
309 final Representation theRepresentation,
310 final CombinationStrategy combiStrat,
311 final FinalReextractor reextract, final TokenizerFactory tFactory,
312 final Reestimator estimator, final DocumentRewriter[] docFilters,
313 final TrainableFilter sentFilter, final Set<String> relevantPunct,
314 final boolean trainOnlyErrors, final boolean testOnly,
315 final TiesConfiguration config) {
316 super(outExt, targetStruct, theClassifiers, theRepresentation,
317 combiStrat, reextract, tFactory, estimator, docFilters, sentFilter,
318 relevantPunct, config);
319 trainingOnlyErrors = trainOnlyErrors;
320 testingOnly = testOnly;
321 trainableClassifiers = theClassifiers;
322
323
324 resetGlobalAccuracy();
325 }
326
327
328 /***
329 * {@inheritDoc}
330 */
331 protected FilteringTokenWalker createFilteringTokenWalker(
332 final TrainableFilter repFilter) {
333
334 return new TrainableFilteringTokenWalker(this, getFactory(),
335 repFilter, this, this, sentenceTrainingEnabled);
336 }
337
338 /***
339 * Disables training the embedded sentence filter, if sentence filtering is
340 * used.
341 */
342 public void disableSentenceTraining() {
343 sentenceTrainingEnabled = false;
344 }
345
346 /***
347 * Re-enables training the embedded filter, if sentence filtering is
348 * used.
349 */
350 public void enableSentenceTraining() {
351 sentenceTrainingEnabled = true;
352 }
353
354 /***
355 * Evaluates precision and recall for {@linkplain #isSentenceFiltering()
356 * sentence filtering} on the last processed document.
357 *
358 * @return the calculated statistics for sentence filtering on the
359 * last document; <code>null</code> if {@linkplain
360 * #isSentenceFiltering() sentence filtering} is disabled
361 */
362 public FMetricsView evaluateSentenceFiltering() {
363
364 return evaluateSentenceFiltering(embeddingElements);
365 }
366
367 /***
368 * Returns the trainable classifiers used for the local classification
369 * decisions. Delegates to {@link ExtractorBase#getClassifiers()} and
370 * casts to result (the constructor ensures that only
371 * {@link TrainableClassifier}s are accepted).
372 *
373 * @return the trainable classifier
374 */
375 private TrainableClassifier[] getTrainableClassifiers() {
376 return trainableClassifiers;
377 }
378
379 /***
380 * Initializes accuracy statistics for each classifier.
381 *
382 * @param prefix the prefix to use for these accuracy statistics
383 * @return an array of ccuracy statistics, one for each classifier
384 */
385 private Accuracy[] initAccuracies(final String prefix) {
386 final Accuracy[] result = new Accuracy[getClassifiers().length];
387
388 for (int i = 0; i < result.length; i++) {
389 result[i] = new Accuracy(prefix);
390 }
391
392 return result;
393 }
394
395 /***
396 * If <code>true</code> the trainer only ensures that all answer keys exist
397 * and can be located in the document instead of doing any training.
398 * @return the value of the attribute
399 */
400 public boolean isTestingOnly() {
401 return testingOnly;
402 }
403
404 /***
405 * Whether to train only errors (TOE mode, recommmended) or to train all
406 * instances (brute-force mode).
407 * @return the value of the attribute
408 */
409 public boolean isTrainingOnlyErrors() {
410 return trainingOnlyErrors;
411 }
412
413 /***
414 * Trains the local classifier with the correct extractions of an XML
415 * document, using the provided context representation. In
416 * {@linkplain #isTrainingOnlyErrors() TOE mode}, training statistics
417 * are serialized to the <code>writer</code>. The answer keys must be
418 * in a corresponding file ending in {@link AnswerBuilder#EXT_ANSWERS} in
419 * the same directory (when processing a local file) or in the current
420 * working directory (when processin an URL).
421 *
422 * @param document the document to read
423 * @param writer ignored by this method
424 * @param context a map of objects that are made available for processing
425 * @throws IOException if an I/O error occurs
426 * @throws ProcessingException if an error occurs during processing
427 */
428 public void process(final Document document, final Writer writer,
429 final ContextMap context) throws IOException, ProcessingException {
430
431 final ExtractionContainer answerKeys =
432 AnswerBuilder.readCorrespondingAnswerKeys(getTargetStructure(),
433 new File((File) context.get(KEY_DIRECTORY),
434 (String) context.get(KEY_LOCAL_NAME)),
435 getConfig());
436
437
438 final File filename = new File((File) context.get(KEY_DIRECTORY),
439 (String) context.get(KEY_LOCAL_NAME));
440
441
442 train(document, filename, answerKeys);
443 }
444
445 /***
446 * {@inheritDoc}
447 */
448 public void processToken(final Element element, final String left,
449 final TokenDetails details, final String right,
450 final ContextMap context) throws ProcessingException {
451
452 updateState(element, left, details.getToken(), right);
453 final String currentType;
454 final boolean startOfExtraction =
455 locator.startOfExtraction(details.getToken(), details.getRep());
456 final boolean endOfExtraction;
457 final CombinationState currentState;
458
459 if (startOfExtraction) {
460 Util.LOG.debug("Starting extraction (" + details.getToken()
461 + " token)");
462 final Extraction currentExtraction = locator.getCurrentExtraction();
463
464
465 if (currentExtraction.getIndex() != details.getIndex()) {
466 currentExtraction.setIndex(details.getIndex());
467 }
468 }
469
470 if (locator.inExtraction()) {
471
472 boolean updatedExtraction =
473 locator.updateExtraction(details.getToken(), details.getRep());
474
475 if (!updatedExtraction) {
476
477 currentType = null;
478 } else if (startOfExtraction) {
479
480
481 currentType = locator.getCurrentExtraction().getType();
482 final TokenDetails newDetails =
483 new TokenDetails(details.getToken(),
484 locator.getCurrentExtraction().getFirstTokenRep(),
485 locator.getCurrentExtraction().getIndex(), false);
486 partialExtraction = new Extraction(currentType, newDetails);
487 getPriorRecognitions().add(partialExtraction);
488 } else {
489
490 currentType = locator.getCurrentExtraction().getType();
491
492
493 partialExtraction.addToken(details, true);
494 }
495
496
497 endOfExtraction = locator.endOfExtraction();
498 } else {
499
500 currentType = null;
501 endOfExtraction = false;
502 }
503
504 currentState = (currentType == null) ? CombinationState.OUTSIDE
505 : new CombinationState(currentType, startOfExtraction,
506 endOfExtraction, null);
507
508 if ((currentType == null) && (partialExtraction != null)
509 && (!partialExtraction.isSealed())) {
510
511 partialExtraction.setSealed(true);
512 }
513
514
515 boolean relevant = isRelevant(details.getToken());
516
517 if (!relevant && locator.inExtraction()
518 && (startOfExtraction || endOfExtraction)) {
519
520 markRelevant(details.getToken());
521 relevant = true;
522
523 Util.LOG.debug("Marked punctuation token " + details.getToken()
524 + " as relevant since it is the "
525 + (startOfExtraction ? "first" : "last")
526 + " token of a " + currentType + " extraction");
527 }
528
529
530 if (relevant) {
531 final String[] translatedStates =
532 getStrategy().translateCurrentState(currentState);
533
534
535
536 final TrainableClassifier[] classifiers = getTrainableClassifiers();
537 final PredictionDistribution[] predDists =
538 new PredictionDistribution[translatedStates.length];
539
540 for (int i = 0; i < translatedStates.length; i++) {
541 if (!testingOnly) {
542
543 if (trainingOnlyErrors) {
544 final PredictionDistribution predDist =
545 classifiers[i].trainOnError(getFeatures(),
546 translatedStates[i], getActiveClasses()[i]);
547 predDists[i] = predDist;
548
549
550 if (predDist == null) {
551
552 localAccuracies[i].incTrueCount();
553 globalAccuracies[i].incTrueCount();
554 } else {
555
556 localAccuracies[i].incFalseCount();
557 globalAccuracies[i].incFalseCount();
558
559
560
561
562
563 }
564 } else {
565 classifiers[i].train(getFeatures(),
566 translatedStates[i]);
567
568 }
569 } else {
570
571 for (int j = 0; j < predDists.length; j++) {
572 predDists[j] = new PredictionDistribution(
573 new Prediction(translatedStates[j],
574 new Probability(1.0)));
575 }
576
577 final CombinationState retrans =
578 getStrategy().translateResult(predDists, details);
579
580
581
582 if ((!StringUtils.equals(currentState.getType(),
583 retrans.getType()))
584 || ((currentState.getType() != null)
585 && (currentState.isBegin() != retrans.isBegin()))) {
586 Util.LOG.error("Error in combination strategy:"
587 + " incorrect re-translation " + retrans
588 + " of current state " + currentState);
589 }
590 }
591 }
592
593
594 getStrategy().updateState(currentState, predDists, details);
595
596 if ((currentState == CombinationState.OUTSIDE)
597 && (getReestimator() != null)) {
598
599 getReestimator().trainOtherToken(new ContextDetails(details,
600 getFeatures(), currentState, relevant));
601 }
602 }
603
604
605
606
607 if (locator.endOfExtraction()) {
608 final Extraction currentExtraction = locator.getCurrentExtraction();
609
610
611 if (currentExtraction.getLastIndex() != details.getIndex()) {
612 currentExtraction.setLastIndex(details.getIndex());
613 }
614
615 if (getReestimator() != null) {
616
617 getReestimator().train(partialExtraction);
618 }
619
620
621 locator.switchToNextExtraction();
622 }
623
624
625 addContextDetails(new ContextDetails(details, getFeatures(),
626 currentState, relevant));
627 }
628
629 /***
630 * Resets the internal classifer, completely deleting the prediction model.
631 * @throws ProcessingException if an error occurs during reset
632 */
633 public void reset() throws ProcessingException {
634 TrainableClassifier[] classifiers = getTrainableClassifiers();
635 for (int i = 0; i < classifiers.length; i++) {
636 classifiers[i].reset();
637 }
638 }
639
640 /***
641 * Resets the global (overall) accuracies measured so far by each
642 * classifier. This can be used to restart accuracy measurements after
643 * each round (iteration) of TUNE training, for example. This method
644 * is only relevant in TOE (training-only/mainly-errors) mode, otherwise
645 * it does nothing.
646 */
647 public void resetGlobalAccuracy() {
648 if (trainingOnlyErrors) {
649 globalAccuracies = initAccuracies(PREFIX_GLOBAL_ACC);
650 }
651 }
652
653 /***
654 * Reset the combination strategy, logging a warning if it tells me to
655 * discard the last extraction.
656 */
657 protected void resetStrategy() {
658 final boolean discardLast = getStrategy().reset();
659 if (discardLast) {
660 Util.LOG.warn("Combination strategy " + getStrategy()
661 + " ordered to discard the last extraction -- "
662 + "this is not supposed to happen when training");
663 }
664 }
665
666 /***
667 * {@inheritDoc}
668 */
669 public boolean shouldMatch(final Element element) {
670
671
672 return embeddingElements.containsExtraction(element);
673 }
674
675 /***
676 * Trains the local classifier with the correct extractions of an XML
677 * document, using the provided context representation.
678 *
679 * @param doc a document whose contents should be classified
680 * @param filename the name of the document
681 * @param correctExtractions a container of all correct extractions for the
682 * document
683 * @return The token accuracies of each classifier of the trained document
684 * if in {@linkplain #isTrainingOnlyErrors() TOE mode}; <code>null</code>
685 * otherwise
686 * @throws IOException if an I/O error occurs
687 * @throws ProcessingException if an error occurs during processing
688 */
689 public Accuracy[] train(final Document doc, final File filename,
690 final ExtractionContainer correctExtractions)
691 throws IOException, ProcessingException {
692
693 initFields(filename);
694 final Document document = filterDocument(doc, filename);
695 locator = new ExtractionLocator(correctExtractions,
696 getFactory().createTokenizer(""));
697
698
699 if (isSentenceFiltering()) {
700 embeddingElements = new EmbeddingElements(document,
701 correctExtractions, getFactory());
702 }
703
704
705 if (trainingOnlyErrors) {
706 localAccuracies = initAccuracies(PREFIX_LOCAL_ACC);
707 }
708
709
710 getWalker().walk(document, null);
711
712
713 locator.reachedEndOfDocument();
714
715
716 if (getReextractor() != null) {
717 final ContextMap reexContext =
718 getStrategy().contextForReextractor();
719 getReextractor().train(correctExtractions, getContextDetails(),
720 reexContext);
721 }
722
723 if (trainingOnlyErrors) {
724
725 Util.LOG.debug("Finished training in TOE mode: "
726 + ArrayUtils.toString(localAccuracies) + ", "
727 + ArrayUtils.toString(globalAccuracies));
728 return localAccuracies;
729 } else {
730 return null;
731 }
732 }
733
734 /***
735 * Returns a view on the global (overall) accuracies measured so far (or
736 * after the last call to {@link #resetGlobalAccuracy()}) by
737 * each classifier. This is not a snapshot but will change whenever the
738 * underlying values are changed.
739 *
740 * @return A view on the global accuracies measured so far if in
741 * {@linkplain #isTrainingOnlyErrors() TOE mode}; <code>null</code>
742 * otherwise
743 */
744 public AccuracyView[] viewGlobalAccuracy() {
745 return globalAccuracies;
746 }
747
748 }