View Javadoc

1   /*
2    * Copyright (C) 2004 Christian Siefkes <christian@siefkes.net>.
3    * Development of this software is supported by the German Research Society,
4    * Berlin-Brandenburg Graduate School in Distributed Information Systems
5    * (DFG grant no. GRK 316).
6    *
7    * This library is free software; you can redistribute it and/or
8    * modify it under the terms of the GNU Lesser General Public
9    * License as published by the Free Software Foundation; either
10   * version 2.1 of the License, or (at your option) any later version.
11   *
12   * This library is distributed in the hope that it will be useful,
13   * but WITHOUT ANY WARRANTY; without even the implied warranty of
14   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
15   * Lesser General Public License for more details.
16   *
17   * You should have received a copy of the GNU Lesser General Public
18   * License along with this library; if not, visit
19   * http://www.gnu.org/licenses/lgpl.html or write to the Free Software
20   * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307, USA.
21   */
22  package de.fu_berlin.ties.classify;
23  
24  import java.io.File;
25  import java.util.ArrayList;
26  import java.util.HashMap;
27  import java.util.Iterator;
28  import java.util.List;
29  import java.util.Map;
30  import java.util.Set;
31  import java.util.SortedSet;
32  import java.util.TreeSet;
33  
34  import org.apache.commons.lang.builder.ToStringBuilder;
35  
36  import de.fu_berlin.ties.ContextMap;
37  import de.fu_berlin.ties.ProcessingException;
38  import de.fu_berlin.ties.TiesConfiguration;
39  import de.fu_berlin.ties.classify.feature.FeatureTransformer;
40  import de.fu_berlin.ties.classify.feature.FeatureVector;
41  
42  /***
43   * This classifier converts an multi-class classification task into a several
44   * binary (two-class) classification task. It wraps several instances of
45   * another classifier that are used to perform the binary classifications and
46   * combines their results.
47   *
48   * <p>When used with <em>N</em> classes, this classifier created and wrapped
49   * <em>N</em> inner binary classifiers. Each inner classifier classifiers
50   * between members (positive instances) and non-members (negative instances) of
51   * a specified class. The inner classifier returning the highest
52   * {@linkplain de.fu_berlin.ties.classify.Prediction#getProbability()
53   * probability} for its (positive) class wins, even if the negative class of
54   * that classifier would be still higher. This is the classical
55   * "one-against-the-rest" (or "one-against-all") setup.
56   *
57   * <p>This classifier differs from
58   * {@link de.fu_berlin.ties.classify.MultiBinaryClassifier} in not using a
59   * "background class" and using <em>N</em> inner classifiers, while
60   * <code>MultiBinaryClassifier</code> uses <em>N-1</em>.
61   *
62   * <p>Instances of this class are thread-safe if and only if instances of the
63   * wrapped classifier are.
64   *
65   * <p><em><strong>WARNING:</strong> The current implementation does
66   * <strong>not</strong> query the {@link
67   * de.fu_berlin.ties.classify.TrainableClassifier#shouldTrain(String,
68   * PredictionDistribution, ContextMap) shouldTrain} method of inner classifiers.
69   * Because of this, classifiers overwriting <code>shouldTrain</code> might not
70   * work correctly within this classifier.</em>
71   *
72   * @author Christian Siefkes
73   * @version $Revision: 1.2 $, $Date: 2004/11/19 14:04:19 $, $Author: siefkes $
74   */
75  public class OneAgainstTheRestClassifier extends TrainableClassifier {
76  
77      /***
78       * Key prefix used to store the contexts of inner classifiers.
79       */
80      private static final String PREFIX_CONTEXT = "context-";
81  
82      /***
83       * Key prefix used to store the prediction distributions returned by inner
84       * classifiers.
85       */
86      private static final String PREFIX_DIST = "dist-";
87  
88      /***
89       * The suffix appended to negative classes.
90       */
91      private static final String NEG_SUFFIX = "-neg";
92  
93      /***
94       * Maps from the names of the classes to the binary classifiers
95       * used to decide between positive and negative instances of this class.
96       */
97      private final Map<String, TrainableClassifier> innerClassifiers =
98          new HashMap<String, TrainableClassifier>();
99  
100     /***
101      * Creates a new instance.
102      *
103      * @param allValidClasses the set of all valid classes
104      * @param trans the last transformer in the transformer chain to use, or
105      * <code>null</code> if no feature transformers should be used
106      * @param runDirectory optional run directory passed to inner classifiers
107      * of the {@link ExternalClassifier} type
108      * @param innerSpec the specification used to initialize the inner
109      * classifiers, passed to the
110      * {@link TrainableClassifier#createClassifier(Set, File,
111      * FeatureTransformer, String[], TiesConfiguration)} factory method
112      * @param conf used to configure this instance and the inner classifiers
113      * @throws IllegalArgumentException if there are fewer than three classes
114      * (in which case you should use the inner classifier directly since there
115      * is no need to wrap several instances)
116      * @throws ProcessingException if an error occurred while creating this
117      * classifier or one of the wrapped classifiers
118      */
119     public OneAgainstTheRestClassifier(final Set<String> allValidClasses,
120                                  final FeatureTransformer trans,
121                                  final File runDirectory,
122                                  final String[] innerSpec,
123                                  final TiesConfiguration conf)
124     throws IllegalArgumentException, ProcessingException {
125         super(allValidClasses, trans, conf);
126 
127         if (allValidClasses.size() < 3) {
128             throw new IllegalArgumentException("OneAgainstTheRestClassifier "
129                     + "requires at least 3 classes instead of "
130                     + allValidClasses.size());
131         }
132 
133         // create inner classifiers via factory method
134         final Iterator<String> classIter= allValidClasses.iterator();
135         String baseClass;
136         TrainableClassifier innerClassifier;
137         Set<String> innerSet;
138 
139         while (classIter.hasNext()) {
140             // each inner classifier decides between a positive class and a
141             // negative class
142             baseClass = classIter.next();
143             innerSet = createBinarySet(baseClass);
144 
145             // transformer is set to null because features shouldn't be
146             // transformed twice
147             innerClassifier = TrainableClassifier.createClassifier(innerSet,
148                 runDirectory, null, innerSpec, conf);
149             innerClassifiers.put(baseClass, innerClassifier);
150         }
151     }
152 
153     /***
154      * Helper method that creates a set containing the two classes of a
155      * binary classifier.
156      *
157      * @param baseClass the base class to use
158      * @return a set containing a positive and a negative class derived from the
159      * base class; this implementation returns the two classes in alphabetic
160      * order which means that the negative class comes first
161      */
162     protected Set<String> createBinarySet(final String baseClass) {
163         final SortedSet<String> result = new TreeSet<String>();
164         result.add(baseClass);
165         result.add(baseClass + NEG_SUFFIX);
166         return result;
167     }
168 
169     /***
170      * {@inheritDoc}
171      * This implementation combines the predictions for the positive class
172      * of all involved inner classifiers.
173      *
174      * <p><em>The probability estimates returned by each classifier are used
175      * "as is", so the result will <strong>not</strong> be a real probability
176      * distribution because sum of all probabilities will be more than 1.
177      * If you want to work on a real probability distribution you have
178      * normalize it yourself.</em>
179      */
180     protected PredictionDistribution doClassify(final FeatureVector features,
181                                                 final Set candidateClasses,
182                                                 final ContextMap context)
183     throws ProcessingException {
184         final List<Prediction> unnormalizedPreds =
185             new ArrayList<Prediction>(candidateClasses.size());
186         final Iterator classIter = candidateClasses.iterator();
187         String currentClass;
188         TrainableClassifier innerClassifier;
189         PredictionDistribution currentDist;
190         Prediction currentPred;
191         String predClass;
192         ContextMap currentContext;
193         Iterator predIter;
194 //        double summedProb = 0.0;
195 
196         // delegate to doClassify methods of all candidat classes
197         while (classIter.hasNext()) {
198             currentClass = (String) classIter.next();
199 
200             // classify all candidat classeses
201             currentContext = new ContextMap();
202             innerClassifier = innerClassifiers.get(currentClass);
203             currentDist = innerClassifier.doClassify(features,
204                 innerClassifier.getAllClasses(), currentContext);
205 
206             // store prediction distribution and context map of this
207             // classifier in my own context
208             context.put(PREFIX_CONTEXT + currentClass, currentContext);
209             context.put(PREFIX_DIST + currentClass, currentDist);
210 
211             predIter = currentDist.iterator();
212 
213             // iterate prediction distribution of this classifier
214             while (predIter.hasNext()) {
215                 currentPred = (Prediction) predIter.next();
216                 predClass = currentPred.getType();
217 
218                 if (predClass.equals(currentClass)) {
219                     // use predictions of positive classes only
220                     unnormalizedPreds.add(currentPred);
221 //                  summedProb += currentPred.getProbability();                    
222                 } else if (!predClass.equals(currentClass + NEG_SUFFIX)) {
223                     // just to make sure that everything is as it should be
224                     throw new RuntimeException("Implementation error: "
225                             + "unexpected class name " + predClass + " used by "
226                             + currentClass + " classifier");
227                 }
228             }
229         }
230 
231         final PredictionDistribution result = new PredictionDistribution();
232         predIter = unnormalizedPreds.iterator();
233 
234         // build full distribution for positive classes
235         while (predIter.hasNext()) {
236             currentPred = (Prediction) predIter.next();
237             result.add(currentPred);
238 
239 /*            // normalize probabilities so they sum to 1
240             result.add(new Prediction(currentPred.getType(),
241                 currentPred.getSource(),
242                 currentPred.getProbability() / summedProb,
243                 currentPred.getPR(), currentPred.getEvalStatus())); */
244         }
245 
246         //Util.LOG.debug("Combined probability distribution: "
247         //        + result.toString());
248         //Util.LOG.debug("Current context: " + context);
249 
250         return result;
251     }
252 
253     /***
254      * {@inheritDoc}
255      */
256     protected void doTrain(final FeatureVector features,
257                            final String targetClass,
258                            final ContextMap context)
259     throws ProcessingException {
260         final Iterator classIter = innerClassifiers.keySet().iterator();
261         String currentClass;
262         TrainableClassifier classifier;
263         String classToTrain;
264         String contextKey;
265         ContextMap currentContext;
266 
267         // delegate to doTrain methods of all classifiers
268         while (classIter.hasNext()) {
269             currentClass = (String) classIter.next();
270             classifier = innerClassifiers.get(currentClass);
271 
272             if (currentClass.equals(targetClass)) {
273                 // classifier for the target class
274                 classToTrain = targetClass;
275             } else {
276                 // other classifier: train negative class
277                 classToTrain = currentClass + NEG_SUFFIX;
278             }
279 
280             // retrieve original context map, if stored
281             contextKey = PREFIX_CONTEXT + currentClass;
282             if (context.containsKey(contextKey)) {
283                 currentContext = (ContextMap) context.get(contextKey);
284             } else {
285                 currentContext = new ContextMap();
286             }
287 
288             classifier.doTrain(features, classToTrain, currentContext);
289         }
290     }
291 
292     /***
293      * Returns a string representation of this object.
294      *
295      * @return a textual representation
296      */
297     public String toString() {
298         return new ToStringBuilder(this)
299             .appendSuper(super.toString())
300             .append("inner classifiers", innerClassifiers.values())
301             .toString();
302     }
303 
304     /***
305      * {@inheritDoc}
306      */
307     protected boolean trainOnErrorHook(final PredictionDistribution predDist,
308             final FeatureVector features, final String targetClass,
309             final Set candidateClasses, final ContextMap context)
310     throws ProcessingException {
311         boolean result = false;
312         final Iterator predIter = predDist.iterator();
313         Prediction currentPred;
314         String currentClass;
315         TrainableClassifier classifier;
316         ContextMap currentContext;
317         PredictionDistribution currentDist;
318         String classToTrain;
319 
320         // delegate to all involved classifiers, returning ORing of results
321         // (i.e. return true if at least one of them returns true)
322         while (predIter.hasNext()) {
323             currentPred = (Prediction) predIter.next();
324             currentClass = currentPred.getType();
325             classifier = innerClassifiers.get(currentClass);
326 
327             // retrieve original context map + prediction distribution
328             currentContext = (ContextMap)
329                 context.get(PREFIX_CONTEXT + currentClass);
330             currentDist = (PredictionDistribution)
331                 context.get(PREFIX_DIST + currentClass);
332 
333             if (currentClass.equals(targetClass)) {
334                 // classifier for the target class
335                 classToTrain = targetClass;
336             } else {
337                 // other classifier: train negative class
338                 classToTrain = currentClass + NEG_SUFFIX;
339             }
340 
341             result = classifier.trainOnErrorHook(currentDist, features,
342                 classToTrain, classifier.getAllClasses(), currentContext)
343                 || result;
344         }
345         return result;
346     }
347 
348     /***
349      * {@inheritDoc}
350      */
351     public void reset() throws ProcessingException {
352         final Iterator innerIter = innerClassifiers.values().iterator();
353         TrainableClassifier classifier;
354 
355         // delegate to reset methods of all inner classifiers
356         while (innerIter.hasNext()) {
357             classifier = (TrainableClassifier) innerIter.next();
358             classifier.reset();
359         }
360     }
361 
362 }