de.fu_berlin.ties.classify
Class TrainableClassifier

java.lang.Object
  extended byde.fu_berlin.ties.classify.TrainableClassifier
All Implemented Interfaces:
Classifier
Direct Known Subclasses:
ExternalClassifier, Winnow

public abstract class TrainableClassifier
extends Object
implements Classifier

Classifiers extending this abstract class must provide a training mechanism by implementing the doTrain(FeatureVector, String) method. This class supports error-driven learning ("train only errors") which often leads to better prediction models than brute-force training.

The code in this class is thread-safe.

Version:
$Revision: 1.13 $, $Date: 2004/04/13 08:00:13 $, $Author: siefkes $
Author:
Christian Siefkes

Field Summary
 
Fields inherited from interface de.fu_berlin.ties.classify.Classifier
CONFIG_CLASSIFIER
 
Constructor Summary
TrainableClassifier(Set allValidClasses, FeatureTransformer trans)
          Creates a new instance.
 
Method Summary
 PredictionDistribution classify(FeatureVector features, Set candidateClasses)
          Classifies an item that is represented by a feature vector by choosing the most probable class among a set of candidate classes.
static TrainableClassifier createClassifier(Set allValidClasses)
          Factory method that delegates to createClassifier(Set, TiesConfiguration) using the standard configuration.
static TrainableClassifier createClassifier(Set allValidClasses, File runDirectory, TiesConfiguration config)
          Factory method that creates a trainable classifier based on the Classifier.CONFIG_CLASSIFIER key in the provided configuration.
static TrainableClassifier createClassifier(Set allValidClasses, TiesConfiguration config)
          Factory method that delegates to createClassifier(Set, File, TiesConfiguration) without specifying an run directory.
protected abstract  PredictionDistribution doClassify(FeatureVector features, Set candidateClasses)
          Classifies an item that is represented by a feature vector by choosing the most probable class among a set of candidate classes.
protected abstract  void doTrain(FeatureVector features, String targetClass)
          Incorporates an item that is represented by a feature vector into the classification model.
 Set getAllClasses()
          Returns the set of all valid classes.
 String toString()
          Returns a string representation of this object.
 void train(FeatureVector features, String targetClass)
          Incorporates an item that is represented by a feature vector into the classification model.
 PredictionDistribution trainOnError(FeatureVector features, String targetClass, Set candidateClasses)
          Handles error-driven learning ("train only errors"): the specified feature vector is trained into the model only if the predicted class for the feature vector differs from the specified target class.
protected  boolean trainOnErrorHook(PredictionDistribution predDist, FeatureVector features, String targetClass, Set candidateClasses)
          Subclasses can implement this hook for more refined error-driven learning.
 
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
 

Constructor Detail

TrainableClassifier

public TrainableClassifier(Set allValidClasses,
                           FeatureTransformer trans)
Creates a new instance.

Parameters:
allValidClasses - the set of all valid classes
trans - the last transformer in the transformer chain to use, or null if no feature transformers should be used
Method Detail

createClassifier

public static TrainableClassifier createClassifier(Set allValidClasses)
                                            throws IllegalArgumentException,
                                                   ProcessingException
Factory method that delegates to createClassifier(Set, TiesConfiguration) using the standard configuration.

Parameters:
allValidClasses - the set of all valid classes
Returns:
the created classifier
Throws:
IllegalArgumentException - if the value of the Classifier.CONFIG_CLASSIFIER key is missing or invalid
ProcessingException - if an error occurred while creating the classifier

createClassifier

public static TrainableClassifier createClassifier(Set allValidClasses,
                                                   TiesConfiguration config)
                                            throws IllegalArgumentException,
                                                   ProcessingException
Factory method that delegates to createClassifier(Set, File, TiesConfiguration) without specifying an run directory.

Parameters:
allValidClasses - the set of all valid classes
config - the configuration to use
Returns:
the created classifier
Throws:
IllegalArgumentException - if the value of the Classifier.CONFIG_CLASSIFIER key is missing or invalid
ProcessingException - if an error occurred while creating the classifier

createClassifier

public static TrainableClassifier createClassifier(Set allValidClasses,
                                                   File runDirectory,
                                                   TiesConfiguration config)
                                            throws IllegalArgumentException,
                                                   ProcessingException
Factory method that creates a trainable classifier based on the Classifier.CONFIG_CLASSIFIER key in the provided configuration.

Currently supported values: "Ext" for ExternalClassifier, "Winnow" for Winnow, "ucWinnow" for UltraconservativeWinnow.

Otherwise the value must be the qualified name of a TrainableClassifier subclass accepting a Set (of all valid class names) as first argument, a FeatureTransformer as second argument and a TiesConfiguration as third argument.

Parameters:
allValidClasses - the set of all valid classes
runDirectory - the directory to run the classifier in; used for ExternalClassifier instead of the configured directory if not null; ignored otherwise
config - the configuration to use
Returns:
the created classifier
Throws:
IllegalArgumentException - if the value of the Classifier.CONFIG_CLASSIFIER key is missing or invalid
ProcessingException - if an error occurred while creating the classifier

classify

public final PredictionDistribution classify(FeatureVector features,
                                             Set candidateClasses)
                                      throws IllegalArgumentException,
                                             ProcessingException
Classifies an item that is represented by a feature vector by choosing the most probable class among a set of candidate classes. Delegates to the abstract doClassify(FeatureVector, Set) method.

Specified by:
classify in interface Classifier
Parameters:
features - the feature vector to consider
candidateClasses - an set of classes that are allowed for this item
Returns:
the result of the classification; you can call PredictionDistribution.best() to get the most probably class
Throws:
IllegalArgumentException - if the set of valid classes does not contain all candidate classes
ProcessingException - if an error occurs during classification

doClassify

protected abstract PredictionDistribution doClassify(FeatureVector features,
                                                     Set candidateClasses)
                                              throws ProcessingException
Classifies an item that is represented by a feature vector by choosing the most probable class among a set of candidate classes.

Parameters:
features - the feature vector to consider
candidateClasses - an set of classes that are allowed for this item
Returns:
the result of the classification; you can call PredictionDistribution.best() to get the most probably class
Throws:
ProcessingException - if an error occurs during classification

doTrain

protected abstract void doTrain(FeatureVector features,
                                String targetClass)
                         throws ProcessingException
Incorporates an item that is represented by a feature vector into the classification model.

Parameters:
features - the feature vector to consider
targetClass - the class of this feature vector
Throws:
ProcessingException - if an error occurs during training

getAllClasses

public Set getAllClasses()
Returns the set of all valid classes. Each target or candidate class must be contained in this set.

Returns:
an immutable set containing all valid class names

toString

public String toString()
Returns a string representation of this object.

Returns:
a textual representation

train

public final void train(FeatureVector features,
                        String targetClass)
                 throws IllegalArgumentException,
                        ProcessingException
Incorporates an item that is represented by a feature vector into the classification model. Delegates to the abstract doTrain(FeatureVector, String) method.

Parameters:
features - the feature vector to consider
targetClass - the class of this feature vector
Throws:
IllegalArgumentException - if the target class is not in the set of valid classes
ProcessingException - if an error occurs during training

trainOnError

public final PredictionDistribution trainOnError(FeatureVector features,
                                                 String targetClass,
                                                 Set candidateClasses)
                                          throws ProcessingException
Handles error-driven learning ("train only errors"): the specified feature vector is trained into the model only if the predicted class for the feature vector differs from the specified target class. If the prediciton was correct, the model is not changed.

Parameters:
features - the feature vector to consider
targetClass - the expected class of this feature vector; must be contained in the set of candidateClasses
candidateClasses - an set of classes that are allowed for this item (the actual targetClass must be one of them)
Returns:
the original prediction distribution if the best prediction was wrong, i.e. if training was necessary; or null if no training was necessary (the prediction was already correct)
Throws:
ProcessingException - if an error occurs during training

trainOnErrorHook

protected boolean trainOnErrorHook(PredictionDistribution predDist,
                                   FeatureVector features,
                                   String targetClass,
                                   Set candidateClasses)
                            throws ProcessingException
Subclasses can implement this hook for more refined error-driven learning. It is called from the trainOnError(FeatureVector, String, Set) method after classifying. This method can do any necessary training itself and return true to signal that no further action is necessary. This implementation is just a placeholder that always returns false.

Parameters:
predDist - the prediction distribution returned by classify(FeatureVector, Set)
features - the feature vector to consider
targetClass - the expected class of this feature vector; must be contained in the set of candidateClasses
candidateClasses - an set of classes that are allowed for this item (the actual targetClass must be one of them)
Returns:
this implementation always returns false; subclasses can return true to signal that any error-driven learning was already handled
Throws:
ProcessingException - if an error occurs during training


Copyright © 2003-2004 Christian Siefkes. All Rights Reserved.