1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22 package de.fu_berlin.ties.classify;
23
24 import java.io.File;
25 import java.util.Collections;
26 import java.util.HashSet;
27 import java.util.Iterator;
28 import java.util.Set;
29
30 import org.apache.commons.lang.ArrayUtils;
31 import org.apache.commons.lang.builder.ToStringBuilder;
32
33 import de.fu_berlin.ties.ContextMap;
34 import de.fu_berlin.ties.ProcessingException;
35 import de.fu_berlin.ties.TiesConfiguration;
36 import de.fu_berlin.ties.classify.feature.FeatureTransformer;
37 import de.fu_berlin.ties.classify.feature.FeatureVector;
38 import de.fu_berlin.ties.classify.winnow.UltraconservativeWinnow;
39 import de.fu_berlin.ties.classify.winnow.Winnow;
40 import de.fu_berlin.ties.util.Util;
41 import de.fu_berlin.ties.xml.io.ObjectElement;
42 import de.fu_berlin.ties.xml.io.XMLStorable;
43
44 /***
45 * Classifiers extending this abstract class must provide a training mechanism
46 * by implementing the {@link #doTrain(FeatureVector, String, ContextMap)}
47 * method. This class supports error-driven learning ("train only errors")
48 * which often leads to better prediction models than brute-force training.
49 *
50 * <p>The code in this class is thread-safe.
51 *
52 * @author Christian Siefkes
53 * @version $Revision: 1.31 $, $Date: 2004/12/09 18:09:14 $, $Author: siefkes $
54 */
55 public abstract class TrainableClassifier implements Classifier, XMLStorable {
56
57 /***
58 * Flag used to load the {@link MetaClassifier}.
59 */
60 public static final String META_CLASSIFIER = "meta";
61
62 /***
63 * Flag used to load the {@link MultiBinaryClassifier}.
64 */
65 public static final String MULTI_CLASSIFIER = "multi";
66
67 /***
68 * Flag used to load the {@link OneAgainstTheRestClassifier}.
69 */
70 public static final String OAR_CLASSIFIER = "oar";
71
72 /***
73 * Names of classifiers wrapping inner classifiers.
74 */
75 private static final Set<String> wrappingClassifiers;
76
77 /***
78 * Static initialization of set of wrapping classifiers.
79 */
80 static {
81 final Set<String> wrapping = new HashSet<String>();
82 wrapping.add(META_CLASSIFIER);
83 wrapping.add(MULTI_CLASSIFIER);
84 wrapping.add(OAR_CLASSIFIER);
85 wrappingClassifiers = Collections.unmodifiableSet(wrapping);
86 }
87
88 /***
89 * Factory method that delegates to
90 * {@link #createClassifier(Set, TiesConfiguration)} using the
91 * {@linkplain TiesConfiguration#CONF standard configuration}.
92 *
93 * @param allValidClasses the set of all valid classes
94 * @return the created classifier
95 * @throws IllegalArgumentException if the value of the
96 * {@link #CONFIG_CLASSIFIER} key is missing or invalid
97 * @throws ProcessingException if an error occurred while creating the
98 * classifier
99 */
100 public static TrainableClassifier createClassifier(
101 final Set<String> allValidClasses)
102 throws IllegalArgumentException, ProcessingException {
103 return createClassifier(allValidClasses, TiesConfiguration.CONF);
104 }
105
106 /***
107 * Factory method that delegates to
108 * {@link #createClassifier(Set, File, TiesConfiguration)} without
109 * specifying an run directory.
110 *
111 * @param allValidClasses the set of all valid classes
112 * @param conf the configuration to use
113 * @return the created classifier
114 * @throws IllegalArgumentException if the value of the
115 * {@link #CONFIG_CLASSIFIER} key is missing or invalid
116 * @throws ProcessingException if an error occurred while creating the
117 * classifier
118 */
119 public static TrainableClassifier createClassifier(
120 final Set<String> allValidClasses, final TiesConfiguration conf)
121 throws IllegalArgumentException, ProcessingException {
122 return createClassifier(allValidClasses, null, conf);
123 }
124
125 /***
126 * Factory method that delegates to
127 * {@link #createClassifier(Set, File, FeatureTransformer, String[],
128 * TiesConfiguration)}. It reads the specification of the classifier from
129 * the {@link #CONFIG_CLASSIFIER} key in the provided configuration. It
130 * calls {@link FeatureTransformer#createTransformer(TiesConfiguration)} to
131 * initialize a transformer chain, if configured.
132 *
133 * @param allValidClasses the set of all valid classes
134 * @param runDirectory the directory to run the classifier in; used for
135 * {@link ExternalClassifier} instead of the
136 * {@linkplain ExternalClassifier#CONFIG_DIR configured directory}
137 * if not <code>null</code>; ignored otherwise
138 * @param conf the configuration to use
139 * @return the created classifier
140 * @throws IllegalArgumentException if the value of the
141 * {@link #CONFIG_CLASSIFIER} key is missing or invalid
142 * @throws ProcessingException if an error occurred while creating the
143 * classifier
144 */
145 public static TrainableClassifier createClassifier(
146 final Set<String> allValidClasses, final File runDirectory,
147 final TiesConfiguration conf)
148 throws IllegalArgumentException, ProcessingException {
149
150 final String[] spec = conf.getStringArray(CONFIG_CLASSIFIER);
151
152
153 final FeatureTransformer trans =
154 FeatureTransformer.createTransformer(conf);
155 return createClassifier(allValidClasses, runDirectory, trans,
156 spec, conf);
157 }
158
159 /***
160 * Factory method that creates a trainable classifier based on the
161 * provided specification.
162 *
163 * <p>Currently supported values in the first element of the specification:
164 *
165 * <ul>
166 * <li>"Ext" for {@link ExternalClassifier}
167 * <li>"Winnow" for {@link Winnow}
168 * <li>"ucWinnow" for {@link UltraconservativeWinnow}
169 * <li>"Meta" followed by the specification of the inner
170 * classifiers as further element(s) for {@link MetaClassifier}
171 * <li>"Multi" or "OAR" followed by the specification of the inner
172 * classifiers as further element(s) for {@link MultiBinaryClassifier} resp.
173 * {@link OneAgainstTheRestClassifier} (if there are only two classes
174 * to classify, the outer classifer is skipped and the inner classifier is
175 * used directly).
176 * </ul>
177 *
178 * <p>Otherwise the first element must be the qualified name of a
179 * TrainableClassifier subclass accepting a {@link Set} (of all valid class
180 * names) as first argument, a {@link FeatureTransformer} as second argument
181 * and a {@link TiesConfiguration} as third argument.
182 *
183 * @param allValidClasses the set of all valid classes
184 * @param runDirectory the directory to run the classifier in; used for
185 * {@link ExternalClassifier} instead of the
186 * {@linkplain ExternalClassifier#CONFIG_DIR configured directory}
187 * if not <code>null</code>; ignored otherwise
188 * @param trans the last transformer in the transformer chain to use, or
189 * <code>null</code> if no feature transformers should be used
190 * @param spec the specification used to initialize the classifier, as
191 * described above
192 * @param conf passed to the created classifier to configure itself
193 * @return the created classifier
194 * @throws IllegalArgumentException if the value of the
195 * {@link #CONFIG_CLASSIFIER} key is missing or invalid
196 * @throws ProcessingException if an error occurred while creating the
197 * classifier
198 */
199 public static TrainableClassifier createClassifier(
200 final Set<String> allValidClasses, final File runDirectory,
201 final FeatureTransformer trans, final String[] spec,
202 final TiesConfiguration conf)
203 throws IllegalArgumentException, ProcessingException {
204 if ((spec == null) || spec.length < 1) {
205 throw new IllegalArgumentException(
206 "Cannot create classifier -- specification is null or empty");
207 }
208
209
210
211 final String lowerValue = spec[0].toLowerCase();
212 final TrainableClassifier result;
213
214 if ("ext".equals(lowerValue)) {
215
216 result = new ExternalClassifier(allValidClasses, trans,
217 runDirectory, conf);
218 } else if ("winnow".equals(lowerValue)) {
219
220 result = new Winnow(allValidClasses, trans, conf);
221 } else if ("ucwinnow".equals(lowerValue)) {
222
223 result =
224 new UltraconservativeWinnow(allValidClasses, trans, conf);
225 } else if (wrappingClassifiers.contains(lowerValue)) {
226
227 final String[] innerSpec = new String[spec.length - 1];
228
229
230 for (int i = 0; i < innerSpec.length; i++) {
231 innerSpec[i] = spec[i + 1];
232 }
233
234 if (META_CLASSIFIER.equals(lowerValue)) {
235
236 result = new MetaClassifier(allValidClasses, trans,
237 runDirectory, innerSpec, conf);
238 } else if (allValidClasses.size() > 2) {
239
240
241
242 if (MULTI_CLASSIFIER.equals(lowerValue)) {
243
244 result = new MultiBinaryClassifier(allValidClasses, trans,
245 runDirectory, innerSpec, conf);
246 } else if (OAR_CLASSIFIER.equals(lowerValue)) {
247
248 result = new OneAgainstTheRestClassifier(allValidClasses,
249 trans, runDirectory, innerSpec, conf);
250 } else {
251
252 throw new RuntimeException("Implementation error: "
253 + "unknown wrapping classifier" + spec[0]);
254 }
255 } else {
256
257
258 result = createClassifier(allValidClasses, runDirectory, trans,
259 innerSpec, conf);
260 }
261 } else {
262
263
264
265 try {
266 result = (TrainableClassifier) Util.createObject(
267 Class.forName(spec[0]),
268 new Object[] {allValidClasses, trans, conf},
269 new Class[] {Set.class, FeatureTransformer.class,
270 TiesConfiguration.class});
271 } catch (ClassNotFoundException cnfe) {
272
273 throw new ProcessingException(
274 "Cannot create classifier from specification "
275 + ArrayUtils.toString(spec) + ": " + cnfe.toString());
276 } catch (InstantiationException ie) {
277
278 throw new ProcessingException(
279 "Cannot create classifier from specification "
280 + ArrayUtils.toString(spec), ie);
281 }
282 }
283 return result;
284 }
285
286 /***
287 * The immutable set of all valid classes. Each target or candidate class
288 * must be contained in this set.
289 */
290 private final Set<String> allClasses;
291
292 /***
293 * Used to configure this instance.
294 */
295 private final TiesConfiguration config;
296
297 /***
298 * If <code>true</code> the classifier considers all classes for
299 * error-driven training, not only the candidate classes (results are
300 * filtered to the candidate classes prior to returning them).
301 */
302 private final boolean trainingAll;
303
304 /***
305 * The last transformer in a transformer chain, or <code>null</code> if
306 * no feature transformers are used.
307 */
308 private final FeatureTransformer transformer;
309
310 /***
311 * Creates a new instance.
312 *
313 * @param allValidClasses the set of all valid classes
314 * @param trans the last transformer in the transformer chain to use, or
315 * <code>null</code> if no feature transformers should be used
316 * @param conf used to configure this instance
317 */
318 public TrainableClassifier(final Set<String> allValidClasses,
319 final FeatureTransformer trans, final TiesConfiguration conf) {
320 super();
321
322
323 allClasses = Collections.unmodifiableSet(allValidClasses);
324 config = conf;
325 transformer = trans;
326 trainingAll = conf.getBoolean("classifier.train.all");
327 }
328
329 /***
330 * Ensure that all candidate classes are valid (contained in the set of all
331 * classes), throwing an exception otherwise.
332 *
333 * @param candidateClasses an set of classes that are allowed for this item
334 * (the actual <code>targetClass</code> must be one of them)
335 * @throws IllegalArgumentException if not all candidate classes are members
336 * of the {@linkplain #getAllClasses() set of valid classes}
337 */
338 private void checkCandidateClass(final Set candidateClasses)
339 throws IllegalArgumentException {
340 final Iterator classIter = candidateClasses.iterator();
341 String currentClass;
342
343 while (classIter.hasNext()) {
344 currentClass = (String) classIter.next();
345 if (!allClasses.contains(currentClass)) {
346 throw new IllegalArgumentException("Candidate class "
347 + currentClass
348 + " is not in the set of valid classes: " + allClasses);
349 }
350 }
351 }
352
353 /***
354 * Ensure that the target class is valid (contained in the set of all
355 * classes), throwing an exception otherwise.
356 *
357 * @param targetClass the expected class of an instance; must be
358 * contained in the set of <code>candidateClasses</code>
359 * @throws IllegalArgumentException if the target class is not a member of
360 * the {@linkplain #getAllClasses() set of valid classes}
361 */
362 private void checkTargetClass(final String targetClass)
363 throws IllegalArgumentException {
364
365 if (!allClasses.contains(targetClass)) {
366 throw new IllegalArgumentException("Target class " + targetClass
367 + " is not in the set of valid classes: " + allClasses);
368 }
369 }
370
371 /***
372 * Classifies an item that is represented by a feature vector by choosing
373 * the most probable class among a set of candidate classes. Delegates to
374 * the abstract {@link #doClassify(FeatureVector, Set, ContextMap)} method.
375 *
376 * @param features the feature vector to consider
377 * @param candidateClasses an set of classes that are allowed for this item
378 * @return the result of the classification; you can call
379 * {@link PredictionDistribution#best()} to get the most probably class
380 * @throws IllegalArgumentException if the
381 * {@linkplain #getAllClasses() set of valid classes} does not contain all
382 * candidate classes
383 * @throws ProcessingException if an error occurs during classification
384 */
385 public final PredictionDistribution classify(final FeatureVector features,
386 final Set candidateClasses)
387 throws IllegalArgumentException, ProcessingException {
388
389 checkCandidateClass(candidateClasses);
390
391
392 final FeatureVector actualFeatures = (transformer != null)
393 ? transformer.transform(features) : features;
394
395
396 final ContextMap context = new ContextMap();
397 return doClassify(actualFeatures, candidateClasses, context);
398 }
399
400 /***
401 * Classifies an item that is represented by a feature vector by choosing
402 * the most probable class among a set of candidate classes.
403 *
404 * @param features the feature vector to consider
405 * @param candidateClasses an set of classes that are allowed for this item
406 * @param context can be used to transport implementation-specific
407 * contextual information between the
408 * {@link #doClassify(FeatureVector, Set, ContextMap)},
409 * {@link #doTrain(FeatureVector, String, ContextMap)}, and
410 * {@link #trainOnErrorHook(PredictionDistribution, FeatureVector, String,
411 * Set, ContextMap)} methods
412 * @return the result of the classification; you can call
413 * {@link PredictionDistribution#best()} to get the most probably class
414 * @throws ProcessingException if an error occurs during classification
415 */
416 protected abstract PredictionDistribution doClassify(
417 final FeatureVector features, final Set candidateClasses,
418 final ContextMap context)
419 throws ProcessingException;
420
421 /***
422 * Incorporates an item that is represented by a feature vector into the
423 * classification model.
424 *
425 * @param features the feature vector to consider
426 * @param targetClass the class of this feature vector
427 * @param context can be used to transport implementation-specific
428 * contextual information between the
429 * {@link #doClassify(FeatureVector, Set, ContextMap)},
430 * {@link #doTrain(FeatureVector, String, ContextMap)}, and
431 * {@link #trainOnErrorHook(PredictionDistribution, FeatureVector, String,
432 * Set, ContextMap)} methods
433 * @throws ProcessingException if an error occurs during training
434 */
435 protected abstract void doTrain(final FeatureVector features,
436 final String targetClass, final ContextMap context)
437 throws ProcessingException;
438
439 /***
440 * Returns the set of all valid classes. Each target or candidate class
441 * must be contained in this set.
442 *
443 * @return an immutable set containing all valid class names
444 */
445 public Set getAllClasses() {
446 return allClasses;
447 }
448
449 /***
450 * Returns the configuration used by this instance.
451 * @return the used configuration
452 */
453 public TiesConfiguration getConfig() {
454 return config;
455 }
456
457 /***
458 * Resets the classifer, completely deleting the prediction model.
459 * @throws ProcessingException if an error occurs during reset
460 */
461 public abstract void reset() throws ProcessingException;
462
463 /***
464 * Invoked by {@link #trainOnError(FeatureVector, String, Set)} to decide
465 * whether to train an instance. The default behavior is to train if the
466 * best prediction was wrong or didn't yield a positive probability
467 * ("train only errors"). Subclasses can override this method to
468 * add their own behavior, e.g. reinforcement training (thick threshold
469 * heuristic).
470 *
471 * @param targetClass the expected class of this feature vector; must be
472 * contained in the set of <code>candidateClasses</code>
473 * @param predDist the prediction distribution returned by
474 * {@link #doClassify(FeatureVector, Set, ContextMap)}
475 * @param context can be used to transport implementation-specific
476 * contextual information between the
477 * {@link #doClassify(FeatureVector, Set, ContextMap)},
478 * {@link #doTrain(FeatureVector, String, ContextMap)}, and
479 * {@link #trainOnErrorHook(PredictionDistribution, FeatureVector, String,
480 * Set, ContextMap)} methods
481 * @return whether to train this instance
482 */
483 protected boolean shouldTrain(final String targetClass,
484 final PredictionDistribution predDist, final ContextMap context) {
485 final Prediction best = predDist.best();
486 final double bestProb = best.getProbability().getProb();
487 return !best.getType().equals(targetClass) || Double.isNaN(bestProb)
488 || (bestProb <= 0.0);
489 }
490
491 /***
492 * {@inheritDoc}
493 * Subclasses of {@link TrainableClassifier} should extend this method and
494 * the corresponding constructor from {@link org.dom4j.Element} to
495 * ensure (de)serialization works as expected.
496 */
497 public ObjectElement toElement() {
498 final ObjectElement result =
499 new ObjectElement("classifier", this.getClass());
500
501 return result;
502 }
503
504 /***
505 * Returns a string representation of this object.
506 *
507 * @return a textual representation
508 */
509 public String toString() {
510 final ToStringBuilder builder = new ToStringBuilder(this);
511 builder.append("all classes", allClasses);
512
513 if (trainingAll) {
514 builder.append("training all classes", trainingAll);
515 }
516
517 if (transformer != null) {
518 builder.append("transformer", transformer);
519 }
520
521 return builder.toString();
522 }
523
524 /***
525 * Incorporates an item that is represented by a feature vector into the
526 * classification model. Delegates to the abstract
527 * {@link #doTrain(FeatureVector, String, ContextMap)} method.
528 *
529 * @param features the feature vector to consider
530 * @param targetClass the class of this feature vector
531 * @throws IllegalArgumentException if the target class is not in the
532 * {@linkplain #getAllClasses() set of valid classes}
533 * @throws ProcessingException if an error occurs during training
534 */
535 public final void train(final FeatureVector features,
536 final String targetClass)
537 throws IllegalArgumentException, ProcessingException {
538
539 checkTargetClass(targetClass);
540
541
542 final FeatureVector actualFeatures = (transformer != null)
543 ? transformer.transform(features) : features;
544
545
546 final ContextMap context = new ContextMap();
547 doTrain(actualFeatures, targetClass, context);
548 }
549
550 /***
551 * Handles error-driven learning ("train only errors"): the specified
552 * feature vector is trained into the model only if the predicted class
553 * for the feature vector differs from the specified target class. If the
554 * prediction was correct, the model is not changed.
555 *
556 * @param features the feature vector to consider
557 * @param targetClass the expected class of this feature vector; must be
558 * contained in the set of <code>candidateClasses</code>
559 * @param candidateClasses an set of classes that are allowed for this item
560 * (the actual <code>targetClass</code> must be one of them)
561 * @return the original prediction distribution if the best prediction was
562 * wrong, i.e. if training was necessary; or <code>null</code> if no
563 * training was necessary (the prediction was already correct)
564 * @throws ProcessingException if an error occurs during training
565 */
566 public final PredictionDistribution trainOnError(
567 final FeatureVector features, final String targetClass,
568 final Set candidateClasses) throws ProcessingException {
569
570 checkTargetClass(targetClass);
571 checkCandidateClass(candidateClasses);
572 final FeatureVector actualFeatures = (transformer != null)
573 ? transformer.transform(features) : features;
574
575
576 final Set consideredClasses = trainingAll
577 ? allClasses : candidateClasses;
578 final ContextMap context = new ContextMap();
579
580
581 final PredictionDistribution predDist =
582 doClassify(actualFeatures, consideredClasses, context);
583
584
585 final boolean hookHandledTraining = trainOnErrorHook(predDist,
586 actualFeatures, targetClass, consideredClasses, context);
587
588 if (shouldTrain(targetClass, predDist, context)) {
589
590 if (!hookHandledTraining) {
591 doTrain(actualFeatures, targetClass, context);
592 }
593
594
595 if (trainingAll) {
596 final Iterator predIter = predDist.iterator();
597 Prediction pred;
598
599 while (predIter.hasNext()) {
600 pred = (Prediction) predIter.next();
601
602 if (!candidateClasses.contains(pred.getType())) {
603
604 predIter.remove();
605 }
606 }
607 }
608
609
610 if (((predDist.size() > 0)
611 && !predDist.best().getType().equals(targetClass))) {
612
613 return predDist;
614 } else {
615
616 return null;
617 }
618 } else {
619
620 return null;
621 }
622 }
623
624 /***
625 * Subclasses can implement this hook for more refined error-driven
626 * learning. It is called from the
627 * {@link #trainOnError(FeatureVector, String, Set)} method after
628 * classifying. This method can do any necessary training itself and
629 * return <code>true</code> to signal that no further action is necessary.
630 * This implementation is just a placeholder that always returns
631 * <code>false</code>.
632 *
633 * @param predDist the prediction distribution returned by
634 * {@link #classify(FeatureVector, Set)}
635 * @param features the feature vector to consider
636 * @param targetClass the expected class of this feature vector; must be
637 * contained in the set of <code>candidateClasses</code>
638 * @param candidateClasses an set of classes that are allowed for this item
639 * (the actual <code>targetClass</code> must be one of them)
640 * @param context can be used to transport implementation-specific
641 * contextual information between the
642 * {@link #doClassify(FeatureVector, Set, ContextMap)},
643 * {@link #doTrain(FeatureVector, String, ContextMap)}, and
644 * {@link #trainOnErrorHook(PredictionDistribution, FeatureVector, String,
645 * Set, ContextMap)} methods
646 * @return this implementation always returns <code>false</code>; subclasses
647 * can return <code>true</code> to signal that any error-driven learning was
648 * already handled
649 * @throws ProcessingException if an error occurs during training
650 */
651 protected boolean trainOnErrorHook(final PredictionDistribution predDist,
652 final FeatureVector features, final String targetClass,
653 final Set candidateClasses, final ContextMap context)
654 throws ProcessingException {
655 return false;
656 }
657
658 }