de.fu_berlin.ties.classify
Class TrainableClassifier

java.lang.Object
  extended by de.fu_berlin.ties.classify.TrainableClassifier
All Implemented Interfaces:
Classifier, XMLStorable
Direct Known Subclasses:
ExternalClassifier, MetaClassifier, MoonClassifier, MultiBinaryClassifier, OneAgainstTheRestClassifier, TieClassifier, Winnow

public abstract class TrainableClassifier
extends Object
implements Classifier, XMLStorable

Classifiers extending this abstract class must provide a training mechanism by implementing the doTrain(FeatureVector, String, ContextMap) method. This class supports error-driven learning ("train only errors") which often leads to better prediction models than brute-force training.

The code in this class is thread-safe.

Version:
$Revision: 1.50 $, $Date: 2006/11/26 21:14:58 $, $Author: siefkes $
Author:
Christian Siefkes

Field Summary
(package private) static QName ATTRIB_CLASSES
          Attribute name used for XML serialization.
(package private) static QName ATTRIB_TRAIN_ALL
          Attribute name used for XML serialization.
static QName ELEMENT_MAIN
          Name of the main element used for XML serialization.
static String META_CLASSIFIER
          Flag used to load the MetaClassifier.
static String MULTI_CLASSIFIER
          Flag used to load the MultiBinaryClassifier.
static String OAR_CLASSIFIER
          Flag used to load the OneAgainstTheRestClassifier.
static String TIE_CLASSIFIER
          Flag used to load the TieClassifier.
 
Fields inherited from interface de.fu_berlin.ties.classify.Classifier
CONFIG_CLASSIFIER
 
Constructor Summary
TrainableClassifier(Element element)
          Creates a new instance from an XML element, fulfilling the recommandation of the XMLStorable interface.
TrainableClassifier(Set<String> allValidClasses, FeatureTransformer trans, boolean trainAll, TiesConfiguration conf)
          Creates a new instance.
TrainableClassifier(Set<String> allValidClasses, FeatureTransformer trans, TiesConfiguration conf)
          Creates a new instance.
 
Method Summary
 PredictionDistribution classify(FeatureVector features, Set candidateClasses)
          Classifies an item that is represented by a feature vector by choosing the most probable class among a set of candidate classes.
static TrainableClassifier createClassifier(Set<String> allValidClasses)
          Factory method that delegates to createClassifier(Set, TiesConfiguration) using the standard configuration.
static TrainableClassifier createClassifier(Set<String> allValidClasses, File runDirectory, FeatureTransformer trans, String[] spec, TiesConfiguration conf)
          Factory method that creates a trainable classifier based on the provided specification.
static TrainableClassifier createClassifier(Set<String> allValidClasses, File runDirectory, TiesConfiguration conf, String suffix)
          Factory method that delegates to createClassifier(Set, File, FeatureTransformer, String[], TiesConfiguration).
static TrainableClassifier createClassifier(Set<String> allValidClasses, TiesConfiguration conf)
          Factory method that delegates to createClassifier(Set, TiesConfiguration, String) without specifying a suffix.
static TrainableClassifier createClassifier(Set<String> allValidClasses, TiesConfiguration conf, String suffix)
          Factory method that delegates to createClassifier(Set, File, TiesConfiguration, String) without specifying an run directory.
 void destroy()
          Destroys the classifer.
protected abstract  PredictionDistribution doClassify(FeatureVector features, Set candidateClasses, ContextMap context)
          Classifies an item that is represented by a feature vector by choosing the most probable class among a set of candidate classes.
protected abstract  void doTrain(FeatureVector features, String targetClass, ContextMap context)
          Incorporates an item that is represented by a feature vector into the classification model.
protected  boolean doTrainOnError(PredictionDistribution predDist, FeatureVector features, String targetClass, Set candidateClasses, ContextMap context)
          The core of the trainOnError(FeatureVector, String, Set) method.
 Set<String> getAllClasses()
          Returns the set of all valid classes.
 TiesConfiguration getConfig()
          Returns the configuration used by this instance.
abstract  void reset()
          Resets the classifer, completely deleting the prediction model.
protected  boolean shouldTrain(String targetClass, PredictionDistribution predDist, ContextMap context)
          Invoked by trainOnError(FeatureVector, String, Set) to decide whether to train an instance.
 ObjectElement toElement()
          Stores all relevant fields of this object in an XML element for serialization. Subclasses of TrainableClassifier should extend this method and the corresponding constructor from Element to ensure (de)serialization works as expected.
 String toString()
          Returns a string representation of this object.
 void train(FeatureVector features, String targetClass)
          Incorporates an item that is represented by a feature vector into the classification model.
 PredictionDistribution trainOnError(FeatureVector features, String targetClass, Set candidateClasses)
          Handles error-driven learning ("train only errors"): the specified feature vector is trained into the model only if the predicted class for the feature vector differs from the specified target class.
protected  boolean trainOnErrorHook(PredictionDistribution predDist, FeatureVector features, String targetClass, Set candidateClasses, ContextMap context)
          Subclasses can implement this hook for more refined error-driven learning.
 
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
 

Field Detail

ELEMENT_MAIN

public static final QName ELEMENT_MAIN
Name of the main element used for XML serialization.


ATTRIB_CLASSES

static final QName ATTRIB_CLASSES
Attribute name used for XML serialization.


ATTRIB_TRAIN_ALL

static final QName ATTRIB_TRAIN_ALL
Attribute name used for XML serialization.


META_CLASSIFIER

public static final String META_CLASSIFIER
Flag used to load the MetaClassifier.

See Also:
Constant Field Values

MULTI_CLASSIFIER

public static final String MULTI_CLASSIFIER
Flag used to load the MultiBinaryClassifier.

See Also:
Constant Field Values

OAR_CLASSIFIER

public static final String OAR_CLASSIFIER
Flag used to load the OneAgainstTheRestClassifier.

See Also:
Constant Field Values

TIE_CLASSIFIER

public static final String TIE_CLASSIFIER
Flag used to load the TieClassifier.

See Also:
Constant Field Values
Constructor Detail

TrainableClassifier

public TrainableClassifier(Element element)
                    throws InstantiationException
Creates a new instance from an XML element, fulfilling the recommandation of the XMLStorable interface.

Parameters:
element - the XML element containing the serialized representation
Throws:
InstantiationException - if the given element does not contain a valid classifier description

TrainableClassifier

public TrainableClassifier(Set<String> allValidClasses,
                           FeatureTransformer trans,
                           TiesConfiguration conf)
Creates a new instance.

Parameters:
allValidClasses - the set of all valid classes
trans - the last transformer in the transformer chain to use, or null if no feature transformers should be used
conf - used to configure this instance

TrainableClassifier

public TrainableClassifier(Set<String> allValidClasses,
                           FeatureTransformer trans,
                           boolean trainAll,
                           TiesConfiguration conf)
Creates a new instance.

Parameters:
allValidClasses - the set of all valid classes; all class names must be printable names
trans - the last transformer in the transformer chain to use, or null if no feature transformers should be used
trainAll - set to true iff the classifier should consider all classes for error-driven training, not only the candidate classes (results are filtered to the candidate classes prior to returning them)
conf - used to configure this instance
Method Detail

createClassifier

public static TrainableClassifier createClassifier(Set<String> allValidClasses)
                                            throws IllegalArgumentException,
                                                   ProcessingException
Factory method that delegates to createClassifier(Set, TiesConfiguration) using the standard configuration.

Parameters:
allValidClasses - the set of all valid classes
Returns:
the created classifier
Throws:
IllegalArgumentException - if the value of the Classifier.CONFIG_CLASSIFIER key is missing or invalid
ProcessingException - if an error occurred while creating the classifier

createClassifier

public static TrainableClassifier createClassifier(Set<String> allValidClasses,
                                                   TiesConfiguration conf)
                                            throws IllegalArgumentException,
                                                   ProcessingException
Factory method that delegates to createClassifier(Set, TiesConfiguration, String) without specifying a suffix.

Parameters:
allValidClasses - the set of all valid classes
conf - the configuration to use
Returns:
the created classifier
Throws:
IllegalArgumentException - if the value of the Classifier.CONFIG_CLASSIFIER key is missing or invalid
ProcessingException - if an error occurred while creating the classifier

createClassifier

public static TrainableClassifier createClassifier(Set<String> allValidClasses,
                                                   TiesConfiguration conf,
                                                   String suffix)
                                            throws IllegalArgumentException,
                                                   ProcessingException
Factory method that delegates to createClassifier(Set, File, TiesConfiguration, String) without specifying an run directory.

Parameters:
allValidClasses - the set of all valid classes
conf - the configuration to use
suffix - an optional suffix that is appended to the Classifier.CONFIG_CLASSIFIER key if not null
Returns:
the created classifier
Throws:
IllegalArgumentException - if the value of the Classifier.CONFIG_CLASSIFIER key is missing or invalid
ProcessingException - if an error occurred while creating the classifier

createClassifier

public static TrainableClassifier createClassifier(Set<String> allValidClasses,
                                                   File runDirectory,
                                                   TiesConfiguration conf,
                                                   String suffix)
                                            throws IllegalArgumentException,
                                                   ProcessingException
Factory method that delegates to createClassifier(Set, File, FeatureTransformer, String[], TiesConfiguration). It reads the specification of the classifier from the Classifier.CONFIG_CLASSIFIER key in the provided configuration. It calls FeatureTransformer.createTransformer(TiesConfiguration) to initialize a transformer chain, if configured.

Parameters:
allValidClasses - the set of all valid classes
runDirectory - the directory to run the classifier in; used for ExternalClassifier instead of the configured directory if not null; ignored otherwise
conf - the configuration to use
suffix - an optional suffix that is appended to the Classifier.CONFIG_CLASSIFIER key if not null
Returns:
the created classifier
Throws:
IllegalArgumentException - if the value of the Classifier.CONFIG_CLASSIFIER key is missing or invalid
ProcessingException - if an error occurred while creating the classifier

createClassifier

public static TrainableClassifier createClassifier(Set<String> allValidClasses,
                                                   File runDirectory,
                                                   FeatureTransformer trans,
                                                   String[] spec,
                                                   TiesConfiguration conf)
                                            throws IllegalArgumentException,
                                                   ProcessingException
Factory method that creates a trainable classifier based on the provided specification.

Currently supported values in the first element of the specification:

Otherwise the first element must be the qualified name of a TrainableClassifier subclass accepting a Set (of all valid class names) as first argument, a FeatureTransformer as second argument and a TiesConfiguration as third argument.

Parameters:
allValidClasses - the set of all valid classes
runDirectory - the directory to run the classifier in; used for ExternalClassifier instead of the configured directory if not null; ignored otherwise
trans - the last transformer in the transformer chain to use, or null if no feature transformers should be used
spec - the specification used to initialize the classifier, as described above
conf - passed to the created classifier to configure itself
Returns:
the created classifier
Throws:
IllegalArgumentException - if the value of the Classifier.CONFIG_CLASSIFIER key is missing or invalid
ProcessingException - if an error occurred while creating the classifier

classify

public final PredictionDistribution classify(FeatureVector features,
                                             Set candidateClasses)
                                      throws IllegalArgumentException,
                                             ProcessingException
Classifies an item that is represented by a feature vector by choosing the most probable class among a set of candidate classes. Delegates to the abstract doClassify(FeatureVector, Set, ContextMap) method.

Specified by:
classify in interface Classifier
Parameters:
features - the feature vector to consider
candidateClasses - an set of classes that are allowed for this item
Returns:
the result of the classification; you can call PredictionDistribution.best() to get the most probably class
Throws:
IllegalArgumentException - if the set of valid classes does not contain all candidate classes
ProcessingException - if an error occurs during classification

destroy

public void destroy()
             throws ProcessingException
Destroys the classifer. This method must be called only if the classifier will never be used again. The default implementation delegates to reset(), but subclasses can overwrite this behaviour if appropriate.

Specified by:
destroy in interface Classifier
Throws:
ProcessingException - if an error occurs while the classifier is being destroyed

doClassify

protected abstract PredictionDistribution doClassify(FeatureVector features,
                                                     Set candidateClasses,
                                                     ContextMap context)
                                              throws ProcessingException
Classifies an item that is represented by a feature vector by choosing the most probable class among a set of candidate classes.

Parameters:
features - the feature vector to consider
candidateClasses - an set of classes that are allowed for this item
context - can be used to transport implementation-specific contextual information between the doClassify(FeatureVector, Set, ContextMap), doTrain(FeatureVector, String, ContextMap), and trainOnErrorHook(PredictionDistribution, FeatureVector, String, Set, ContextMap) methods
Returns:
the result of the classification; you can call PredictionDistribution.best() to get the most probably class
Throws:
ProcessingException - if an error occurs during classification

doTrain

protected abstract void doTrain(FeatureVector features,
                                String targetClass,
                                ContextMap context)
                         throws ProcessingException
Incorporates an item that is represented by a feature vector into the classification model.

Parameters:
features - the feature vector to consider
targetClass - the class of this feature vector
context - can be used to transport implementation-specific contextual information between the doClassify(FeatureVector, Set, ContextMap), doTrain(FeatureVector, String, ContextMap), and trainOnErrorHook(PredictionDistribution, FeatureVector, String, Set, ContextMap) methods
Throws:
ProcessingException - if an error occurs during training

doTrainOnError

protected boolean doTrainOnError(PredictionDistribution predDist,
                                 FeatureVector features,
                                 String targetClass,
                                 Set candidateClasses,
                                 ContextMap context)
                          throws ProcessingException
The core of the trainOnError(FeatureVector, String, Set) method. Generally there is no need for subclasses to modify this method.

Parameters:
predDist - the prediction distribution returned by classify(FeatureVector, Set)
features - the feature vector to consider
targetClass - the expected class of this feature vector; must be contained in the set of candidateClasses
candidateClasses - an set of classes that are allowed for this item (the actual targetClass must be one of them)
context - can be used to transport implementation-specific contextual information between the doClassify(FeatureVector, Set, ContextMap), doTrain(FeatureVector, String, ContextMap), and trainOnErrorHook(PredictionDistribution, FeatureVector, String, Set, ContextMap) methods
Returns:
the result of the shouldTrain(String, PredictionDistribution, ContextMap) method
Throws:
ProcessingException - if an error occurs during training

getAllClasses

public Set<String> getAllClasses()
Returns the set of all valid classes. Each target or candidate class must be contained in this set.

Returns:
an immutable set containing all valid class names

getConfig

public TiesConfiguration getConfig()
Returns the configuration used by this instance.

Returns:
the used configuration

reset

public abstract void reset()
                    throws ProcessingException
Resets the classifer, completely deleting the prediction model.

Throws:
ProcessingException - if an error occurs during reset

shouldTrain

protected boolean shouldTrain(String targetClass,
                              PredictionDistribution predDist,
                              ContextMap context)
Invoked by trainOnError(FeatureVector, String, Set) to decide whether to train an instance. The default behavior is to train if the best prediction was wrong or didn't yield a positive probability ("train only errors"). Subclasses can override this method to add their own behavior, e.g. reinforcement training (thick threshold heuristic).

Parameters:
targetClass - the expected class of this feature vector; must be contained in the set of candidateClasses
predDist - the prediction distribution returned by doClassify(FeatureVector, Set, ContextMap)
context - can be used to transport implementation-specific contextual information between the doClassify(FeatureVector, Set, ContextMap), doTrain(FeatureVector, String, ContextMap), and trainOnErrorHook(PredictionDistribution, FeatureVector, String, Set, ContextMap) methods
Returns:
whether to train this instance

toElement

public ObjectElement toElement()
Stores all relevant fields of this object in an XML element for serialization. An equivalent object can be created by calling ObjectElement.createObject(org.dom4j.Element, Class) on the created element. Subclasses of TrainableClassifier should extend this method and the corresponding constructor from Element to ensure (de)serialization works as expected.

Specified by:
toElement in interface XMLStorable
Returns:
the created XML element

toString

public String toString()
Returns a string representation of this object.

Overrides:
toString in class Object
Returns:
a textual representation

train

public final void train(FeatureVector features,
                        String targetClass)
                 throws IllegalArgumentException,
                        ProcessingException
Incorporates an item that is represented by a feature vector into the classification model. Delegates to the abstract doTrain(FeatureVector, String, ContextMap) method.

Parameters:
features - the feature vector to consider
targetClass - the class of this feature vector
Throws:
IllegalArgumentException - if the target class is not in the set of valid classes
ProcessingException - if an error occurs during training

trainOnError

public final PredictionDistribution trainOnError(FeatureVector features,
                                                 String targetClass,
                                                 Set candidateClasses)
                                          throws ProcessingException
Handles error-driven learning ("train only errors"): the specified feature vector is trained into the model only if the predicted class for the feature vector differs from the specified target class. If the prediction was correct, the model is not changed.

Parameters:
features - the feature vector to consider
targetClass - the expected class of this feature vector; must be contained in the set of candidateClasses
candidateClasses - an set of classes that are allowed for this item (the actual targetClass must be one of them)
Returns:
the original prediction distribution if the best prediction was wrong, i.e. if training was necessary; or null if no training was necessary (the prediction was already correct)
Throws:
ProcessingException - if an error occurs during training

trainOnErrorHook

protected boolean trainOnErrorHook(PredictionDistribution predDist,
                                   FeatureVector features,
                                   String targetClass,
                                   Set candidateClasses,
                                   ContextMap context)
                            throws ProcessingException
Subclasses can implement this hook for more refined error-driven learning. It is called from the trainOnError(FeatureVector, String, Set) method after classifying. This method can do any necessary training itself and return true to signal that no further action is necessary. This implementation is just a placeholder that always returns false.

Parameters:
predDist - the prediction distribution returned by classify(FeatureVector, Set)
features - the feature vector to consider
targetClass - the expected class of this feature vector; must be contained in the set of candidateClasses
candidateClasses - an set of classes that are allowed for this item (the actual targetClass must be one of them)
context - can be used to transport implementation-specific contextual information between the doClassify(FeatureVector, Set, ContextMap), doTrain(FeatureVector, String, ContextMap), and trainOnErrorHook(PredictionDistribution, FeatureVector, String, Set, ContextMap) methods
Returns:
this implementation always returns false; subclasses can return true to signal that any error-driven learning was already handled
Throws:
ProcessingException - if an error occurs during training


Copyright © 2003-2007 Christian Siefkes. All Rights Reserved.