1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22 package de.fu_berlin.ties.extract;
23
24 import java.io.File;
25 import java.io.IOException;
26 import java.io.Writer;
27 import java.util.Set;
28
29 import org.apache.commons.lang.ArrayUtils;
30 import org.apache.commons.lang.StringUtils;
31 import org.apache.commons.lang.builder.ToStringBuilder;
32 import org.dom4j.Document;
33 import org.dom4j.Element;
34
35 import de.fu_berlin.ties.combi.CombinationState;
36 import de.fu_berlin.ties.combi.CombinationStrategy;
37 import de.fu_berlin.ties.ContextMap;
38 import de.fu_berlin.ties.ProcessingException;
39 import de.fu_berlin.ties.TiesConfiguration;
40
41 import de.fu_berlin.ties.classify.Prediction;
42 import de.fu_berlin.ties.classify.PredictionDistribution;
43 import de.fu_berlin.ties.classify.Probability;
44 import de.fu_berlin.ties.classify.TrainableClassifier;
45 import de.fu_berlin.ties.context.Representation;
46 import de.fu_berlin.ties.eval.Accuracy;
47 import de.fu_berlin.ties.eval.AccuracyView;
48 import de.fu_berlin.ties.eval.FMetricsView;
49 import de.fu_berlin.ties.filter.EmbeddingElements;
50 import de.fu_berlin.ties.filter.FilteringTokenWalker;
51 import de.fu_berlin.ties.filter.Oracle;
52 import de.fu_berlin.ties.filter.TrainableFilter;
53 import de.fu_berlin.ties.filter.TrainableFilteringTokenWalker;
54 import de.fu_berlin.ties.text.TokenDetails;
55 import de.fu_berlin.ties.text.TokenizerFactory;
56 import de.fu_berlin.ties.util.CollectionUtils;
57 import de.fu_berlin.ties.util.Util;
58 import de.fu_berlin.ties.xml.dom.DOMUtils;
59
60 /***
61 * A trainer trains a local {@link de.fu_berlin.ties.classify.Classifier}
62 * to be used for extraction.
63 *
64 * <p>Instances of this class are not thread-safe and cannot handle training on
65 * several documents in parallel.
66 *
67 * @author Christian Siefkes
68 * @version $Revision: 1.49 $, $Date: 2004/12/07 12:01:48 $, $Author: siefkes $
69 */
70 public class Trainer extends ExtractorBase implements Oracle {
71
72 /***
73 * Configuration key for determining the training mode
74 * ({@link #isTrainingOnlyErrors()}).
75 */
76 public static final String CONFIG_TOE = "train.only-errors";
77
78 /***
79 * Configuration key determining whether the trainer only ensures that all
80 * answer keys exist and can be located in the document instead of doing
81 * any training.
82 */
83 public static final String CONFIG_TEST_ONLY = "train.test-only";
84
85 /***
86 * Prefix used for serializing the global (overall) accuracy.
87 */
88 public static final String PREFIX_GLOBAL_ACC = "Overall ";
89
90 /***
91 * Prefix used for serializing the local (document-specific) accuracy.
92 */
93 public static final String PREFIX_LOCAL_ACC = "Document ";
94
95 /***
96 * Training the embedded sentence filter (if used) can be disabled by
97 * setting this to <code>false</code>.
98 */
99 private boolean sentenceTrainingEnabled = true;
100
101 /***
102 * The trainalbe classifier(s) used for the local classification decisions.
103 */
104 private final TrainableClassifier[] trainableClassifiers;
105
106 /***
107 * Whether to train only errors (TOE mode, recommmended) or to train all
108 * instances (brute-force mode).
109 */
110 private final boolean trainingOnlyErrors;
111
112 /***
113 * If <code>true</code> the trainer only ensures that all answer keys exist
114 * and can be located in the document instead of doing any training.
115 */
116 private final boolean testingOnly;
117
118 /***
119 * The token accuracy of all documents trained so far by each classifier,
120 * measured in {@linkplain #isTrainingOnlyErrors() TOE mode}.
121 */
122 private Accuracy[] globalAccuracies = null;
123
124 /***
125 * The token accuracy of the current document so far by each classifier,
126 * measured in {@linkplain #isTrainingOnlyErrors() TOE mode}.
127 */
128 private Accuracy[] localAccuracies = null;
129
130 /***
131 * Used to locate extractions.
132 */
133 private ExtractionLocator locator;
134
135 /***
136 * Used to determine which elements contain extractions if sentence
137 * filtering is used (first step of a double classification approach).
138 */
139 private EmbeddingElements embeddingElements;
140
141 /***
142 * A copy of the current extraction that contains only the part that has
143 * already been recognized, to be included in the context representation.
144 */
145 private Extraction partialExtraction;
146
147 /***
148 * Creates a new instance without specifying an output extension (which
149 * isn't needed anyway, because this class doesn't produce output).
150 * Delegates to {@link #Trainer(String)} using a dummy
151 * extension.
152 *
153 * @throws IllegalArgumentException if the combination strategy cannot be
154 * initialized
155 * (cf. {@link CombinationStrategy#createStrategy(java.util.Set,
156 * TiesConfiguration)})
157 * @throws ProcessingException if an error occurs during initialization
158 */
159 public Trainer() throws IllegalArgumentException, ProcessingException {
160 this("tmp");
161 }
162
163 /***
164 * Creates a new instance. Delegates to
165 * {@link #Trainer(String, TiesConfiguration)} using the
166 * {@linkplain TiesConfiguration#CONF standard configuration}.
167 *
168 * @param outExt the extension to use for output files
169 * @throws IllegalArgumentException if the combination strategy cannot be
170 * initialized
171 * (cf. {@link CombinationStrategy#createStrategy(java.util.Set,
172 * TiesConfiguration)})
173 * @throws ProcessingException if an error occurs during initialization
174 */
175 public Trainer(final String outExt)
176 throws IllegalArgumentException, ProcessingException {
177 this(outExt, TiesConfiguration.CONF);
178 }
179
180 /***
181 * Creates a new instance. Delegates to the
182 * {@link #Trainer(String, File, TiesConfiguration)} constructor without
183 * specifying a <code>runDirectory</code>.
184 *
185 * @param outExt the extension to use for output files
186 * @param config the configuration to use
187 * @throws IllegalArgumentException if the combination strategy cannot be
188 * initialized
189 * (cf. {@link CombinationStrategy#createStrategy(java.util.Set,
190 * TiesConfiguration)})
191 * @throws ProcessingException if an error occurs during initialization
192 */
193 public Trainer(final String outExt, final TiesConfiguration config)
194 throws IllegalArgumentException, ProcessingException {
195 this(outExt, null, config);
196 }
197
198 /***
199 * Creates a new instance. Sets the training mode
200 * ({@link #isTrainingOnlyErrors()}) to the value of the {@link #CONFIG_TOE}
201 * configuration key in the provided configuration and delegates to the
202 * corresponding {@linkplain
203 * ExtractorBase#ExtractorBase(String, TiesConfiguration) super constructor}
204 * to configure the fields.
205 *
206 * @param outExt the extension to use for output files
207 * @param runDirectory the directory to run the classifier in; used instead
208 * of the
209 * {@linkplain de.fu_berlin.ties.classify.ExternalClassifier#CONFIG_DIR
210 * configured directory} if not <code>null</code>
211 * @param config the configuration to use
212 * @throws IllegalArgumentException if the combination strategy cannot be
213 * initialized
214 * (cf. {@link CombinationStrategy#createStrategy(java.util.Set,
215 * TiesConfiguration)})
216 * @throws ProcessingException if an error occurs during initialization
217 */
218 public Trainer(final String outExt, final File runDirectory,
219 final TiesConfiguration config)
220 throws IllegalArgumentException, ProcessingException {
221 super(outExt, runDirectory, config);
222 trainingOnlyErrors = config.getBoolean(CONFIG_TOE);
223 testingOnly = config.getBoolean(CONFIG_TEST_ONLY);
224
225
226 trainableClassifiers = new TrainableClassifier[getClassifiers().length];
227
228 for (int i = 0; i < trainableClassifiers.length; i++) {
229 trainableClassifiers[i] =
230 (TrainableClassifier) getClassifiers()[i];
231 }
232
233
234 resetGlobalAccuracy();
235 }
236
237 /***
238 * Creates a new instance, using the
239 * {@link TiesConfiguration#CONF standard configuration} to configure the
240 * training mode and the superclasses.
241 *
242 * @param outExt the extension to use for output files
243 * @param targetStruct the target structure specifying the classes to
244 * recognize
245 * @param theClassifiers the array of classifiers to train
246 * @param theRepresentation the context representation to use training
247 * @param combiStrat the combination strategy to use
248 * @param tFactory used to instantiate tokenizers
249 * @param sentFilter the filter used in the first step of a double
250 * classification approach ("sentence filtering"); if <code>null</code>,
251 * no sentence filtering is used
252 */
253 public Trainer(final String outExt, final TargetStructure targetStruct,
254 final TrainableClassifier[] theClassifiers,
255 final Representation theRepresentation,
256 final CombinationStrategy combiStrat,
257 final TokenizerFactory tFactory, final TrainableFilter sentFilter) {
258 this(outExt, targetStruct, theClassifiers, theRepresentation,
259 combiStrat, tFactory, sentFilter,
260 CollectionUtils.arrayAsSet(TiesConfiguration.CONF
261 .getStringArray(CONFIG_RELEVANT_PUNCTUATION)),
262 TiesConfiguration.CONF.getBoolean(CONFIG_TOE),
263 TiesConfiguration.CONF.getBoolean(CONFIG_TEST_ONLY),
264 TiesConfiguration.CONF);
265 }
266
267 /***
268 * Creates a new instance.
269 *
270 * @param outExt the extension to use for output files
271 * @param targetStruct the target structure specifying the classes to
272 * recognize
273 * @param theClassifiers the array of classifiers to train
274 * @param theRepresentation the context representation to use training
275 * @param combiStrat the combination strategy to use
276 * @param tFactory used to instantiate tokenizers
277 * @param sentFilter the filter used in the first step of a double
278 * classification approach ("sentence filtering"); if <code>null</code>,
279 * no sentence filtering is used
280 * @param relevantPunct a set of punctuation tokens that have been found to
281 * be relevant for token classification; might be empty but not
282 * <code>null</code>
283 * @param trainOnlyErrors whether to train only errors (TOE mode,
284 * recommmended) or to train all instances (brute-force mode)
285 * @param testOnly if <code>true</code> the trainer only ensures that all
286 * answer keys exist and can be located in the document instead of doing
287 * any training
288 * @param config used to configure superclasses; if <code>null</code>,
289 * the {@linkplain TiesConfiguration#CONF standard configuration} is used
290 */
291 public Trainer(final String outExt, final TargetStructure targetStruct,
292 final TrainableClassifier[] theClassifiers,
293 final Representation theRepresentation,
294 final CombinationStrategy combiStrat,
295 final TokenizerFactory tFactory, final TrainableFilter sentFilter,
296 final Set<String> relevantPunct,
297 final boolean trainOnlyErrors, final boolean testOnly,
298 final TiesConfiguration config) {
299 super(outExt, targetStruct, theClassifiers, theRepresentation,
300 combiStrat, tFactory, sentFilter, relevantPunct, config);
301 trainingOnlyErrors = trainOnlyErrors;
302 testingOnly = testOnly;
303 trainableClassifiers = theClassifiers;
304
305
306 resetGlobalAccuracy();
307 }
308
309
310 /***
311 * {@inheritDoc}
312 */
313 protected FilteringTokenWalker createFilteringTokenWalker(
314 final TrainableFilter repFilter) {
315
316 return new TrainableFilteringTokenWalker(this, getFactory(),
317 repFilter, this, this, sentenceTrainingEnabled);
318 }
319
320 /***
321 * Disables training the embedded sentence filter, if sentence filtering is
322 * used.
323 */
324 public void disableSentenceTraining() {
325 sentenceTrainingEnabled = false;
326 }
327
328 /***
329 * Re-enables training the embedded filter, if sentence filtering is
330 * used.
331 */
332 public void enableSentenceTraining() {
333 sentenceTrainingEnabled = true;
334 }
335
336 /***
337 * Evaluates precision and recall for {@linkplain #isSentenceFiltering()
338 * sentence filtering} on the last processed document.
339 *
340 * @return the calculated statistics for sentence filtering on the
341 * last document; <code>null</code> if {@linkplain
342 * #isSentenceFiltering() sentence filtering} is disabled
343 */
344 public FMetricsView evaluateSentenceFiltering() {
345
346 return evaluateSentenceFiltering(embeddingElements);
347 }
348
349 /***
350 * Returns the trainable classifiers used for the local classification
351 * decisions. Delegates to {@link ExtractorBase#getClassifiers()} and
352 * casts to result (the constructor ensures that only
353 * {@link TrainableClassifier}s are accepted).
354 *
355 * @return the trainable classifier
356 */
357 private TrainableClassifier[] getTrainableClassifiers() {
358 return trainableClassifiers;
359 }
360
361 /***
362 * Initializes accuracy statistics for each classifier.
363 *
364 * @param prefix the prefix to use for these accuracy statistics
365 * @return an array of ccuracy statistics, one for each classifier
366 */
367 private Accuracy[] initAccuracies(final String prefix) {
368 final Accuracy[] result = new Accuracy[getClassifiers().length];
369
370 for (int i = 0; i < result.length; i++) {
371 result[i] = new Accuracy(prefix);
372 }
373
374 return result;
375 }
376
377 /***
378 * If <code>true</code> the trainer only ensures that all answer keys exist
379 * and can be located in the document instead of doing any training.
380 * @return the value of the attribute
381 */
382 public boolean isTestingOnly() {
383 return testingOnly;
384 }
385
386 /***
387 * Whether to train only errors (TOE mode, recommmended) or to train all
388 * instances (brute-force mode).
389 * @return the value of the attribute
390 */
391 public boolean isTrainingOnlyErrors() {
392 return trainingOnlyErrors;
393 }
394
395 /***
396 * Trains the local classifier with the correct extractions of an XML
397 * document, using the provided context representation. In
398 * {@linkplain #isTrainingOnlyErrors() TOE mode}, training statistics
399 * are serialized to the <code>writer</code>. The answer keys must be
400 * in a corresponding file ending in {@link AnswerBuilder#EXT_ANSWERS} in
401 * the same directory (when processing a local file) or in the current
402 * working directory (when processin an URL).
403 *
404 * @param document the document to read
405 * @param writer ignored by this method
406 * @param context a map of objects that are made available for processing
407 * @throws IOException if an I/O error occurs
408 * @throws ProcessingException if an error occurs during processing
409 */
410 public void process(final Document document, final Writer writer,
411 final ContextMap context) throws IOException, ProcessingException {
412
413 final ExtractionContainer answerKeys =
414 AnswerBuilder.readCorrespondingAnswerKeys(getTargetStructure(),
415 new File((File) context.get(KEY_DIRECTORY),
416 (String) context.get(KEY_LOCAL_NAME)),
417 getConfig());
418
419
420 train(document, answerKeys);
421 }
422
423 /***
424 * {@inheritDoc}
425 */
426 public void processToken(final Element element, final String left,
427 final TokenDetails details, final String right,
428 final ContextMap context) throws ProcessingException {
429
430 updateState(element, left, details.getToken(), right);
431 final String currentType;
432 final boolean startOfExtraction =
433 locator.startOfExtraction(details.getToken(), details.getRep());
434 final boolean endOfExtraction;
435
436 if (startOfExtraction) {
437 Util.LOG.debug("Starting extraction (" + details.getToken()
438 + " token)");
439 }
440
441 if (locator.inExtraction()) {
442
443 boolean updatedExtraction =
444 locator.updateExtraction(details.getToken(), details.getRep());
445
446 if (!updatedExtraction) {
447
448 currentType = null;
449 } else if (startOfExtraction) {
450
451
452 currentType = locator.getCurrentExtraction().getType();
453 final TokenDetails newDetails =
454 new TokenDetails(details.getToken(),
455 locator.getCurrentExtraction().getFirstTokenRep(),
456 locator.getCurrentExtraction().getIndex(), false);
457 partialExtraction = new Extraction(currentType, newDetails);
458 getPriorRecognitions().add(partialExtraction);
459 } else {
460
461 currentType = locator.getCurrentExtraction().getType();
462
463
464 partialExtraction.addToken(details, true);
465 }
466
467
468 endOfExtraction = locator.endOfExtraction();
469 } else {
470
471 currentType = null;
472 endOfExtraction = false;
473 }
474
475 final CombinationState currentState = (currentType == null)
476 ? CombinationState.OUTSIDE
477 : new CombinationState(currentType, startOfExtraction,
478 endOfExtraction);
479
480 if ((currentType == null) && (partialExtraction != null)
481 && (!partialExtraction.isSealed())) {
482
483 partialExtraction.setSealed(true);
484 }
485
486
487 boolean relevant = isRelevant(details.getToken());
488
489 if (!relevant && locator.inExtraction()
490 && (startOfExtraction || endOfExtraction)) {
491
492 markRelevant(details.getToken());
493 relevant = true;
494
495 Util.LOG.debug("Marked punctuation token " + details.getToken()
496 + " as relevant since it is the "
497 + (startOfExtraction ? "first" : "last" )
498 + " token of a " + currentType + " extraction");
499 }
500
501
502 if (relevant) {
503 final String[] translatedStates =
504 getStrategy().translateCurrentState(currentState);
505 Util.LOG.debug("Current state: " + currentState
506 + "; translated state: '"
507 + ArrayUtils.toString(translatedStates) + "'");
508 TrainableClassifier[] classifiers = getTrainableClassifiers();
509
510 for (int i = 0; i < translatedStates.length; i++) {
511 if (!testingOnly) {
512
513 if (trainingOnlyErrors) {
514 final PredictionDistribution predDist =
515 classifiers[i].trainOnError(getFeatures(),
516 translatedStates[i], getActiveClasses()[i]);
517
518
519 if (predDist == null) {
520
521 localAccuracies[i].incTrueCount();
522 globalAccuracies[i].incTrueCount();
523 } else {
524
525 localAccuracies[i].incFalseCount();
526 globalAccuracies[i].incFalseCount();
527 final Prediction prediction = predDist.best();
528 Util.LOG.debug(DOMUtils.showToken(element,
529 details.getToken()) + " classifier " + i
530 + " misclassified " + prediction
531 + " instead of " + translatedStates[i]);
532 }
533 } else {
534 classifiers[i].train(getFeatures(),
535 translatedStates[i]);
536 Util.LOG.debug("Trained classifier");
537 }
538 } else {
539
540 final PredictionDistribution[] predDists =
541 new PredictionDistribution[translatedStates.length];
542
543 for (int j = 0; j < predDists.length; j++) {
544 predDists[j] = new PredictionDistribution(
545 new Prediction(translatedStates[j],
546 new Probability(1.0)));
547 }
548
549 final CombinationState retrans =
550 getStrategy().translateResult(predDists);
551
552
553
554 if ((!StringUtils.equals(currentState.getType(),
555 retrans.getType()))
556 || ((currentState.getType() != null)
557 && (currentState.isBegin() != retrans.isBegin()))) {
558 Util.LOG.error("Error in combination strategy:"
559 + " incorrect re-translation " + retrans
560 + " of current state " + currentState);
561 }
562 }
563 }
564
565
566 getStrategy().updateState(currentState);
567 } else {
568 Util.LOG.debug("Skipping over irrelevant punctuation token "
569 + details.getToken());
570 }
571
572 if (locator.endOfExtraction()) {
573
574 locator.switchToNextExtraction();
575 }
576 }
577
578 /***
579 * Resets the internal classifer, completely deleting the prediction model.
580 * @throws ProcessingException if an error occurs during reset
581 */
582 public void reset() throws ProcessingException {
583 TrainableClassifier[] classifiers = getTrainableClassifiers();
584 for (int i = 0; i < classifiers.length; i++) {
585 classifiers[i].reset();
586 }
587 }
588
589 /***
590 * Resets the global (overall) accuracies measured so far by each
591 * classifier. This can be used to restart accuracy measurements after
592 * each round (iteration) of TUNE training, for example. This method
593 * is only relevant in TOE (training-only/mainly-errors) mode, otherwise
594 * it does nothing.
595 */
596 public void resetGlobalAccuracy() {
597 if (trainingOnlyErrors) {
598 globalAccuracies = initAccuracies(PREFIX_GLOBAL_ACC);
599 }
600 }
601
602 /***
603 * Reset the combination strategy, logging a warning if it tells me to
604 * discard the last extraction.
605 */
606 protected void resetStrategy() {
607 final boolean discardLast = getStrategy().reset();
608 if (discardLast) {
609 Util.LOG.warn("Combination strategy " + getStrategy()
610 + " ordered to discard the last extraction -- "
611 + "this is not supposed to happen when training");
612 }
613 }
614
615 /***
616 * {@inheritDoc}
617 */
618 public boolean shouldMatch(final Element element) {
619
620
621 return embeddingElements.containsExtraction(element);
622 }
623
624 /***
625 * Returns a string representation of this object.
626 * @return a textual representation
627 */
628 public String toString() {
629 final ToStringBuilder result = new ToStringBuilder(this)
630 .appendSuper(super.toString())
631 .append("training only errors", trainingOnlyErrors)
632 .append("testing only", testingOnly)
633 .append("locator", locator);
634
635 if (getSentenceFilter() != null) {
636 result.append("sentence training enabled", sentenceTrainingEnabled);
637 }
638
639 return result.toString();
640 }
641
642 /***
643 * Trains the local classifier with the correct extractions of an XML
644 * document, using the provided context representation.
645 *
646 * @param document a document whose contents should be classified
647 * @param correctExtractions a container of all correct extractions for the
648 * document
649 * @return The token accuracies of each classifier of the trained document
650 * if in {@linkplain #isTrainingOnlyErrors() TOE mode}; <code>null</code>
651 * otherwise
652 * @throws IOException if an I/O error occurs
653 * @throws ProcessingException if an error occurs during processing
654 */
655 public Accuracy[] train(final Document document,
656 final ExtractionContainer correctExtractions)
657 throws IOException, ProcessingException {
658
659 initFields();
660 locator = new ExtractionLocator(document, correctExtractions,
661 getFactory().createTokenizer(""));
662
663
664 if (isSentenceFiltering()) {
665 embeddingElements = new EmbeddingElements(document,
666 correctExtractions, getFactory());
667 }
668
669
670 if (trainingOnlyErrors) {
671 localAccuracies = initAccuracies(PREFIX_LOCAL_ACC);
672 }
673
674
675 getWalker().walk(document, null);
676
677
678 locator.reachedEndOfDocument();
679
680 if (trainingOnlyErrors) {
681
682 Util.LOG.debug("Finished training in TOE mode: "
683 + ArrayUtils.toString(localAccuracies) + ", "
684 + ArrayUtils.toString(globalAccuracies));
685 return localAccuracies;
686 } else {
687 return null;
688 }
689 }
690
691 /***
692 * Returns a view on the global (overall) accuracies measured so far (or
693 * after the last call to {@link #resetGlobalAccuracy()}) by
694 * each classifier. This is not a snapshot but will change whenever the
695 * underlying values are changed.
696 *
697 * @return A view on the global accuracies measured so far if in
698 * {@linkplain #isTrainingOnlyErrors() TOE mode}; <code>null</code>
699 * otherwise
700 */
701 public AccuracyView[] viewGlobalAccuracy() {
702 return globalAccuracies;
703 }
704
705 }