1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22 package de.fu_berlin.ties.classify;
23
24 import java.io.File;
25 import java.util.Collections;
26 import java.util.HashSet;
27 import java.util.Iterator;
28 import java.util.Set;
29
30 import org.apache.commons.lang.ArrayUtils;
31 import org.apache.commons.lang.builder.ToStringBuilder;
32 import org.dom4j.Element;
33 import org.dom4j.QName;
34
35 import de.fu_berlin.ties.ContextMap;
36 import de.fu_berlin.ties.ProcessingException;
37 import de.fu_berlin.ties.TiesConfiguration;
38 import de.fu_berlin.ties.classify.feature.FeatureTransformer;
39 import de.fu_berlin.ties.classify.feature.FeatureVector;
40 import de.fu_berlin.ties.classify.winnow.UltraconservativeWinnow;
41 import de.fu_berlin.ties.classify.winnow.Winnow;
42 import de.fu_berlin.ties.io.ObjectElement;
43 import de.fu_berlin.ties.io.XMLStorable;
44 import de.fu_berlin.ties.text.TextUtils;
45 import de.fu_berlin.ties.util.CollUtils;
46 import de.fu_berlin.ties.util.Util;
47 import de.fu_berlin.ties.xml.dom.DOMUtils;
48
49 /***
50 * Classifiers extending this abstract class must provide a training mechanism
51 * by implementing the {@link #doTrain(FeatureVector, String, ContextMap)}
52 * method. This class supports error-driven learning ("train only errors")
53 * which often leads to better prediction models than brute-force training.
54 *
55 * <p>The code in this class is thread-safe.
56 *
57 * @author Christian Siefkes
58 * @version $Revision: 1.50 $, $Date: 2006/11/26 21:14:58 $, $Author: siefkes $
59 */
60 public abstract class TrainableClassifier implements Classifier, XMLStorable {
61
62 /***
63 * Name of the main element used for XML serialization.
64 */
65 public static final QName ELEMENT_MAIN =
66 DOMUtils.defaultName("classifier");
67
68 /***
69 * Attribute name used for XML serialization.
70 */
71 static final QName ATTRIB_CLASSES = DOMUtils.defaultName("classes");
72
73 /***
74 * Attribute name used for XML serialization.
75 */
76 static final QName ATTRIB_TRAIN_ALL = DOMUtils.defaultName("train-all");
77
78 /***
79 * Flag used to load the {@link MetaClassifier}.
80 */
81 public static final String META_CLASSIFIER = "meta";
82
83 /***
84 * Flag used to load the {@link MultiBinaryClassifier}.
85 */
86 public static final String MULTI_CLASSIFIER = "multi";
87
88 /***
89 * Flag used to load the {@link OneAgainstTheRestClassifier}.
90 */
91 public static final String OAR_CLASSIFIER = "oar";
92
93 /***
94 * Flag used to load the {@link TieClassifier}.
95 */
96 public static final String TIE_CLASSIFIER = "tie";
97
98 /***
99 * Names of classifiers wrapping inner classifiers.
100 */
101 private static final Set<String> WRAPPING_CLASSIFIERS;
102
103
104 /***
105 * Static initialization of set of wrapping classifiers.
106 */
107 static {
108 final Set<String> wrapping = new HashSet<String>();
109 wrapping.add(META_CLASSIFIER);
110 wrapping.add(MULTI_CLASSIFIER);
111 wrapping.add(OAR_CLASSIFIER);
112 wrapping.add(TIE_CLASSIFIER);
113 WRAPPING_CLASSIFIERS = Collections.unmodifiableSet(wrapping);
114 }
115
116 /***
117 * Factory method that delegates to
118 * {@link #createClassifier(Set, TiesConfiguration)} using the
119 * {@linkplain TiesConfiguration#CONF standard configuration}.
120 *
121 * @param allValidClasses the set of all valid classes
122 * @return the created classifier
123 * @throws IllegalArgumentException if the value of the
124 * {@link #CONFIG_CLASSIFIER} key is missing or invalid
125 * @throws ProcessingException if an error occurred while creating the
126 * classifier
127 */
128 public static TrainableClassifier createClassifier(
129 final Set<String> allValidClasses)
130 throws IllegalArgumentException, ProcessingException {
131 return createClassifier(allValidClasses, TiesConfiguration.CONF);
132 }
133
134 /***
135 * Factory method that delegates to
136 * {@link #createClassifier(Set, TiesConfiguration, String)}
137 * without specifying a suffix.
138 *
139 * @param allValidClasses the set of all valid classes
140 * @param conf the configuration to use
141 * @return the created classifier
142 * @throws IllegalArgumentException if the value of the
143 * {@link #CONFIG_CLASSIFIER} key is missing or invalid
144 * @throws ProcessingException if an error occurred while creating the
145 * classifier
146 */
147 public static TrainableClassifier createClassifier(
148 final Set<String> allValidClasses, final TiesConfiguration conf)
149 throws IllegalArgumentException, ProcessingException {
150 return createClassifier(allValidClasses, conf, null);
151 }
152
153 /***
154 * Factory method that delegates to
155 * {@link #createClassifier(Set, File, TiesConfiguration, String)}
156 * without specifying an run directory.
157 *
158 * @param allValidClasses the set of all valid classes
159 * @param conf the configuration to use
160 * @param suffix an optional
161 * {@linkplain TiesConfiguration#adaptKey(String, String) suffix} that is
162 * appended to the {@link Classifier#CONFIG_CLASSIFIER} key if not
163 * <code>null</code>
164 * @return the created classifier
165 * @throws IllegalArgumentException if the value of the
166 * {@link #CONFIG_CLASSIFIER} key is missing or invalid
167 * @throws ProcessingException if an error occurred while creating the
168 * classifier
169 */
170 public static TrainableClassifier createClassifier(
171 final Set<String> allValidClasses, final TiesConfiguration conf,
172 final String suffix)
173 throws IllegalArgumentException, ProcessingException {
174 return createClassifier(allValidClasses, null, conf, suffix);
175 }
176
177 /***
178 * Factory method that delegates to
179 * {@link #createClassifier(Set, File, FeatureTransformer, String[],
180 * TiesConfiguration)}. It reads the specification of the classifier from
181 * the {@link #CONFIG_CLASSIFIER} key in the provided configuration. It
182 * calls {@link FeatureTransformer#createTransformer(TiesConfiguration)} to
183 * initialize a transformer chain, if configured.
184 *
185 * @param allValidClasses the set of all valid classes
186 * @param runDirectory the directory to run the classifier in; used for
187 * {@link ExternalClassifier} instead of the
188 * {@linkplain ExternalClassifier#CONFIG_DIR configured directory}
189 * if not <code>null</code>; ignored otherwise
190 * @param conf the configuration to use
191 * @param suffix an optional
192 * {@linkplain TiesConfiguration#adaptKey(String, String) suffix} that is
193 * appended to the {@link Classifier#CONFIG_CLASSIFIER} key if not
194 * <code>null</code>
195 * @return the created classifier
196 * @throws IllegalArgumentException if the value of the
197 * {@link #CONFIG_CLASSIFIER} key is missing or invalid
198 * @throws ProcessingException if an error occurred while creating the
199 * classifier
200 */
201 public static TrainableClassifier createClassifier(
202 final Set<String> allValidClasses, final File runDirectory,
203 final TiesConfiguration conf, final String suffix)
204 throws IllegalArgumentException, ProcessingException {
205
206 final String[] spec =
207 conf.getStringArray(conf.adaptKey(CONFIG_CLASSIFIER, suffix));
208
209
210 final FeatureTransformer trans =
211 FeatureTransformer.createTransformer(conf);
212 return createClassifier(allValidClasses, runDirectory, trans,
213 spec, conf);
214 }
215
216 /***
217 * Factory method that creates a trainable classifier based on the
218 * provided specification.
219 *
220 * <p>Currently supported values in the first element of the specification:
221 *
222 * <ul>
223 * <li>"Ext" for {@link ExternalClassifier}
224 * <li>"Winnow" for {@link Winnow}
225 * <li>"ucWinnow" for {@link UltraconservativeWinnow}
226 * <li>"Moon" for the {@link MoonClassifier}
227 * <li>"Tie" or "Meta" followed by the specification of the inner
228 * classifiers as further element(s) for {@link TieClassifier} resp.
229 * {@link MetaClassifier}
230 * <li>"Multi" or "OAR" followed by the specification of the inner
231 * classifiers as further element(s) for {@link MultiBinaryClassifier} resp.
232 * {@link OneAgainstTheRestClassifier} (if there are only two classes
233 * to classify, the outer classifer is skipped and the inner classifier is
234 * used directly).
235 * </ul>
236 *
237 * <p>Otherwise the first element must be the qualified name of a
238 * TrainableClassifier subclass accepting a {@link Set} (of all valid class
239 * names) as first argument, a {@link FeatureTransformer} as second argument
240 * and a {@link TiesConfiguration} as third argument.
241 *
242 * @param allValidClasses the set of all valid classes
243 * @param runDirectory the directory to run the classifier in; used for
244 * {@link ExternalClassifier} instead of the
245 * {@linkplain ExternalClassifier#CONFIG_DIR configured directory}
246 * if not <code>null</code>; ignored otherwise
247 * @param trans the last transformer in the transformer chain to use, or
248 * <code>null</code> if no feature transformers should be used
249 * @param spec the specification used to initialize the classifier, as
250 * described above
251 * @param conf passed to the created classifier to configure itself
252 * @return the created classifier
253 * @throws IllegalArgumentException if the value of the
254 * {@link #CONFIG_CLASSIFIER} key is missing or invalid
255 * @throws ProcessingException if an error occurred while creating the
256 * classifier
257 */
258 public static TrainableClassifier createClassifier(
259 final Set<String> allValidClasses, final File runDirectory,
260 final FeatureTransformer trans, final String[] spec,
261 final TiesConfiguration conf)
262 throws IllegalArgumentException, ProcessingException {
263 if ((spec == null) || spec.length < 1) {
264 throw new IllegalArgumentException(
265 "Cannot create classifier -- specification is null or empty");
266 }
267
268
269
270 final String lowerValue = spec[0].toLowerCase();
271 final TrainableClassifier result;
272
273 if ("ext".equals(lowerValue)) {
274
275 result = new ExternalClassifier(allValidClasses, trans,
276 runDirectory, conf);
277 } else if ("winnow".equals(lowerValue)) {
278
279 result = new Winnow(allValidClasses, trans, conf);
280 } else if ("ucwinnow".equals(lowerValue)) {
281
282 result =
283 new UltraconservativeWinnow(allValidClasses, trans, conf);
284 } else if ("moon".equals(lowerValue)) {
285
286 result = new MoonClassifier(allValidClasses, trans, conf);
287 } else if (WRAPPING_CLASSIFIERS.contains(lowerValue)) {
288
289 final String[] innerSpec = new String[spec.length - 1];
290
291
292 for (int i = 0; i < innerSpec.length; i++) {
293 innerSpec[i] = spec[i + 1];
294 }
295
296 if (TIE_CLASSIFIER.equals(lowerValue)) {
297
298 result = new TieClassifier(allValidClasses, trans,
299 runDirectory, innerSpec, conf);
300 } else if (META_CLASSIFIER.equals(lowerValue)) {
301
302 result = new MetaClassifier(allValidClasses, trans,
303 runDirectory, innerSpec, conf);
304 } else if (allValidClasses.size() > 2) {
305
306
307
308 if (MULTI_CLASSIFIER.equals(lowerValue)) {
309
310 result = new MultiBinaryClassifier(allValidClasses, trans,
311 runDirectory, innerSpec, conf);
312 } else if (OAR_CLASSIFIER.equals(lowerValue)) {
313
314 result = new OneAgainstTheRestClassifier(allValidClasses,
315 trans, runDirectory, innerSpec, conf);
316 } else {
317
318 throw new RuntimeException("Implementation error: "
319 + "unknown wrapping classifier" + spec[0]);
320 }
321 } else {
322
323
324 result = createClassifier(allValidClasses, runDirectory, trans,
325 innerSpec, conf);
326 }
327 } else {
328
329
330
331 try {
332 result = (TrainableClassifier) Util.createObject(
333 Class.forName(spec[0]),
334 new Object[] {allValidClasses, trans, conf},
335 new Class[] {Set.class, FeatureTransformer.class,
336 TiesConfiguration.class});
337 } catch (ClassNotFoundException cnfe) {
338
339 throw new ProcessingException(
340 "Cannot create classifier from specification "
341 + ArrayUtils.toString(spec) + ": " + cnfe.toString());
342 } catch (InstantiationException ie) {
343
344 throw new ProcessingException(
345 "Cannot create classifier from specification "
346 + ArrayUtils.toString(spec), ie);
347 }
348 }
349 return result;
350 }
351
352 /***
353 * The immutable set of all valid classes. Each target or candidate class
354 * must be contained in this set.
355 */
356 private final Set<String> allClasses;
357
358 /***
359 * Used to configure this instance.
360 */
361 private final TiesConfiguration config;
362
363 /***
364 * If <code>true</code> the classifier considers all classes for
365 * error-driven training, not only the candidate classes (results are
366 * filtered to the candidate classes prior to returning them).
367 */
368 private final boolean trainingAll;
369
370 /***
371 * The last transformer in a transformer chain, or <code>null</code> if
372 * no feature transformers are used.
373 */
374 private final FeatureTransformer transformer;
375
376 /***
377 * Used to re-use results of the last classification for error-driven
378 * training, if possible.
379 */
380 private PredictionDistribution cachedPredictions = null;
381
382 /***
383 * Used to re-use results of the last classification for error-driven
384 * training, if possible.
385 */
386 private FeatureVector cachedOrgFeatures = null;
387
388 /***
389 * Used to re-use results of the last classification for error-driven
390 * training, if possible.
391 */
392 private FeatureVector cachedActualFeatures = null;
393
394 /***
395 * Used to re-use results of the last classification for error-driven
396 * training, if possible.
397 */
398 private ContextMap cachedContext = null;
399
400
401 /***
402 * Creates a new instance from an XML element, fulfilling the
403 * recommandation of the {@link XMLStorable} interface.
404 *
405 * @param element the XML element containing the serialized representation
406 * @throws InstantiationException if the given element does not contain
407 * a valid classifier description
408 */
409 public TrainableClassifier(final Element element)
410 throws InstantiationException {
411
412 this(CollUtils.asStringSet(element.attributeValue(ATTRIB_CLASSES)),
413 (FeatureTransformer) ObjectElement.createNextObject(
414 element.elementIterator(FeatureTransformer.ELEMENT_MAIN)),
415 Util.asBoolean(element.attributeValue(ATTRIB_TRAIN_ALL)),
416 TiesConfiguration.CONF);
417 }
418
419 /***
420 * Creates a new instance.
421 *
422 * @param allValidClasses the set of all valid classes
423 * @param trans the last transformer in the transformer chain to use, or
424 * <code>null</code> if no feature transformers should be used
425 * @param conf used to configure this instance
426 */
427 public TrainableClassifier(final Set<String> allValidClasses,
428 final FeatureTransformer trans, final TiesConfiguration conf) {
429 this(allValidClasses, trans, conf.getBoolean("classifier.train.all"),
430 conf);
431 }
432
433 /***
434 * Creates a new instance.
435 *
436 * @param allValidClasses the set of all valid classes; all class names
437 * must be {@link TextUtils#ensurePrintableName(String) printable names}
438 * @param trans the last transformer in the transformer chain to use, or
439 * <code>null</code> if no feature transformers should be used
440 * @param trainAll set to <code>true</code> iff the classifier should
441 * consider all classes for error-driven training, not only the candidate
442 * classes (results are filtered to the candidate classes prior to
443 * returning them)
444 * @param conf used to configure this instance
445 */
446 public TrainableClassifier(final Set<String> allValidClasses,
447 final FeatureTransformer trans, final boolean trainAll,
448 final TiesConfiguration conf) {
449 super();
450
451 final Iterator<String> iter = allValidClasses.iterator();
452 while (iter.hasNext()) {
453 TextUtils.ensurePrintableName(iter.next());
454 }
455
456
457 allClasses = Collections.unmodifiableSet(allValidClasses);
458 config = conf;
459 transformer = trans;
460 trainingAll = trainAll;
461 }
462
463 /***
464 * Ensure that all candidate classes are valid (contained in the set of all
465 * classes), throwing an exception otherwise.
466 *
467 * @param candidateClasses an set of classes that are allowed for this item
468 * (the actual <code>targetClass</code> must be one of them)
469 * @throws IllegalArgumentException if not all candidate classes are members
470 * of the {@linkplain #getAllClasses() set of valid classes}
471 */
472 private void checkCandidateClass(final Set candidateClasses)
473 throws IllegalArgumentException {
474 final Iterator classIter = candidateClasses.iterator();
475 String currentClass;
476
477 while (classIter.hasNext()) {
478 currentClass = (String) classIter.next();
479 if (!allClasses.contains(currentClass)) {
480 throw new IllegalArgumentException("Candidate class "
481 + currentClass
482 + " is not in the set of valid classes: " + allClasses);
483 }
484 }
485 }
486
487 /***
488 * Ensure that the target class is valid (contained in the set of all
489 * classes), throwing an exception otherwise.
490 *
491 * @param targetClass the expected class of an instance; must be
492 * contained in the set of <code>candidateClasses</code>
493 * @throws IllegalArgumentException if the target class is not a member of
494 * the {@linkplain #getAllClasses() set of valid classes}
495 */
496 private void checkTargetClass(final String targetClass)
497 throws IllegalArgumentException {
498
499 if (!allClasses.contains(targetClass)) {
500 throw new IllegalArgumentException("Target class " + targetClass
501 + " is not in the set of valid classes: " + allClasses);
502 }
503 }
504
505 /***
506 * Classifies an item that is represented by a feature vector by choosing
507 * the most probable class among a set of candidate classes. Delegates to
508 * the abstract {@link #doClassify(FeatureVector, Set, ContextMap)} method.
509 *
510 * @param features the feature vector to consider
511 * @param candidateClasses an set of classes that are allowed for this item
512 * @return the result of the classification; you can call
513 * {@link PredictionDistribution#best()} to get the most probably class
514 * @throws IllegalArgumentException if the
515 * {@linkplain #getAllClasses() set of valid classes} does not contain all
516 * candidate classes
517 * @throws ProcessingException if an error occurs during classification
518 */
519 public final PredictionDistribution classify(final FeatureVector features,
520 final Set candidateClasses)
521 throws IllegalArgumentException, ProcessingException {
522
523 checkCandidateClass(candidateClasses);
524
525
526 final FeatureVector actualFeatures = (transformer != null)
527 ? transformer.transform(features) : features;
528
529
530 final ContextMap context = new ContextMap();
531 final PredictionDistribution result =
532 doClassify(actualFeatures, candidateClasses, context);
533
534
535 cachedPredictions = result;
536 cachedOrgFeatures = features;
537 cachedActualFeatures = actualFeatures;
538 cachedContext = context;
539
540 return result;
541 }
542
543 /***
544 * Destroys the classifer. This method must be called only if the classifier
545 * will never be used again. The default implementation delegates to
546 * {@link #reset()}, but subclasses can overwrite this behaviour if
547 * appropriate.
548 *
549 * @throws ProcessingException if an error occurs while the classifier is
550 * being destroyed
551 */
552 public void destroy() throws ProcessingException {
553 reset();
554 }
555
556 /***
557 * Classifies an item that is represented by a feature vector by choosing
558 * the most probable class among a set of candidate classes.
559 *
560 * @param features the feature vector to consider
561 * @param candidateClasses an set of classes that are allowed for this item
562 * @param context can be used to transport implementation-specific
563 * contextual information between the
564 * {@link #doClassify(FeatureVector, Set, ContextMap)},
565 * {@link #doTrain(FeatureVector, String, ContextMap)}, and
566 * {@link #trainOnErrorHook(PredictionDistribution, FeatureVector, String,
567 * Set, ContextMap)} methods
568 * @return the result of the classification; you can call
569 * {@link PredictionDistribution#best()} to get the most probably class
570 * @throws ProcessingException if an error occurs during classification
571 */
572 protected abstract PredictionDistribution doClassify(
573 final FeatureVector features, final Set candidateClasses,
574 final ContextMap context)
575 throws ProcessingException;
576
577 /***
578 * Incorporates an item that is represented by a feature vector into the
579 * classification model.
580 *
581 * @param features the feature vector to consider
582 * @param targetClass the class of this feature vector
583 * @param context can be used to transport implementation-specific
584 * contextual information between the
585 * {@link #doClassify(FeatureVector, Set, ContextMap)},
586 * {@link #doTrain(FeatureVector, String, ContextMap)}, and
587 * {@link #trainOnErrorHook(PredictionDistribution, FeatureVector, String,
588 * Set, ContextMap)} methods
589 * @throws ProcessingException if an error occurs during training
590 */
591 protected abstract void doTrain(final FeatureVector features,
592 final String targetClass, final ContextMap context)
593 throws ProcessingException;
594
595 /***
596 * The core of the {@link #trainOnError(FeatureVector, String, Set)} method.
597 * Generally there is no need for subclasses to modify this method.
598 *
599 * @param predDist the prediction distribution returned by
600 * {@link #classify(FeatureVector, Set)}
601 * @param features the feature vector to consider
602 * @param targetClass the expected class of this feature vector; must be
603 * contained in the set of <code>candidateClasses</code>
604 * @param candidateClasses an set of classes that are allowed for this item
605 * (the actual <code>targetClass</code> must be one of them)
606 * @param context can be used to transport implementation-specific
607 * contextual information between the
608 * {@link #doClassify(FeatureVector, Set, ContextMap)},
609 * {@link #doTrain(FeatureVector, String, ContextMap)}, and
610 * {@link #trainOnErrorHook(PredictionDistribution, FeatureVector, String,
611 * Set, ContextMap)} methods
612 * @return the result of the {@link TrainableClassifier#shouldTrain(String,
613 * PredictionDistribution, ContextMap)} method
614 * @throws ProcessingException if an error occurs during training
615 */
616 protected boolean doTrainOnError(final PredictionDistribution predDist,
617 final FeatureVector features, final String targetClass,
618 final Set candidateClasses, final ContextMap context)
619 throws ProcessingException {
620
621 final boolean hookHandledTraining = trainOnErrorHook(predDist,
622 features, targetClass, candidateClasses, context);
623 final boolean shouldTrain = shouldTrain(targetClass, predDist, context);
624
625
626 if (shouldTrain && !hookHandledTraining) {
627 doTrain(features, targetClass, context);
628 }
629
630 return shouldTrain;
631 }
632
633 /***
634 * Returns the set of all valid classes. Each target or candidate class
635 * must be contained in this set.
636 *
637 * @return an immutable set containing all valid class names
638 */
639 public Set<String> getAllClasses() {
640 return allClasses;
641 }
642
643 /***
644 * Returns the configuration used by this instance.
645 * @return the used configuration
646 */
647 public TiesConfiguration getConfig() {
648 return config;
649 }
650
651 /***
652 * Resets the classifer, completely deleting the prediction model.
653 * @throws ProcessingException if an error occurs during reset
654 */
655 public abstract void reset() throws ProcessingException;
656
657 /***
658 * Invoked by {@link #trainOnError(FeatureVector, String, Set)} to decide
659 * whether to train an instance. The default behavior is to train if the
660 * best prediction was wrong or didn't yield a positive probability
661 * ("train only errors"). Subclasses can override this method to
662 * add their own behavior, e.g. reinforcement training (thick threshold
663 * heuristic).
664 *
665 * @param targetClass the expected class of this feature vector; must be
666 * contained in the set of <code>candidateClasses</code>
667 * @param predDist the prediction distribution returned by
668 * {@link #doClassify(FeatureVector, Set, ContextMap)}
669 * @param context can be used to transport implementation-specific
670 * contextual information between the
671 * {@link #doClassify(FeatureVector, Set, ContextMap)},
672 * {@link #doTrain(FeatureVector, String, ContextMap)}, and
673 * {@link #trainOnErrorHook(PredictionDistribution, FeatureVector, String,
674 * Set, ContextMap)} methods
675 * @return whether to train this instance
676 */
677 protected boolean shouldTrain(final String targetClass,
678 final PredictionDistribution predDist, final ContextMap context) {
679 final Prediction best = predDist.best();
680 final double bestProb = best.getProbability().getProb();
681 return !best.getType().equals(targetClass) || Double.isNaN(bestProb)
682 || (bestProb <= 0.0);
683 }
684
685 /***
686 * {@inheritDoc}
687 * Subclasses of {@link TrainableClassifier} should extend this method and
688 * the corresponding constructor from {@link org.dom4j.Element} to
689 * ensure (de)serialization works as expected.
690 */
691 public ObjectElement toElement() {
692 final ObjectElement result =
693 new ObjectElement(ELEMENT_MAIN, this.getClass());
694
695 result.addAttribute(ATTRIB_CLASSES,
696 CollUtils.flatten(allClasses.iterator()));
697 result.addAttribute(ATTRIB_TRAIN_ALL, Boolean.toString(trainingAll));
698
699
700 if (transformer != null) {
701 result.add(transformer.toElement());
702 }
703 return result;
704 }
705
706 /***
707 * Returns a string representation of this object.
708 *
709 * @return a textual representation
710 */
711 public String toString() {
712 final ToStringBuilder builder = new ToStringBuilder(this);
713 builder.append("all classes", allClasses);
714
715 if (trainingAll) {
716 builder.append("training all classes", trainingAll);
717 }
718
719 if (transformer != null) {
720 builder.append("transformer", transformer);
721 }
722
723 return builder.toString();
724 }
725
726 /***
727 * Incorporates an item that is represented by a feature vector into the
728 * classification model. Delegates to the abstract
729 * {@link #doTrain(FeatureVector, String, ContextMap)} method.
730 *
731 * @param features the feature vector to consider
732 * @param targetClass the class of this feature vector
733 * @throws IllegalArgumentException if the target class is not in the
734 * {@linkplain #getAllClasses() set of valid classes}
735 * @throws ProcessingException if an error occurs during training
736 */
737 public final void train(final FeatureVector features,
738 final String targetClass)
739 throws IllegalArgumentException, ProcessingException {
740
741 checkTargetClass(targetClass);
742
743
744 final FeatureVector actualFeatures = (transformer != null)
745 ? transformer.transform(features) : features;
746
747
748 final ContextMap context = new ContextMap();
749 doTrain(actualFeatures, targetClass, context);
750 }
751
752 /***
753 * Handles error-driven learning ("train only errors"): the specified
754 * feature vector is trained into the model only if the predicted class
755 * for the feature vector differs from the specified target class. If the
756 * prediction was correct, the model is not changed.
757 *
758 * @param features the feature vector to consider
759 * @param targetClass the expected class of this feature vector; must be
760 * contained in the set of <code>candidateClasses</code>
761 * @param candidateClasses an set of classes that are allowed for this item
762 * (the actual <code>targetClass</code> must be one of them)
763 * @return the original prediction distribution if the best prediction was
764 * wrong, i.e. if training was necessary; or <code>null</code> if no
765 * training was necessary (the prediction was already correct)
766 * @throws ProcessingException if an error occurs during training
767 */
768 public final PredictionDistribution trainOnError(
769 final FeatureVector features, final String targetClass,
770 final Set candidateClasses) throws ProcessingException {
771
772 checkTargetClass(targetClass);
773 checkCandidateClass(candidateClasses);
774
775
776 final Set consideredClasses = trainingAll
777 ? allClasses : candidateClasses;
778
779 final FeatureVector actualFeatures;
780 final ContextMap context;
781 final PredictionDistribution predDist;
782
783
784
785 if (features == cachedOrgFeatures) {
786
787 actualFeatures = cachedActualFeatures;
788 context = cachedContext;
789 predDist = cachedPredictions;
790 } else {
791
792 actualFeatures = (transformer != null)
793 ? transformer.transform(features) : features;
794 context = new ContextMap();
795 predDist = doClassify(actualFeatures, consideredClasses, context);
796 }
797
798
799 final boolean shouldTrain = doTrainOnError(predDist,
800 actualFeatures, targetClass, consideredClasses, context);
801
802 if (shouldTrain) {
803
804 if (trainingAll) {
805 final Iterator predIter = predDist.iterator();
806 Prediction pred;
807
808 while (predIter.hasNext()) {
809 pred = (Prediction) predIter.next();
810
811 if (!candidateClasses.contains(pred.getType())) {
812
813 predIter.remove();
814 }
815 }
816 }
817
818
819 if (((predDist.size() > 0)
820 && !predDist.best().getType().equals(targetClass))) {
821
822 return predDist;
823 } else {
824
825 return null;
826 }
827 } else {
828
829 return null;
830 }
831 }
832
833 /***
834 * Subclasses can implement this hook for more refined error-driven
835 * learning. It is called from the
836 * {@link #trainOnError(FeatureVector, String, Set)} method after
837 * classifying. This method can do any necessary training itself and
838 * return <code>true</code> to signal that no further action is necessary.
839 * This implementation is just a placeholder that always returns
840 * <code>false</code>.
841 *
842 * @param predDist the prediction distribution returned by
843 * {@link #classify(FeatureVector, Set)}
844 * @param features the feature vector to consider
845 * @param targetClass the expected class of this feature vector; must be
846 * contained in the set of <code>candidateClasses</code>
847 * @param candidateClasses an set of classes that are allowed for this item
848 * (the actual <code>targetClass</code> must be one of them)
849 * @param context can be used to transport implementation-specific
850 * contextual information between the
851 * {@link #doClassify(FeatureVector, Set, ContextMap)},
852 * {@link #doTrain(FeatureVector, String, ContextMap)}, and
853 * {@link #trainOnErrorHook(PredictionDistribution, FeatureVector, String,
854 * Set, ContextMap)} methods
855 * @return this implementation always returns <code>false</code>; subclasses
856 * can return <code>true</code> to signal that any error-driven learning was
857 * already handled
858 * @throws ProcessingException if an error occurs during training
859 */
860 protected boolean trainOnErrorHook(final PredictionDistribution predDist,
861 final FeatureVector features, final String targetClass,
862 final Set candidateClasses, final ContextMap context)
863 throws ProcessingException {
864 return false;
865 }
866
867 }