View Javadoc

1   /*
2    * Copyright (C) 2004-2006 Christian Siefkes <christian@siefkes.net>.
3    * Development of this software is supported by the German Research Society,
4    * Berlin-Brandenburg Graduate School in Distributed Information Systems
5    * (DFG grant no. GRK 316).
6    *
7    * This program is free software; you can redistribute it and/or modify
8    * it under the terms of the GNU General Public License as published by
9    * the Free Software Foundation; either version 2 of the License, or
10   * (at your option) any later version.
11   *
12   * This program is distributed in the hope that it will be useful,
13   * but WITHOUT ANY WARRANTY; without even the implied warranty of
14   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15   * GNU General Public License for more details.
16   *
17   * You should have received a copy of the GNU General Public License
18   * along with this program; if not, visit
19   * http://www.gnu.org/licenses/gpl.html or write to the Free Software
20   * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
21   */
22  package de.fu_berlin.ties.classify;
23  
24  import java.io.File;
25  import java.util.ArrayList;
26  import java.util.HashMap;
27  import java.util.Iterator;
28  import java.util.List;
29  import java.util.Map;
30  import java.util.Set;
31  import java.util.SortedSet;
32  import java.util.TreeSet;
33  
34  import org.apache.commons.lang.builder.ToStringBuilder;
35  import org.dom4j.Element;
36  import org.dom4j.tree.DefaultElement;
37  
38  import de.fu_berlin.ties.ContextMap;
39  import de.fu_berlin.ties.ProcessingException;
40  import de.fu_berlin.ties.TiesConfiguration;
41  import de.fu_berlin.ties.classify.feature.FeatureTransformer;
42  import de.fu_berlin.ties.classify.feature.FeatureVector;
43  import de.fu_berlin.ties.io.ObjectElement;
44  
45  /***
46   * This classifier converts an multi-class classification task into a several
47   * binary (two-class) classification task. It wraps several instances of
48   * another classifier that are used to perform the binary classifications and
49   * combines their results.
50   *
51   * <p>When used with <em>N</em> classes, this classifier created and wrapped
52   * <em>N</em> inner binary classifiers. Each inner classifier classifiers
53   * between members (positive instances) and non-members (negative instances) of
54   * a specified class. The inner classifier returning the highest
55   * {@linkplain de.fu_berlin.ties.classify.Prediction#getProbability()
56   * probability} for its (positive) class wins, even if the negative class of
57   * that classifier would be still higher. This is the classical
58   * "one-against-the-rest" (or "one-against-all") setup.
59   *
60   * <p>This classifier differs from
61   * {@link de.fu_berlin.ties.classify.MultiBinaryClassifier} in not using a
62   * "background class" and using <em>N</em> inner classifiers, while
63   * <code>MultiBinaryClassifier</code> uses <em>N-1</em>.
64   *
65   * <p>Instances of this class are thread-safe if and only if instances of the
66   * wrapped classifier are.
67   *
68   * <p><em><strong>WARNING:</strong> The current implementation does
69   * <strong>not</strong> query the {@link
70   * de.fu_berlin.ties.classify.TrainableClassifier#shouldTrain(String,
71   * PredictionDistribution, ContextMap) shouldTrain} method of inner classifiers.
72   * Because of this, classifiers overwriting <code>shouldTrain</code> might not
73   * work correctly within this classifier.</em>
74   *
75   * @author Christian Siefkes
76   * @version $Revision: 1.9 $, $Date: 2006/10/21 16:03:54 $, $Author: siefkes $
77   */
78  public class OneAgainstTheRestClassifier extends TrainableClassifier {
79  
80      /***
81       * Key prefix used to store the contexts of inner classifiers.
82       */
83      private static final String PREFIX_CONTEXT = "context-";
84  
85      /***
86       * Key prefix used to store the prediction distributions returned by inner
87       * classifiers.
88       */
89      private static final String PREFIX_DIST = "dist-";
90  
91      /***
92       * The suffix appended to negative classes.
93       */
94      private static final String NEG_SUFFIX = "-neg";
95  
96  
97      /***
98       * Maps from the names of the classes to the binary classifiers
99       * used to decide between positive and negative instances of this class.
100      */
101     private final Map<String, TrainableClassifier> innerClassifiers =
102         new HashMap<String, TrainableClassifier>();
103 
104 
105     /***
106      * Creates a new instance from an XML element, fulfilling the
107      * recommandation of the {@link de.fu_berlin.ties.io.XMLStorable} interface.
108      *
109      * @param element the XML element containing the serialized representation
110      * @throws InstantiationException if the given element does not contain
111      * a valid classifier description
112      */
113     public OneAgainstTheRestClassifier(final Element element)
114     throws InstantiationException {
115         // delegate to superclass
116         super(element);
117 
118         // initialize inner classifiers
119         final Iterator innerIter =
120             element.elementIterator(MultiBinaryClassifier.ELEMENT_INNER);
121         Element innerElement;
122 
123         while (innerIter.hasNext()) {
124             innerElement = (Element) innerIter.next();
125             innerClassifiers.put(
126                 innerElement.attributeValue(MultiBinaryClassifier.ATTRIB_FOR),
127                 (TrainableClassifier) ObjectElement.createObject(
128                     innerElement.element(TrainableClassifier.ELEMENT_MAIN)));
129         }
130 
131         // check that there is a classifier for each class
132         if (!(getAllClasses().size() == innerClassifiers.size())) {
133             throw new InstantiationException("Serialization error: Found "
134                     + innerClassifiers.size()
135                     + " inner classifiers but there are "
136                     + getAllClasses().size() + " classes");
137         }
138 
139         final Iterator classIter = getAllClasses().iterator();
140         String className;
141 
142         while (classIter.hasNext()) {
143             className = (String) classIter.next();
144             if (!innerClassifiers.containsKey(className)) {
145                 throw new InstantiationException("Serialization error: "
146                     + "No inner classifier exists for the class "  + className);
147             }
148         }
149     }
150 
151     /***
152      * Creates a new instance.
153      *
154      * @param allValidClasses the set of all valid classes
155      * @param trans the last transformer in the transformer chain to use, or
156      * <code>null</code> if no feature transformers should be used
157      * @param runDirectory optional run directory passed to inner classifiers
158      * of the {@link ExternalClassifier} type
159      * @param innerSpec the specification used to initialize the inner
160      * classifiers, passed to the
161      * {@link TrainableClassifier#createClassifier(Set, File,
162      * FeatureTransformer, String[], TiesConfiguration)} factory method
163      * @param conf used to configure this instance and the inner classifiers
164      * @throws IllegalArgumentException if there are fewer than three classes
165      * (in which case you should use the inner classifier directly since there
166      * is no need to wrap several instances)
167      * @throws ProcessingException if an error occurred while creating this
168      * classifier or one of the wrapped classifiers
169      */
170     public OneAgainstTheRestClassifier(final Set<String> allValidClasses,
171                                  final FeatureTransformer trans,
172                                  final File runDirectory,
173                                  final String[] innerSpec,
174                                  final TiesConfiguration conf)
175     throws IllegalArgumentException, ProcessingException {
176         super(allValidClasses, trans, conf);
177 
178         if (allValidClasses.size() < 3) {
179             throw new IllegalArgumentException("OneAgainstTheRestClassifier "
180                     + "requires at least 3 classes instead of "
181                     + allValidClasses.size());
182         }
183 
184         // create inner classifiers via factory method
185         final Iterator<String> classIter = allValidClasses.iterator();
186         String baseClass;
187         TrainableClassifier innerClassifier;
188         Set<String> innerSet;
189 
190         while (classIter.hasNext()) {
191             // each inner classifier decides between a positive class and a
192             // negative class
193             baseClass = classIter.next();
194             innerSet = createBinarySet(baseClass);
195 
196             // transformer is set to null because features shouldn't be
197             // transformed twice
198             innerClassifier = TrainableClassifier.createClassifier(innerSet,
199                 runDirectory, null, innerSpec, conf);
200             innerClassifiers.put(baseClass, innerClassifier);
201         }
202     }
203 
204     /***
205      * Helper method that creates a set containing the two classes of a
206      * binary classifier.
207      *
208      * @param baseClass the base class to use
209      * @return a set containing a positive and a negative class derived from the
210      * base class; this implementation returns the two classes in alphabetic
211      * order which means that the negative class comes first
212      */
213     protected Set<String> createBinarySet(final String baseClass) {
214         final SortedSet<String> result = new TreeSet<String>();
215         result.add(baseClass);
216         result.add(baseClass + NEG_SUFFIX);
217         return result;
218     }
219 
220     /***
221      * {@inheritDoc}
222      */
223     public void destroy() throws ProcessingException {
224         final Iterator innerIter = innerClassifiers.values().iterator();
225         TrainableClassifier classifier;
226 
227         // delegate to destroy methods of all inner classifiers
228         while (innerIter.hasNext()) {
229             classifier = (TrainableClassifier) innerIter.next();
230             classifier.destroy();
231         }
232     }
233 
234     /***
235      * {@inheritDoc}
236      * This implementation combines the predictions for the positive class
237      * of all involved inner classifiers.
238      *
239      * <p><em>The probability estimates returned by each classifier are used
240      * "as is", so the result will <strong>not</strong> be a real probability
241      * distribution because sum of all probabilities will be more than 1.
242      * If you want to work on a real probability distribution you have
243      * normalize it yourself.</em>
244      */
245     protected PredictionDistribution doClassify(final FeatureVector features,
246                                                 final Set candidateClasses,
247                                                 final ContextMap context)
248     throws ProcessingException {
249         final List<Prediction> unnormalizedPreds =
250             new ArrayList<Prediction>(candidateClasses.size());
251         final Iterator classIter = candidateClasses.iterator();
252         String currentClass;
253         TrainableClassifier innerClassifier;
254         PredictionDistribution currentDist;
255         Prediction currentPred;
256         String predClass;
257         ContextMap currentContext;
258         Iterator predIter;
259 //        double summedProb = 0.0;
260 
261         // delegate to doClassify methods of all candidate classes
262         while (classIter.hasNext()) {
263             currentClass = (String) classIter.next();
264 
265             // classify all candidate classeses
266             currentContext = new ContextMap();
267             innerClassifier = innerClassifiers.get(currentClass);
268             currentDist = innerClassifier.doClassify(features,
269                 innerClassifier.getAllClasses(), currentContext);
270 
271             // store prediction distribution and context map of this
272             // classifier in my own context
273             context.put(PREFIX_CONTEXT + currentClass, currentContext);
274             context.put(PREFIX_DIST + currentClass, currentDist);
275 
276             predIter = currentDist.iterator();
277 
278             // iterate prediction distribution of this classifier
279             while (predIter.hasNext()) {
280                 currentPred = (Prediction) predIter.next();
281                 predClass = currentPred.getType();
282 
283                 if (predClass.equals(currentClass)) {
284                     // use predictions of positive classes only
285                     unnormalizedPreds.add(currentPred);
286 //                  summedProb += currentPred.getProbability();
287                 } else if (!predClass.equals(currentClass + NEG_SUFFIX)) {
288                     // just to make sure that everything is as it should be
289                     throw new RuntimeException("Implementation error: "
290                             + "unexpected class name " + predClass + " used by "
291                             + currentClass + " classifier");
292                 }
293             }
294         }
295 
296         final PredictionDistribution result = new PredictionDistribution();
297         predIter = unnormalizedPreds.iterator();
298 
299         // build full distribution for positive classes
300         while (predIter.hasNext()) {
301             currentPred = (Prediction) predIter.next();
302             result.add(currentPred);
303 
304 /*            // normalize probabilities so they sum to 1
305             result.add(new Prediction(currentPred.getType(),
306                 currentPred.getSource(),
307                 currentPred.getProbability() / summedProb,
308                 currentPred.getPR(), currentPred.getEvalStatus())); */
309         }
310 
311         //Util.LOG.debug("Combined probability distribution: "
312         //        + result.toString());
313         //Util.LOG.debug("Current context: " + context);
314 
315         return result;
316     }
317 
318     /***
319      * {@inheritDoc}
320      */
321     protected void doTrain(final FeatureVector features,
322                            final String targetClass,
323                            final ContextMap context)
324     throws ProcessingException {
325         final Iterator classIter = innerClassifiers.keySet().iterator();
326         String currentClass;
327         TrainableClassifier classifier;
328         String classToTrain;
329         String contextKey;
330         ContextMap currentContext;
331 
332         // delegate to doTrain methods of all classifiers
333         while (classIter.hasNext()) {
334             currentClass = (String) classIter.next();
335             classifier = innerClassifiers.get(currentClass);
336 
337             if (currentClass.equals(targetClass)) {
338                 // classifier for the target class
339                 classToTrain = targetClass;
340             } else {
341                 // other classifier: train negative class
342                 classToTrain = currentClass + NEG_SUFFIX;
343             }
344 
345             // retrieve original context map, if stored
346             contextKey = PREFIX_CONTEXT + currentClass;
347             if (context.containsKey(contextKey)) {
348                 currentContext = (ContextMap) context.get(contextKey);
349             } else {
350                 currentContext = new ContextMap();
351             }
352 
353             classifier.doTrain(features, classToTrain, currentContext);
354         }
355     }
356 
357     /***
358      * {@inheritDoc}
359      */
360     public void reset() throws ProcessingException {
361         final Iterator innerIter = innerClassifiers.values().iterator();
362         TrainableClassifier classifier;
363 
364         // delegate to reset methods of all inner classifiers
365         while (innerIter.hasNext()) {
366             classifier = (TrainableClassifier) innerIter.next();
367             classifier.reset();
368         }
369     }
370 
371     /***
372      * {@inheritDoc}
373      */
374     public ObjectElement toElement() {
375         // delegate to superclass
376         final ObjectElement result = super.toElement();
377 
378         // add each inner classifier within <inner> element
379         Element innerElement;
380         Map.Entry<String, TrainableClassifier> inner;
381 
382         for (Iterator<Map.Entry<String, TrainableClassifier>> iter =
383                 innerClassifiers.entrySet().iterator();
384                 iter.hasNext();) {
385             inner = iter.next();
386             innerElement =
387                 new DefaultElement(MultiBinaryClassifier.ELEMENT_INNER);
388             result.add(innerElement);
389             innerElement.addAttribute(MultiBinaryClassifier.ATTRIB_FOR,
390                     inner.getKey());
391             innerElement.add(inner.getValue().toElement());
392         }
393 
394         return result;
395     }
396 
397     /***
398      * Returns a string representation of this object.
399      *
400      * @return a textual representation
401      */
402     public String toString() {
403         return new ToStringBuilder(this)
404             .appendSuper(super.toString())
405             .append("inner classifiers", innerClassifiers.values())
406             .toString();
407     }
408 
409     /***
410      * {@inheritDoc}
411      */
412     protected boolean trainOnErrorHook(final PredictionDistribution predDist,
413             final FeatureVector features, final String targetClass,
414             final Set candidateClasses, final ContextMap context)
415     throws ProcessingException {
416         boolean result = false;
417         final Iterator predIter = predDist.iterator();
418         Prediction currentPred;
419         String currentClass;
420         TrainableClassifier classifier;
421         ContextMap currentContext;
422         PredictionDistribution currentDist;
423         String classToTrain;
424 
425         // delegate to all involved classifiers, returning ORing of results
426         // (i.e. return true if at least one of them returns true)
427         while (predIter.hasNext()) {
428             currentPred = (Prediction) predIter.next();
429             currentClass = currentPred.getType();
430             classifier = innerClassifiers.get(currentClass);
431 
432             // retrieve original context map + prediction distribution
433             currentContext = (ContextMap)
434                 context.get(PREFIX_CONTEXT + currentClass);
435             currentDist = (PredictionDistribution)
436                 context.get(PREFIX_DIST + currentClass);
437 
438             if (currentClass.equals(targetClass)) {
439                 // classifier for the target class
440                 classToTrain = targetClass;
441             } else {
442                 // other classifier: train negative class
443                 classToTrain = currentClass + NEG_SUFFIX;
444             }
445 
446             result = classifier.trainOnErrorHook(currentDist, features,
447                 classToTrain, classifier.getAllClasses(), currentContext)
448                 || result;
449         }
450         return result;
451     }
452 
453 }