View Javadoc

1   /*
2    * Copyright (C) 2004-2006 Christian Siefkes <christian@siefkes.net>.
3    * Development of this software is supported by the German Research Society,
4    * Berlin-Brandenburg Graduate School in Distributed Information Systems
5    * (DFG grant no. GRK 316).
6    *
7    * This program is free software; you can redistribute it and/or modify
8    * it under the terms of the GNU General Public License as published by
9    * the Free Software Foundation; either version 2 of the License, or
10   * (at your option) any later version.
11   *
12   * This program is distributed in the hope that it will be useful,
13   * but WITHOUT ANY WARRANTY; without even the implied warranty of
14   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15   * GNU General Public License for more details.
16   *
17   * You should have received a copy of the GNU General Public License
18   * along with this program; if not, visit
19   * http://www.gnu.org/licenses/gpl.html or write to the Free Software
20   * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
21   */
22  package de.fu_berlin.ties.classify;
23  
24  import java.io.File;
25  import java.util.ArrayList;
26  import java.util.HashMap;
27  import java.util.Iterator;
28  import java.util.List;
29  import java.util.Map;
30  import java.util.Set;
31  import java.util.SortedSet;
32  import java.util.TreeSet;
33  
34  import org.apache.commons.lang.builder.ToStringBuilder;
35  import org.dom4j.Element;
36  import org.dom4j.QName;
37  import org.dom4j.tree.DefaultElement;
38  
39  import de.fu_berlin.ties.ContextMap;
40  import de.fu_berlin.ties.ProcessingException;
41  import de.fu_berlin.ties.TiesConfiguration;
42  import de.fu_berlin.ties.classify.feature.FeatureTransformer;
43  import de.fu_berlin.ties.classify.feature.FeatureVector;
44  import de.fu_berlin.ties.io.ObjectElement;
45  import de.fu_berlin.ties.xml.dom.DOMUtils;
46  
47  /***
48   * This classifier converts an multi-class classification task into a several
49   * binary (two-class) classification task. It wraps several instances of
50   * another classifier that are used to perform the binary classifications and
51   * combines their results.
52   *
53   * <p>The first class from the set of classes passed to the constructor is
54   * considered as the "background" class, while all further members are
55   * considered as "foreground" classes.
56   *
57   * <p>Instances of this class are thread-safe if and only if instances of the
58   * wrapped classifier are.
59   *
60   * <p><em><strong>WARNING:</strong> The current implementation does
61   * <strong>not</strong> query the {@link
62   * de.fu_berlin.ties.classify.TrainableClassifier#shouldTrain(String,
63   * PredictionDistribution, ContextMap) shouldTrain} method of inner classifiers.
64   * Because of this, classifiers overwriting <code>shouldTrain</code> might not
65   * work correctly within this classifier.</em>
66   *
67   * @author Christian Siefkes
68   * @version $Revision: 1.26 $, $Date: 2006/10/21 16:03:54 $, $Author: siefkes $
69   */
70  public class MultiBinaryClassifier extends TrainableClassifier {
71  
72      /***
73       * Element name used for XML serialization.
74       */
75      static final QName ELEMENT_INNER = DOMUtils.defaultName("inner");
76  
77      /***
78       * Attribute name used for XML serialization.
79       */
80      static final QName ATTRIB_FOR = DOMUtils.defaultName("for");
81  
82      /***
83       * Attribute name used for XML serialization.
84       */
85      private static final QName ATTRIB_BACKGROUND =
86          DOMUtils.defaultName("background-class");
87  
88      /***
89       * Key prefix used to store the contexts of inner classifiers.
90       */
91      private static final String PREFIX_CONTEXT = "context-";
92  
93      /***
94       * Key prefix used to store the prediction distributions returned by inner
95       * classifiers.
96       */
97      private static final String PREFIX_DIST = "dist-";
98  
99  
100     /***
101      * The "background" class of this classifier.
102      */
103     private final String backgroundClass;
104 
105     /***
106      * Maps from the names of the foreground classes to the binary classifiers
107      * used to decide between this foreground class and the background class.
108      */
109     private final Map<String, TrainableClassifier> innerClassifiers =
110         new HashMap<String, TrainableClassifier>();
111 
112 
113     /***
114      * Creates a new instance from an XML element, fulfilling the
115      * recommandation of the {@link de.fu_berlin.ties.io.XMLStorable} interface.
116      *
117      * @param element the XML element containing the serialized representation
118      * @throws InstantiationException if the given element does not contain
119      * a valid classifier description
120      */
121     public MultiBinaryClassifier(final Element element)
122     throws InstantiationException {
123         // delegate to superclass + init background class
124         super(element);
125         backgroundClass = element.attributeValue(ATTRIB_BACKGROUND);
126 
127         // initialize inner classifiers
128         final Iterator innerIter = element.elementIterator(ELEMENT_INNER);
129         Element innerElement;
130 
131         while (innerIter.hasNext()) {
132             innerElement = (Element) innerIter.next();
133             innerClassifiers.put(innerElement.attributeValue(ATTRIB_FOR),
134                 (TrainableClassifier) ObjectElement.createObject(
135                     innerElement.element(TrainableClassifier.ELEMENT_MAIN)));
136         }
137 
138         // check that there is a classifier for each foreground class
139         if (!(getAllClasses().size() - 1 == innerClassifiers.size())) {
140             throw new InstantiationException("Serialization error: Found "
141                     + innerClassifiers.size()
142                     + " inner classifiers but there are "
143                     + (getAllClasses().size() - 1) + " foreground classes");
144         }
145 
146         final Iterator classIter = getAllClasses().iterator();
147         String className;
148 
149         while (classIter.hasNext()) {
150             className = (String) classIter.next();
151             if (!(innerClassifiers.containsKey(className)
152                     || backgroundClass.equals(className))) {
153                 throw new InstantiationException("Serialization error: "
154                     + "No inner classifier exists for the foreground class "
155                     + className);
156             }
157         }
158     }
159 
160     /***
161      * Creates a new instance.
162      *
163      * @param allValidClasses the set of all valid classes; the first member
164      * of this set is considered as the "background" class, all further members
165      * are considered as "foreground" classes
166      * @param trans the last transformer in the transformer chain to use, or
167      * <code>null</code> if no feature transformers should be used
168      * @param runDirectory optional run directory passed to inner classifiers
169      * of the {@link ExternalClassifier} type
170      * @param innerSpec the specification used to initialize the inner
171      * classifiers, passed to the
172      * {@link TrainableClassifier#createClassifier(Set, File,
173      * FeatureTransformer, String[], TiesConfiguration)} factory method
174      * @param conf used to configure this instance and the inner classifiers
175      * @throws IllegalArgumentException if there are fewer than three classes
176      * (in which case you should use the inner classifier directly since there
177      * is no need to wrap several instances)
178      * @throws ProcessingException if an error occurred while creating this
179      * classifier or one of the wrapped classifiers
180      */
181     public MultiBinaryClassifier(final Set<String> allValidClasses,
182                                  final FeatureTransformer trans,
183                                  final File runDirectory,
184                                  final String[] innerSpec,
185                                  final TiesConfiguration conf)
186     throws IllegalArgumentException, ProcessingException {
187         super(allValidClasses, trans, conf);
188 
189         if (allValidClasses.size() < 3) {
190             throw new IllegalArgumentException(
191                 "MultiBinaryClassifier requires at least 3 classes instead of "
192                 + allValidClasses.size());
193         }
194 
195         // store background class (we checked that there are enough classes)
196         final Iterator<String> classIter = allValidClasses.iterator();
197         backgroundClass = classIter.next();
198 
199         // create inner classifiers via factory method
200         String foregroundClass;
201         TrainableClassifier innerClassifier;
202         Set<String> innerSet;
203 
204         while (classIter.hasNext()) {
205             // each inner classifier decides between a foreground class and the
206             // background class
207             foregroundClass = classIter.next();
208             innerSet = createBinarySet(foregroundClass);
209 
210             // transformer is set to null because features shouldn't be
211             // transformed twice
212             innerClassifier = TrainableClassifier.createClassifier(innerSet,
213                 runDirectory, null, innerSpec, conf);
214             innerClassifiers.put(foregroundClass, innerClassifier);
215         }
216     }
217 
218     /***
219      * Helper method that creates a set containing the two classes of a
220      * binary classifier.
221      *
222      * @param foregroundClass the "foreground" class to use
223      * @return a set containing the {@linkplain #getBackgroundClass()
224      * "background" class} and the specified <code>foregroundClass</code>;
225      * this implementation returns the two classes in alphabetic order
226      */
227     protected Set<String> createBinarySet(final String foregroundClass) {
228         final SortedSet<String> result = new TreeSet<String>();
229         result.add(backgroundClass);
230         result.add(foregroundClass);
231         return result;
232     }
233 
234     /***
235      * {@inheritDoc}
236      */
237     public void destroy() throws ProcessingException {
238         final Iterator innerIter = innerClassifiers.values().iterator();
239         TrainableClassifier classifier;
240 
241         // delegate to destroy methods of all inner classifiers
242         while (innerIter.hasNext()) {
243             classifier = (TrainableClassifier) innerIter.next();
244             classifier.destroy();
245         }
246     }
247 
248     /***
249      * {@inheritDoc}
250      * This implementation combines the predictions for the foreground
251      * of all involved inner classifiers.
252      *
253      * <p>If the {@linkplain #getBackgroundClass() background class} is part of
254      * the <code>candidateClasses</code>, the classifier whose background
255      * probability is closest to 0.5 determines the probability of the
256      * background class. In this way all classes will be sorted the right way
257      * (foreground classes with a higher than 0.5 before the background class,
258      * those with a lower probability after it). pR values are <em>not</em>
259      * considered for this purpose, so you should be careful when combination
260      * classifiers that mainly rely on the pR value in a multi-binary setup.
261      *
262      * <p><em>The probability estimates returned by each classifier are used
263      * "as is", so the result will <strong>not</strong> be a real probability
264      * distribution because sum of all probabilities will be more than 1.
265      * If you want to work on a real probability distribution you have
266      * normalize it yourself.</em>
267      */
268     protected PredictionDistribution doClassify(final FeatureVector features,
269                                                 final Set candidateClasses,
270                                                 final ContextMap context)
271     throws ProcessingException {
272         final List<Prediction> unnormalizedPreds =
273             new ArrayList<Prediction>(candidateClasses.size());
274         final Iterator classIter = candidateClasses.iterator();
275         String currentClass;
276         TrainableClassifier innerClassifier;
277         PredictionDistribution currentDist;
278         Prediction currentPred;
279         ContextMap currentContext;
280         Iterator predIter;
281 //        double summedProb = 0.0;
282 
283         // for determining which prediction to use for the background class
284         Prediction backgroundMostNearToFifty = null;
285         double lowestBackgroundDist = Double.MAX_VALUE;
286         double currentBackgroundDist;
287 
288         // delegate to doClassify methods of all foreground candidates
289         while (classIter.hasNext()) {
290             currentClass = (String) classIter.next();
291 
292             if (!backgroundClass.equals(currentClass)) {
293                 // classify all foreground candidates
294                 currentContext = new ContextMap();
295                 innerClassifier = innerClassifiers.get(currentClass);
296                 currentDist = innerClassifier.doClassify(features,
297                     innerClassifier.getAllClasses(), currentContext);
298 
299                 // store prediction distribution and context map of this
300                 // classifier in my own context
301                 context.put(PREFIX_CONTEXT + currentClass, currentContext);
302                 context.put(PREFIX_DIST + currentClass, currentDist);
303 
304                 predIter = currentDist.iterator();
305 
306                 // iterate prediction distribution of this classifier
307                 while (predIter.hasNext()) {
308                     currentPred = (Prediction) predIter.next();
309 
310                     // prediction for background class
311                     if (backgroundClass.equals(currentPred.getType())) {
312                         // calculate distance to 0.5 (absolute value)
313                         currentBackgroundDist = Math.abs(
314                                 currentPred.getProbability().getProb() - 0.5);
315 
316                         if (currentBackgroundDist < lowestBackgroundDist) {
317 /*                            if (backgroundMostNearToFifty != null) {
318                                 Util.LOG.debug("Using background prob. "
319                                     + currentPred + " instead of "
320                                     + backgroundMostNearToFifty
321                                     + " because it's closer to 0.5");
322                             } */
323 
324                             // prob. is nearer to 0.5 than other background
325                             // predictions
326                             backgroundMostNearToFifty = currentPred;
327                             lowestBackgroundDist = currentBackgroundDist;
328                         }
329                     } else {
330                         // prediction for a foreground class
331                         unnormalizedPreds.add(currentPred);
332 //                        summedProb += currentPred.getProbability();
333                     }
334                 }
335             }
336         }
337 
338         if ((backgroundMostNearToFifty != null)
339                 && candidateClasses.contains(backgroundClass)) {
340             // background class is a candidate:
341             // add it using the prob. that is most near to 0.5
342             unnormalizedPreds.add(backgroundMostNearToFifty);
343 //            summedProb += backgroundMostNearToFifty.getProbability();
344         }
345 
346         final PredictionDistribution result = new PredictionDistribution();
347         predIter = unnormalizedPreds.iterator();
348 
349         // build full distribution
350         while (predIter.hasNext()) {
351             currentPred = (Prediction) predIter.next();
352             result.add(currentPred);
353 
354 /*            // normalize probabilities so they sum to 1
355             result.add(new Prediction(currentPred.getType(),
356                 currentPred.getSource(),
357                 currentPred.getProbability() / summedProb,
358                 currentPred.getPR(), currentPred.getEvalStatus())); */
359         }
360 
361         //Util.LOG.debug("Combined probability distribution: "
362         //    + result.toString());
363         //Util.LOG.debug("Current context: " + context);
364 
365         return result;
366     }
367 
368     /***
369      * {@inheritDoc}
370      */
371     protected void doTrain(final FeatureVector features,
372                            final String targetClass,
373                            final ContextMap context)
374     throws ProcessingException {
375         final Iterator classIter = innerClassifiers.keySet().iterator();
376         String currentClass;
377         TrainableClassifier classifier;
378         String classToTrain;
379         String contextKey;
380         ContextMap currentContext;
381 
382         // delegate to doTrain methods of all classifiers
383         while (classIter.hasNext()) {
384             currentClass = (String) classIter.next();
385             classifier = innerClassifiers.get(currentClass);
386 
387             if (currentClass.equals(targetClass)) {
388                 // classifier for the target class
389                 classToTrain = targetClass;
390             } else {
391                 // other classifier: train background class
392                 classToTrain = backgroundClass;
393             }
394 
395             // retrieve original context map, if stored
396             contextKey = PREFIX_CONTEXT + currentClass;
397             if (context.containsKey(contextKey)) {
398                 currentContext = (ContextMap) context.get(contextKey);
399             } else {
400                 currentContext = new ContextMap();
401             }
402 
403             classifier.doTrain(features, classToTrain, currentContext);
404         }
405     }
406 
407     /***
408      * Returns the "background" class of this classifier.
409      * @return the value of the attribute
410      */
411     public String getBackgroundClass() {
412         return backgroundClass;
413     }
414 
415     /***
416      * {@inheritDoc}
417      */
418     public ObjectElement toElement() {
419         // delegate to superclass + add background class attribute
420         final ObjectElement result = super.toElement();
421         result.addAttribute(ATTRIB_BACKGROUND, backgroundClass);
422 
423         // add each inner classifier within <inner> element
424         Element innerElement;
425         Map.Entry<String, TrainableClassifier> inner;
426 
427         for (Iterator<Map.Entry<String, TrainableClassifier>> iter =
428             innerClassifiers.entrySet().iterator();
429                 iter.hasNext();) {
430             inner = iter.next();
431             innerElement = new DefaultElement(ELEMENT_INNER);
432             result.add(innerElement);
433             innerElement.addAttribute(ATTRIB_FOR, inner.getKey());
434             innerElement.add(inner.getValue().toElement());
435         }
436 
437         return result;
438     }
439 
440     /***
441      * Returns a string representation of this object.
442      *
443      * @return a textual representation
444      */
445     public String toString() {
446         return new ToStringBuilder(this)
447             .appendSuper(super.toString())
448             .append("background class", backgroundClass)
449             .append("inner classifiers", innerClassifiers.values())
450             .toString();
451     }
452 
453     /***
454      * {@inheritDoc}
455      */
456     protected boolean trainOnErrorHook(final PredictionDistribution predDist,
457             final FeatureVector features, final String targetClass,
458             final Set candidateClasses, final ContextMap context)
459     throws ProcessingException {
460         boolean result = false;
461         final Iterator predIter = predDist.iterator();
462         Prediction currentPred;
463         String currentClass;
464         TrainableClassifier classifier;
465         ContextMap currentContext;
466         PredictionDistribution currentDist;
467         String classToTrain;
468 
469         // delegate to all involved classifiers, returning ORing of results
470         // (i.e. return true if at least one of them returns true)
471         while (predIter.hasNext()) {
472             currentPred = (Prediction) predIter.next();
473             currentClass = currentPred.getType();
474 
475             if (!backgroundClass.equals(currentClass)) {
476                 classifier = innerClassifiers.get(currentClass);
477 
478                 // retrieve original context map + prediction distribution
479                 currentContext = (ContextMap)
480                     context.get(PREFIX_CONTEXT + currentClass);
481                 currentDist = (PredictionDistribution)
482                     context.get(PREFIX_DIST + currentClass);
483 
484                 if (currentClass.equals(targetClass)) {
485                     // classifier for the target class
486                     classToTrain = targetClass;
487                 } else {
488                     // other classifier: train background class
489                     classToTrain = backgroundClass;
490                 }
491 
492                 result = classifier.trainOnErrorHook(currentDist, features,
493                     classToTrain, classifier.getAllClasses(), currentContext)
494                     || result;
495             }
496         }
497         return result;
498     }
499 
500     /***
501      * {@inheritDoc}
502      */
503     public void reset() throws ProcessingException {
504         final Iterator innerIter = innerClassifiers.values().iterator();
505         TrainableClassifier classifier;
506 
507         // delegate to reset methods of all inner classifiers
508         while (innerIter.hasNext()) {
509             classifier = (TrainableClassifier) innerIter.next();
510             classifier.reset();
511         }
512     }
513 
514 }