View Javadoc

1   /*
2    * Copyright (C) 2003-2006 Christian Siefkes <christian@siefkes.net>.
3    * Development of this software is supported by the German Research Society,
4    * Berlin-Brandenburg Graduate School in Distributed Information Systems
5    * (DFG grant no. GRK 316).
6    *
7    * This program is free software; you can redistribute it and/or modify
8    * it under the terms of the GNU General Public License as published by
9    * the Free Software Foundation; either version 2 of the License, or
10   * (at your option) any later version.
11   *
12   * This program is distributed in the hope that it will be useful,
13   * but WITHOUT ANY WARRANTY; without even the implied warranty of
14   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15   * GNU General Public License for more details.
16   *
17   * You should have received a copy of the GNU General Public License
18   * along with this program; if not, visit
19   * http://www.gnu.org/licenses/gpl.html or write to the Free Software
20   * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
21   */
22  package de.fu_berlin.ties.classify;
23  
24  import java.io.File;
25  import java.util.Collections;
26  import java.util.HashSet;
27  import java.util.Iterator;
28  import java.util.Set;
29  
30  import org.apache.commons.lang.ArrayUtils;
31  import org.apache.commons.lang.builder.ToStringBuilder;
32  import org.dom4j.Element;
33  import org.dom4j.QName;
34  
35  import de.fu_berlin.ties.ContextMap;
36  import de.fu_berlin.ties.ProcessingException;
37  import de.fu_berlin.ties.TiesConfiguration;
38  import de.fu_berlin.ties.classify.feature.FeatureTransformer;
39  import de.fu_berlin.ties.classify.feature.FeatureVector;
40  import de.fu_berlin.ties.classify.winnow.UltraconservativeWinnow;
41  import de.fu_berlin.ties.classify.winnow.Winnow;
42  import de.fu_berlin.ties.io.ObjectElement;
43  import de.fu_berlin.ties.io.XMLStorable;
44  import de.fu_berlin.ties.text.TextUtils;
45  import de.fu_berlin.ties.util.CollUtils;
46  import de.fu_berlin.ties.util.Util;
47  import de.fu_berlin.ties.xml.dom.DOMUtils;
48  
49  /***
50   * Classifiers extending this abstract class must provide a training mechanism
51   * by implementing the {@link #doTrain(FeatureVector, String, ContextMap)}
52   * method. This class supports error-driven learning ("train only errors")
53   * which often leads to better prediction models than brute-force training.
54   *
55   * <p>The code in this class is thread-safe.
56   *
57   * @author Christian Siefkes
58   * @version $Revision: 1.50 $, $Date: 2006/11/26 21:14:58 $, $Author: siefkes $
59   */
60  public abstract class TrainableClassifier implements Classifier, XMLStorable {
61  
62      /***
63       * Name of the main element used for XML serialization.
64       */
65      public static final QName ELEMENT_MAIN =
66          DOMUtils.defaultName("classifier");
67  
68      /***
69       * Attribute name used for XML serialization.
70       */
71      static final QName ATTRIB_CLASSES = DOMUtils.defaultName("classes");
72  
73      /***
74       * Attribute name used for XML serialization.
75       */
76      static final QName ATTRIB_TRAIN_ALL = DOMUtils.defaultName("train-all");
77  
78      /***
79       * Flag used to load the {@link MetaClassifier}.
80       */
81      public static final String META_CLASSIFIER = "meta";
82  
83      /***
84       * Flag used to load the {@link MultiBinaryClassifier}.
85       */
86      public static final String MULTI_CLASSIFIER = "multi";
87  
88      /***
89       * Flag used to load the {@link OneAgainstTheRestClassifier}.
90       */
91      public static final String OAR_CLASSIFIER = "oar";
92  
93      /***
94       * Flag used to load the {@link TieClassifier}.
95       */
96      public static final String TIE_CLASSIFIER = "tie";
97  
98      /***
99       * Names of classifiers wrapping inner classifiers.
100      */
101     private static final Set<String> WRAPPING_CLASSIFIERS;
102 
103 
104     /***
105      * Static initialization of set of wrapping classifiers.
106      */
107     static {
108         final Set<String> wrapping = new HashSet<String>();
109         wrapping.add(META_CLASSIFIER);
110         wrapping.add(MULTI_CLASSIFIER);
111         wrapping.add(OAR_CLASSIFIER);
112         wrapping.add(TIE_CLASSIFIER);
113         WRAPPING_CLASSIFIERS = Collections.unmodifiableSet(wrapping);
114     }
115 
116     /***
117      * Factory method that delegates to
118      * {@link #createClassifier(Set, TiesConfiguration)} using the
119      * {@linkplain TiesConfiguration#CONF standard configuration}.
120      *
121      * @param allValidClasses the set of all valid classes
122      * @return the created classifier
123      * @throws IllegalArgumentException if the value of the
124      * {@link #CONFIG_CLASSIFIER} key is missing or invalid
125      * @throws ProcessingException if an error occurred while creating the
126      * classifier
127      */
128     public static TrainableClassifier createClassifier(
129             final Set<String> allValidClasses)
130             throws IllegalArgumentException, ProcessingException {
131         return createClassifier(allValidClasses, TiesConfiguration.CONF);
132     }
133 
134     /***
135      * Factory method that delegates to
136      * {@link #createClassifier(Set, TiesConfiguration, String)}
137      * without specifying a suffix.
138      *
139      * @param allValidClasses the set of all valid classes
140      * @param conf the configuration to use
141      * @return the created classifier
142      * @throws IllegalArgumentException if the value of the
143      * {@link #CONFIG_CLASSIFIER} key is missing or invalid
144      * @throws ProcessingException if an error occurred while creating the
145      * classifier
146      */
147     public static TrainableClassifier createClassifier(
148             final Set<String> allValidClasses, final TiesConfiguration conf)
149     throws IllegalArgumentException, ProcessingException {
150         return createClassifier(allValidClasses, conf, null);
151     }
152 
153     /***
154      * Factory method that delegates to
155      * {@link #createClassifier(Set, File, TiesConfiguration, String)}
156      * without specifying an run directory.
157      *
158      * @param allValidClasses the set of all valid classes
159      * @param conf the configuration to use
160      * @param suffix an optional
161      * {@linkplain TiesConfiguration#adaptKey(String, String) suffix} that is
162      * appended to the {@link Classifier#CONFIG_CLASSIFIER} key if not
163      * <code>null</code>
164      * @return the created classifier
165      * @throws IllegalArgumentException if the value of the
166      * {@link #CONFIG_CLASSIFIER} key is missing or invalid
167      * @throws ProcessingException if an error occurred while creating the
168      * classifier
169      */
170     public static TrainableClassifier createClassifier(
171             final Set<String> allValidClasses, final TiesConfiguration conf,
172             final String suffix)
173     throws IllegalArgumentException, ProcessingException {
174         return createClassifier(allValidClasses, null, conf, suffix);
175     }
176 
177     /***
178      * Factory method that delegates to
179      * {@link #createClassifier(Set, File, FeatureTransformer, String[],
180      * TiesConfiguration)}. It reads the specification of the classifier from
181      * the {@link #CONFIG_CLASSIFIER} key in the provided configuration. It
182      * calls {@link FeatureTransformer#createTransformer(TiesConfiguration)} to
183      * initialize a transformer chain, if configured.
184      *
185      * @param allValidClasses the set of all valid classes
186      * @param runDirectory the directory to run the classifier in; used for
187      * {@link ExternalClassifier} instead of the
188      * {@linkplain ExternalClassifier#CONFIG_DIR configured directory}
189      * if not <code>null</code>; ignored otherwise
190      * @param conf the configuration to use
191      * @param suffix an optional
192      * {@linkplain TiesConfiguration#adaptKey(String, String) suffix} that is
193      * appended to the {@link Classifier#CONFIG_CLASSIFIER} key if not
194      * <code>null</code>
195      * @return the created classifier
196      * @throws IllegalArgumentException if the value of the
197      * {@link #CONFIG_CLASSIFIER} key is missing or invalid
198      * @throws ProcessingException if an error occurred while creating the
199      * classifier
200      */
201     public static TrainableClassifier createClassifier(
202             final Set<String> allValidClasses, final File runDirectory,
203             final TiesConfiguration conf, final String suffix)
204             throws IllegalArgumentException, ProcessingException {
205         // read specification, considering suffix if given
206         final String[] spec =
207             conf.getStringArray(conf.adaptKey(CONFIG_CLASSIFIER, suffix));
208 
209         // create transformer chain from the configuration, if specified
210         final FeatureTransformer trans =
211             FeatureTransformer.createTransformer(conf);
212         return createClassifier(allValidClasses, runDirectory, trans,
213                 spec, conf);
214     }
215 
216     /***
217      * Factory method that creates a trainable classifier based on the
218      * provided specification.
219      *
220      * <p>Currently supported values in the first element of the specification:
221      *
222      * <ul>
223      * <li>"Ext" for {@link ExternalClassifier}
224      * <li>"Winnow" for {@link Winnow}
225      * <li>"ucWinnow" for {@link UltraconservativeWinnow}
226      * <li>"Moon" for the {@link MoonClassifier}
227      * <li>"Tie" or "Meta" followed by the specification of the inner
228      * classifiers as further element(s) for {@link TieClassifier} resp.
229      * {@link MetaClassifier}
230      * <li>"Multi" or "OAR" followed by the specification of the inner
231      * classifiers as further element(s) for {@link MultiBinaryClassifier} resp.
232      * {@link OneAgainstTheRestClassifier} (if there are only two classes
233      * to classify, the outer classifer is skipped and the inner classifier is
234      * used directly).
235      * </ul>
236      *
237      * <p>Otherwise the first element must be the qualified name of a
238      * TrainableClassifier subclass accepting a {@link Set} (of all valid class
239      * names) as first argument, a {@link FeatureTransformer} as second argument
240      * and a {@link TiesConfiguration} as third argument.
241      *
242      * @param allValidClasses the set of all valid classes
243      * @param runDirectory the directory to run the classifier in; used for
244      * {@link ExternalClassifier} instead of the
245      * {@linkplain ExternalClassifier#CONFIG_DIR configured directory}
246      * if not <code>null</code>; ignored otherwise
247      * @param trans the last transformer in the transformer chain to use, or
248      * <code>null</code> if no feature transformers should be used
249      * @param spec the specification used to initialize the classifier, as
250      * described above
251      * @param conf passed to the created classifier to configure itself
252      * @return the created classifier
253      * @throws IllegalArgumentException if the value of the
254      * {@link #CONFIG_CLASSIFIER} key is missing or invalid
255      * @throws ProcessingException if an error occurred while creating the
256      * classifier
257      */
258     public static TrainableClassifier createClassifier(
259             final Set<String> allValidClasses, final File runDirectory,
260             final FeatureTransformer trans, final String[] spec,
261             final TiesConfiguration conf)
262             throws IllegalArgumentException, ProcessingException {
263         if ((spec == null) || spec.length < 1) {
264             throw new IllegalArgumentException(
265                 "Cannot create classifier -- specification is null or empty");
266         }
267 
268 
269         // convert 1st array element to lower case to compare with defined types
270         final String lowerValue = spec[0].toLowerCase();
271         final TrainableClassifier result;
272 
273         if ("ext".equals(lowerValue)) {
274             // load external classifier
275             result = new ExternalClassifier(allValidClasses, trans,
276                 runDirectory, conf);
277         } else if ("winnow".equals(lowerValue)) {
278             // load Winnow classifier
279             result = new Winnow(allValidClasses, trans, conf);
280         } else if ("ucwinnow".equals(lowerValue)) {
281             // load Ultraconservative Winnow
282             result =
283                 new UltraconservativeWinnow(allValidClasses, trans, conf);
284         } else if ("moon".equals(lowerValue)) {
285             // load Moonfilter proxy
286             result = new MoonClassifier(allValidClasses, trans, conf);
287         } else if (WRAPPING_CLASSIFIERS.contains(lowerValue)) {
288             // Load classifier wrapping inner classifiers
289             final String[] innerSpec = new String[spec.length - 1];
290 
291             // Use rest of array as specification for the inner classifiers
292             for (int i = 0; i < innerSpec.length; i++) {
293                 innerSpec[i] = spec[i + 1];
294             }
295 
296             if (TIE_CLASSIFIER.equals(lowerValue)) {
297                 // load tie classifier
298                 result = new TieClassifier(allValidClasses, trans,
299                         runDirectory, innerSpec, conf);
300             } else if (META_CLASSIFIER.equals(lowerValue)) {
301                 // load meta classifier
302                 result = new MetaClassifier(allValidClasses, trans,
303                         runDirectory, innerSpec, conf);
304             } else if (allValidClasses.size() > 2) {
305                 // must be multi-binary or one-against-the-rest classifier
306                 // whch are unnecessary for 2 classes
307 
308                 if (MULTI_CLASSIFIER.equals(lowerValue)) {
309                     // multi-binary classifier (with background class)
310                     result = new MultiBinaryClassifier(allValidClasses, trans,
311                             runDirectory, innerSpec, conf);
312                 } else if (OAR_CLASSIFIER.equals(lowerValue))  {
313                     // one-against-the-rest classifier (no background class)
314                     result = new OneAgainstTheRestClassifier(allValidClasses,
315                             trans, runDirectory, innerSpec, conf);
316                 } else {
317                     // not supposed to happen
318                     throw new RuntimeException("Implementation error: "
319                             + "unknown wrapping classifier" + spec[0]);
320                 }
321             } else {
322                 // no need to use outer classifier: recursively call this method
323                 // to return an instance of the inner classifier
324                 result = createClassifier(allValidClasses, runDirectory, trans,
325                     innerSpec, conf);
326             }
327         } else {
328             // should be the qualified name of a TrainableClassifier subclass
329             // accepting a Set (of valid class names) as 1st argument, a
330             // FeatureTransformer as 2nd and a TiesConfiguration as 3rd argument
331             try {
332                 result = (TrainableClassifier) Util.createObject(
333                     Class.forName(spec[0]),
334                     new Object[] {allValidClasses, trans, conf},
335                     new Class[] {Set.class, FeatureTransformer.class,
336                         TiesConfiguration.class});
337             } catch (ClassNotFoundException cnfe) {
338                 // convert and rethrow exception
339                 throw new ProcessingException(
340                     "Cannot create classifier from specification "
341                     + ArrayUtils.toString(spec) + ": " + cnfe.toString());
342             } catch (InstantiationException ie) {
343                 // convert and rethrow exception
344                 throw new ProcessingException(
345                     "Cannot create classifier from specification "
346                     + ArrayUtils.toString(spec), ie);
347             }
348         }
349         return result;
350     }
351 
352     /***
353      * The immutable set of all valid classes. Each target or candidate class
354      * must be contained in this set.
355      */
356     private final Set<String> allClasses;
357 
358     /***
359      * Used to configure this instance.
360      */
361     private final TiesConfiguration config;
362 
363     /***
364      * If <code>true</code> the classifier considers all classes for
365      * error-driven training, not only the candidate classes (results are
366      * filtered to the candidate classes prior to returning them).
367      */
368     private final boolean trainingAll;
369 
370     /***
371      * The last transformer in a transformer chain, or <code>null</code> if
372      * no feature transformers are used.
373      */
374     private final FeatureTransformer transformer;
375 
376     /***
377      * Used to re-use results of the last classification for error-driven
378      * training, if possible.
379      */
380     private PredictionDistribution cachedPredictions = null;
381 
382     /***
383      * Used to re-use results of the last classification for error-driven
384      * training, if possible.
385      */
386     private FeatureVector cachedOrgFeatures = null;
387 
388     /***
389      * Used to re-use results of the last classification for error-driven
390      * training, if possible.
391      */
392     private FeatureVector cachedActualFeatures = null;
393 
394     /***
395      * Used to re-use results of the last classification for error-driven
396      * training, if possible.
397      */
398     private ContextMap cachedContext = null;
399 
400 
401     /***
402      * Creates a new instance from an XML element, fulfilling the
403      * recommandation of the {@link XMLStorable} interface.
404      *
405      * @param element the XML element containing the serialized representation
406      * @throws InstantiationException if the given element does not contain
407      * a valid classifier description
408      */
409     public TrainableClassifier(final Element element)
410     throws InstantiationException {
411         // init feature transformer from nested element if present
412         this(CollUtils.asStringSet(element.attributeValue(ATTRIB_CLASSES)),
413                 (FeatureTransformer) ObjectElement.createNextObject(
414                     element.elementIterator(FeatureTransformer.ELEMENT_MAIN)),
415                 Util.asBoolean(element.attributeValue(ATTRIB_TRAIN_ALL)),
416                 TiesConfiguration.CONF);
417     }
418 
419     /***
420      * Creates a new instance.
421      *
422      * @param allValidClasses the set of all valid classes
423      * @param trans the last transformer in the transformer chain to use, or
424      * <code>null</code> if no feature transformers should be used
425      * @param conf used to configure this instance
426      */
427     public TrainableClassifier(final Set<String> allValidClasses,
428             final FeatureTransformer trans, final TiesConfiguration conf) {
429         this(allValidClasses, trans, conf.getBoolean("classifier.train.all"),
430                 conf);
431     }
432 
433     /***
434      * Creates a new instance.
435      *
436      * @param allValidClasses the set of all valid classes; all class names
437      * must be {@link TextUtils#ensurePrintableName(String) printable names}
438      * @param trans the last transformer in the transformer chain to use, or
439      * <code>null</code> if no feature transformers should be used
440      * @param trainAll set to <code>true</code> iff the classifier should
441      * consider all classes for error-driven training, not only the candidate
442      * classes (results are filtered to the candidate classes prior to
443      * returning them)
444      * @param conf used to configure this instance
445      */
446     public TrainableClassifier(final Set<String> allValidClasses,
447             final FeatureTransformer trans, final boolean trainAll,
448             final TiesConfiguration conf) {
449         super();
450         // check that class names are printable
451         final Iterator<String> iter = allValidClasses.iterator();
452         while (iter.hasNext()) {
453             TextUtils.ensurePrintableName(iter.next());
454         }
455 
456         // make set immutable
457         allClasses = Collections.unmodifiableSet(allValidClasses);
458         config = conf;
459         transformer = trans;
460         trainingAll = trainAll;
461     }
462 
463     /***
464      * Ensure that all candidate classes are valid (contained in the set of all
465      * classes), throwing an exception otherwise.
466      *
467      * @param candidateClasses an set of classes that are allowed for this item
468      * (the actual <code>targetClass</code> must be one of them)
469      * @throws IllegalArgumentException if not all candidate classes are members
470      * of the {@linkplain #getAllClasses() set of valid classes}
471      */
472     private void checkCandidateClass(final Set candidateClasses)
473             throws IllegalArgumentException {
474         final Iterator classIter = candidateClasses.iterator();
475         String currentClass;
476 
477         while (classIter.hasNext()) {
478             currentClass = (String) classIter.next();
479             if (!allClasses.contains(currentClass)) {
480                 throw new IllegalArgumentException("Candidate class "
481                     + currentClass
482                     + " is not in the set of valid classes: " + allClasses);
483             }
484         }
485     }
486 
487     /***
488      * Ensure that the target class is valid (contained in the set of all
489      * classes), throwing an exception otherwise.
490      *
491      * @param targetClass the expected class of an instance; must be
492      * contained in the set of <code>candidateClasses</code>
493      * @throws IllegalArgumentException if the target class is not a member of
494      * the {@linkplain #getAllClasses() set of valid classes}
495      */
496     private void checkTargetClass(final String targetClass)
497             throws IllegalArgumentException {
498         // ensure that the target class is valid
499         if (!allClasses.contains(targetClass)) {
500             throw new IllegalArgumentException("Target class " + targetClass
501                 + " is not in the set of valid classes: " + allClasses);
502         }
503     }
504 
505     /***
506      * Classifies an item that is represented by a feature vector by choosing
507      * the most probable class among a set of candidate classes. Delegates to
508      * the abstract {@link #doClassify(FeatureVector, Set, ContextMap)} method.
509      *
510      * @param features the feature vector to consider
511      * @param candidateClasses an set of classes that are allowed for this item
512      * @return the result of the classification; you can call
513      * {@link PredictionDistribution#best()} to get the most probably class
514      * @throws IllegalArgumentException if the
515      * {@linkplain #getAllClasses() set of valid classes} does not contain all
516      * candidate classes
517      * @throws ProcessingException if an error occurs during classification
518      */
519     public final PredictionDistribution classify(final FeatureVector features,
520             final Set candidateClasses)
521             throws IllegalArgumentException, ProcessingException {
522         // ensure that all candidate classes are valid
523         checkCandidateClass(candidateClasses);
524 
525         // transform features
526         final FeatureVector actualFeatures = (transformer != null)
527             ? transformer.transform(features) : features;
528 
529         // delegate to abstract method
530         final ContextMap context = new ContextMap();
531         final PredictionDistribution result =
532             doClassify(actualFeatures, candidateClasses, context);
533 
534         // cache decision + context to allow re-use for error-driven training
535         cachedPredictions = result;
536         cachedOrgFeatures = features;
537         cachedActualFeatures = actualFeatures;
538         cachedContext = context;
539 
540         return result;
541     }
542 
543     /***
544      * Destroys the classifer. This method must be called only if the classifier
545      * will never be used again. The default implementation delegates to
546      * {@link #reset()}, but subclasses can overwrite this behaviour if
547      * appropriate.
548      *
549      * @throws ProcessingException if an error occurs while the classifier is
550      * being destroyed
551      */
552     public void destroy() throws ProcessingException {
553         reset();
554     }
555 
556     /***
557      * Classifies an item that is represented by a feature vector by choosing
558      * the most probable class among a set of candidate classes.
559      *
560      * @param features the feature vector to consider
561      * @param candidateClasses an set of classes that are allowed for this item
562      * @param context can be used to transport implementation-specific
563      * contextual information between the
564      * {@link #doClassify(FeatureVector, Set, ContextMap)},
565      * {@link #doTrain(FeatureVector, String, ContextMap)}, and
566      * {@link #trainOnErrorHook(PredictionDistribution, FeatureVector, String,
567      * Set, ContextMap)} methods
568      * @return the result of the classification; you can call
569      * {@link PredictionDistribution#best()} to get the most probably class
570      * @throws ProcessingException if an error occurs during classification
571      */
572     protected abstract PredictionDistribution doClassify(
573             final FeatureVector features, final Set candidateClasses,
574             final ContextMap context)
575             throws ProcessingException;
576 
577     /***
578      * Incorporates an item that is represented by a feature vector into the
579      * classification model.
580      *
581      * @param features the feature vector to consider
582      * @param targetClass the class of this feature vector
583      * @param context can be used to transport implementation-specific
584      * contextual information between the
585      * {@link #doClassify(FeatureVector, Set, ContextMap)},
586      * {@link #doTrain(FeatureVector, String, ContextMap)}, and
587      * {@link #trainOnErrorHook(PredictionDistribution, FeatureVector, String,
588      * Set, ContextMap)} methods
589      * @throws ProcessingException if an error occurs during training
590      */
591     protected abstract void doTrain(final FeatureVector features,
592             final String targetClass, final ContextMap context)
593             throws ProcessingException;
594 
595     /***
596      * The core of the {@link #trainOnError(FeatureVector, String, Set)} method.
597      * Generally there is no need for subclasses to modify this method.
598      *
599      * @param predDist the prediction distribution returned by
600      * {@link #classify(FeatureVector, Set)}
601      * @param features the feature vector to consider
602      * @param targetClass the expected class of this feature vector; must be
603      * contained in the set of <code>candidateClasses</code>
604      * @param candidateClasses an set of classes that are allowed for this item
605      * (the actual <code>targetClass</code> must be one of them)
606      * @param context can be used to transport implementation-specific
607      * contextual information between the
608      * {@link #doClassify(FeatureVector, Set, ContextMap)},
609      * {@link #doTrain(FeatureVector, String, ContextMap)}, and
610      * {@link #trainOnErrorHook(PredictionDistribution, FeatureVector, String,
611      * Set, ContextMap)} methods
612      * @return the result of the {@link TrainableClassifier#shouldTrain(String,
613      * PredictionDistribution, ContextMap)} method
614      * @throws ProcessingException if an error occurs during training
615      */
616     protected boolean doTrainOnError(final PredictionDistribution predDist,
617             final FeatureVector features, final String targetClass,
618             final Set candidateClasses, final ContextMap context)
619             throws ProcessingException {
620         // call hook for subclasses
621         final boolean hookHandledTraining = trainOnErrorHook(predDist,
622                 features, targetClass, candidateClasses, context);
623         final boolean shouldTrain = shouldTrain(targetClass, predDist, context);
624 
625         // train if requested, unless the hook kicked in
626         if (shouldTrain && !hookHandledTraining) {
627             doTrain(features, targetClass, context);
628         }
629 
630         return shouldTrain;
631     }
632 
633     /***
634      * Returns the set of all valid classes. Each target or candidate class
635      * must be contained in this set.
636      *
637      * @return an immutable set containing all valid class names
638      */
639     public Set<String> getAllClasses() {
640         return allClasses;
641     }
642 
643     /***
644      * Returns the configuration used by this instance.
645      * @return the used configuration
646      */
647     public TiesConfiguration getConfig() {
648         return config;
649     }
650 
651     /***
652      * Resets the classifer, completely deleting the prediction model.
653      * @throws ProcessingException if an error occurs during reset
654      */
655     public abstract void reset() throws ProcessingException;
656 
657     /***
658      * Invoked by {@link #trainOnError(FeatureVector, String, Set)} to decide
659      * whether to train an instance. The default behavior is to train if the
660      * best prediction was wrong or didn't yield a positive probability
661      * ("train only errors"). Subclasses can override this method to
662      * add their own behavior, e.g. reinforcement training (thick threshold
663      * heuristic).
664      *
665      * @param targetClass the expected class of this feature vector; must be
666      * contained in the set of <code>candidateClasses</code>
667      * @param predDist the prediction distribution returned by
668      * {@link #doClassify(FeatureVector, Set, ContextMap)}
669      * @param context can be used to transport implementation-specific
670      * contextual information between the
671      * {@link #doClassify(FeatureVector, Set, ContextMap)},
672      * {@link #doTrain(FeatureVector, String, ContextMap)}, and
673      * {@link #trainOnErrorHook(PredictionDistribution, FeatureVector, String,
674      * Set, ContextMap)} methods
675      * @return whether to train this instance
676      */
677     protected boolean shouldTrain(final String targetClass,
678             final PredictionDistribution predDist, final ContextMap context) {
679         final Prediction best = predDist.best();
680         final double bestProb = best.getProbability().getProb();
681         return !best.getType().equals(targetClass) || Double.isNaN(bestProb)
682             || (bestProb <= 0.0);
683     }
684 
685     /***
686      * {@inheritDoc}
687      * Subclasses of {@link TrainableClassifier} should extend this method and
688      * the corresponding constructor from {@link org.dom4j.Element} to
689      * ensure (de)serialization works as expected.
690      */
691     public ObjectElement toElement() {
692         final ObjectElement result =
693             new ObjectElement(ELEMENT_MAIN, this.getClass());
694         // add global attributes
695         result.addAttribute(ATTRIB_CLASSES,
696                 CollUtils.flatten(allClasses.iterator()));
697         result.addAttribute(ATTRIB_TRAIN_ALL, Boolean.toString(trainingAll));
698 
699         // store feature transformer as first element, if exists
700         if (transformer != null) {
701             result.add(transformer.toElement());
702         }
703         return result;
704     }
705 
706     /***
707      * Returns a string representation of this object.
708      *
709      * @return a textual representation
710      */
711     public String toString() {
712         final ToStringBuilder builder = new ToStringBuilder(this);
713         builder.append("all classes", allClasses);
714 
715         if (trainingAll) {
716             builder.append("training all classes", trainingAll);
717         }
718 
719         if (transformer != null) {
720             builder.append("transformer", transformer);
721         }
722 
723         return builder.toString();
724     }
725 
726     /***
727      * Incorporates an item that is represented by a feature vector into the
728      * classification model. Delegates to the abstract
729      * {@link #doTrain(FeatureVector, String, ContextMap)} method.
730      *
731      * @param features the feature vector to consider
732      * @param targetClass the class of this feature vector
733      * @throws IllegalArgumentException if the target class is not in the
734      * {@linkplain #getAllClasses() set of valid classes}
735      * @throws ProcessingException if an error occurs during training
736      */
737     public final void train(final FeatureVector features,
738             final String targetClass)
739             throws IllegalArgumentException, ProcessingException {
740         // ensure that the target class is valid
741         checkTargetClass(targetClass);
742 
743         // transform features
744         final FeatureVector actualFeatures = (transformer != null)
745             ? transformer.transform(features) : features;
746 
747         // delegate to abstract method
748         final ContextMap context = new ContextMap();
749         doTrain(actualFeatures, targetClass, context);
750     }
751 
752     /***
753      * Handles error-driven learning ("train only errors"): the specified
754      * feature vector is trained into the model only if the predicted class
755      * for the feature vector differs from the specified target class. If the
756      * prediction was correct, the model is not changed.
757      *
758      * @param features the feature vector to consider
759      * @param targetClass the expected class of this feature vector; must be
760      * contained in the set of <code>candidateClasses</code>
761      * @param candidateClasses an set of classes that are allowed for this item
762      * (the actual <code>targetClass</code> must be one of them)
763      * @return the original prediction distribution if the best prediction was
764      * wrong, i.e. if training was necessary; or <code>null</code> if no
765      * training was necessary (the prediction was already correct)
766      * @throws ProcessingException if an error occurs during training
767      */
768     public final PredictionDistribution trainOnError(
769             final FeatureVector features, final String targetClass,
770             final Set candidateClasses) throws ProcessingException {
771         // transform features after checking that all classes are valid
772         checkTargetClass(targetClass);
773         checkCandidateClass(candidateClasses);
774 
775         // consider all classes for training if configured
776         final Set consideredClasses = trainingAll
777             ? allClasses : candidateClasses;
778 
779         final FeatureVector actualFeatures;
780         final ContextMap context;
781         final PredictionDistribution predDist;
782 
783         // re-used cached result of last classify() if features are identical
784         // (currently equality is not considered, only identical features count)
785         if (features == cachedOrgFeatures) {
786 //            Util.LOG.debug("Re-using cached classification");
787             actualFeatures = cachedActualFeatures;
788             context = cachedContext;
789             predDist = cachedPredictions;
790         } else {
791             // transform features and call abstract method to classify
792             actualFeatures = (transformer != null)
793                 ? transformer.transform(features) : features;
794             context = new ContextMap();
795             predDist = doClassify(actualFeatures, consideredClasses, context);
796         }
797 
798         // delegate to core method
799         final boolean shouldTrain = doTrainOnError(predDist,
800             actualFeatures, targetClass, consideredClasses, context);
801 
802         if (shouldTrain) {
803             // filter prediction distribution to candidate classes
804             if (trainingAll) {
805                 final Iterator predIter = predDist.iterator();
806                 Prediction pred;
807 
808                 while (predIter.hasNext()) {
809                     pred = (Prediction) predIter.next();
810 
811                     if (!candidateClasses.contains(pred.getType())) {
812                         // not a candidate class
813                         predIter.remove();
814                     }
815                 }
816             }
817 
818             // check whether best prediction is correct after filtering
819             if (((predDist.size() > 0)
820                     && !predDist.best().getType().equals(targetClass))) {
821                 // distribution was wrong -- return for analysis
822                 return predDist;
823             } else {
824                 // null signals that prediction was fine
825                 return null;
826             }
827         } else {
828             // null signals that prediction was fine
829             return null;
830         }
831     }
832 
833     /***
834      * Subclasses can implement this hook for more refined error-driven
835      * learning. It is called from the
836      * {@link #trainOnError(FeatureVector, String, Set)} method after
837      * classifying. This method can do any necessary training itself and
838      * return <code>true</code> to signal that no further action is necessary.
839      * This implementation is just a placeholder that always returns
840      * <code>false</code>.
841      *
842      * @param predDist the prediction distribution returned by
843      * {@link #classify(FeatureVector, Set)}
844      * @param features the feature vector to consider
845      * @param targetClass the expected class of this feature vector; must be
846      * contained in the set of <code>candidateClasses</code>
847      * @param candidateClasses an set of classes that are allowed for this item
848      * (the actual <code>targetClass</code> must be one of them)
849      * @param context can be used to transport implementation-specific
850      * contextual information between the
851      * {@link #doClassify(FeatureVector, Set, ContextMap)},
852      * {@link #doTrain(FeatureVector, String, ContextMap)}, and
853      * {@link #trainOnErrorHook(PredictionDistribution, FeatureVector, String,
854      * Set, ContextMap)} methods
855      * @return this implementation always returns <code>false</code>; subclasses
856      * can return <code>true</code> to signal that any error-driven learning was
857      * already handled
858      * @throws ProcessingException if an error occurs during training
859      */
860     protected boolean trainOnErrorHook(final PredictionDistribution predDist,
861             final FeatureVector features, final String targetClass,
862             final Set candidateClasses, final ContextMap context)
863             throws ProcessingException {
864         return false;
865     }
866 
867 }