View Javadoc

1   /*
2    * Copyright (C) 2004 Christian Siefkes <christian@siefkes.net>.
3    * Development of this software is supported by the German Research Society,
4    * Berlin-Brandenburg Graduate School in Distributed Information Systems
5    * (DFG grant no. GRK 316).
6    *
7    * This library is free software; you can redistribute it and/or
8    * modify it under the terms of the GNU Lesser General Public
9    * License as published by the Free Software Foundation; either
10   * version 2.1 of the License, or (at your option) any later version.
11   *
12   * This library is distributed in the hope that it will be useful,
13   * but WITHOUT ANY WARRANTY; without even the implied warranty of
14   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
15   * Lesser General Public License for more details.
16   *
17   * You should have received a copy of the GNU Lesser General Public
18   * License along with this library; if not, visit
19   * http://www.gnu.org/licenses/lgpl.html or write to the Free Software
20   * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307, USA.
21   */
22  package de.fu_berlin.ties.classify;
23  
24  import java.io.File;
25  import java.util.ArrayList;
26  import java.util.HashMap;
27  import java.util.Iterator;
28  import java.util.List;
29  import java.util.Map;
30  import java.util.Set;
31  import java.util.SortedSet;
32  import java.util.TreeSet;
33  
34  import org.apache.commons.lang.builder.ToStringBuilder;
35  
36  import de.fu_berlin.ties.ContextMap;
37  import de.fu_berlin.ties.ProcessingException;
38  import de.fu_berlin.ties.TiesConfiguration;
39  import de.fu_berlin.ties.classify.feature.FeatureTransformer;
40  import de.fu_berlin.ties.classify.feature.FeatureVector;
41  
42  /***
43   * This classifier converts an multi-class classification task into a several
44   * binary (two-class) classification task. It wraps several instances of
45   * another classifier that are used to perform the binary classifications and
46   * combines their results.
47   *
48   * <p>The first class from the set of classes passed to the constructor is
49   * considered as the "background" class, while all further members are
50   * considered as "foreground" classes.
51   *
52   * <p>Instances of this class are thread-safe if and only if instances of the
53   * wrapped classifier are.
54   *
55   * <p><em><strong>WARNING:</strong> The current implementation does
56   * <strong>not</strong> query the {@link
57   * de.fu_berlin.ties.classify.TrainableClassifier#shouldTrain(String,
58   * PredictionDistribution, ContextMap) shouldTrain} method of inner classifiers.
59   * Because of this, classifiers overwriting <code>shouldTrain</code> might not
60   * work correctly within this classifier.</em>
61   *
62   * @author Christian Siefkes
63   * @version $Revision: 1.19 $, $Date: 2004/12/09 18:09:14 $, $Author: siefkes $
64   */
65  public class MultiBinaryClassifier extends TrainableClassifier {
66  
67      /***
68       * Key prefix used to store the contexts of inner classifiers.
69       */
70      private static final String PREFIX_CONTEXT = "context-";
71  
72      /***
73       * Key prefix used to store the prediction distributions returned by inner
74       * classifiers.
75       */
76      private static final String PREFIX_DIST = "dist-";
77  
78      /***
79       * The "background" class of this classifier.
80       */
81      private final String backgroundClass;
82  
83      /***
84       * Maps from the names of the foreground classes to the binary classifiers
85       * used to decide between this foreground class and the background class.
86       */
87      private final Map<String, TrainableClassifier> innerClassifiers =
88          new HashMap<String, TrainableClassifier>();
89  
90      /***
91       * Creates a new instance.
92       *
93       * @param allValidClasses the set of all valid classes; the first member
94       * of this set is considered as the "background" class, all further members
95       * are considered as "foreground" classes
96       * @param trans the last transformer in the transformer chain to use, or
97       * <code>null</code> if no feature transformers should be used
98       * @param runDirectory optional run directory passed to inner classifiers
99       * of the {@link ExternalClassifier} type
100      * @param innerSpec the specification used to initialize the inner
101      * classifiers, passed to the
102      * {@link TrainableClassifier#createClassifier(Set, File,
103      * FeatureTransformer, String[], TiesConfiguration)} factory method
104      * @param conf used to configure this instance and the inner classifiers
105      * @throws IllegalArgumentException if there are fewer than three classes
106      * (in which case you should use the inner classifier directly since there
107      * is no need to wrap several instances)
108      * @throws ProcessingException if an error occurred while creating this
109      * classifier or one of the wrapped classifiers
110      */
111     public MultiBinaryClassifier(final Set<String> allValidClasses,
112                                  final FeatureTransformer trans,
113                                  final File runDirectory,
114                                  final String[] innerSpec,
115                                  final TiesConfiguration conf)
116     throws IllegalArgumentException, ProcessingException {
117         super(allValidClasses, trans, conf);
118 
119         if (allValidClasses.size() < 3) {
120             throw new IllegalArgumentException(
121                 "MultiBinaryClassifier requires at least 3 classes instead of "
122                 + allValidClasses.size());
123         }
124 
125         // store background class (we checked that there are enough classes)
126         final Iterator<String> classIter= allValidClasses.iterator();
127         backgroundClass = classIter.next();
128 
129         // create inner classifiers via factory method
130         String foregroundClass;
131         TrainableClassifier innerClassifier;
132         Set<String> innerSet;
133 
134         while (classIter.hasNext()) {
135             // each inner classifier decides between a foreground class and the
136             // background class
137             foregroundClass = classIter.next();
138             innerSet = createBinarySet(foregroundClass);
139 
140             // transformer is set to null because features shouldn't be
141             // transformed twice
142             innerClassifier = TrainableClassifier.createClassifier(innerSet,
143                 runDirectory, null, innerSpec, conf);
144             innerClassifiers.put(foregroundClass, innerClassifier);
145         }
146     }
147 
148     /***
149      * Helper method that creates a set containing the two classes of a
150      * binary classifier.
151      *
152      * @param foregroundClass the "foreground" class to use
153      * @return a set containing the {@linkplain #getBackgroundClass()
154      * "background" class} and the specified <code>foregroundClass</code>;
155      * this implementation returns the two classes in alphabetic order
156      */
157     protected Set<String> createBinarySet(final String foregroundClass) {
158         final SortedSet<String> result = new TreeSet<String>();
159         result.add(backgroundClass);
160         result.add(foregroundClass);
161         return result;
162     }
163 
164     /***
165      * {@inheritDoc}
166      * This implementation combines the predictions for the foreground
167      * of all involved inner classifiers.
168      *
169      * <p>If the {@linkplain #getBackgroundClass() background class} is part of
170      * the <code>candidateClasses</code>, the classifier whose background
171      * probability is closest to 0.5 determines the probability of the
172      * background class. In this way all classes will be sorted the right way
173      * (foreground classes with a higher than 0.5 before the background class,
174      * those with a lower probability after it). pR values are <em>not</em>
175      * considered for this purpose, so you should be careful when combination
176      * classifiers that mainly rely on the pR value in a multi-binary setup.
177      *
178      * <p><em>The probability estimates returned by each classifier are used
179      * "as is", so the result will <strong>not</strong> be a real probability
180      * distribution because sum of all probabilities will be more than 1.
181      * If you want to work on a real probability distribution you have
182      * normalize it yourself.</em>
183      */
184     protected PredictionDistribution doClassify(final FeatureVector features,
185                                                 final Set candidateClasses,
186                                                 final ContextMap context)
187     throws ProcessingException {
188         final List<Prediction> unnormalizedPreds =
189             new ArrayList<Prediction>(candidateClasses.size());
190         final Iterator classIter = candidateClasses.iterator();
191         String currentClass;
192         TrainableClassifier innerClassifier;
193         PredictionDistribution currentDist;
194         Prediction currentPred;
195         ContextMap currentContext;
196         Iterator predIter;
197 //        double summedProb = 0.0;
198 
199         // for determining which prediction to use for the background class
200         Prediction backgroundMostNearToFifty = null;
201         double lowestBackgroundDist = Double.MAX_VALUE;
202         double currentBackgroundDist;
203 
204         // delegate to doClassify methods of all foreground candidates
205         while (classIter.hasNext()) {
206             currentClass = (String) classIter.next();
207 
208             if (!backgroundClass.equals(currentClass)) {
209                 // classify all foreground candidates
210                 currentContext = new ContextMap();
211                 innerClassifier = innerClassifiers.get(currentClass);
212                 currentDist = innerClassifier.doClassify(features,
213                     innerClassifier.getAllClasses(), currentContext);
214 
215                 // store prediction distribution and context map of this
216                 // classifier in my own context
217                 context.put(PREFIX_CONTEXT + currentClass, currentContext);
218                 context.put(PREFIX_DIST + currentClass, currentDist);
219 
220                 predIter = currentDist.iterator();
221 
222                 // iterate prediction distribution of this classifier
223                 while (predIter.hasNext()) {
224                     currentPred = (Prediction) predIter.next();
225 
226                     // prediction for background class
227                     if (backgroundClass.equals(currentPred.getType())) {
228                         // calculate distance to 0.5 (absolute value)
229                         currentBackgroundDist = Math.abs(
230                                 currentPred.getProbability().getProb()- 0.5);
231 
232                         if (currentBackgroundDist < lowestBackgroundDist) {
233 /*                            if (backgroundMostNearToFifty != null) {
234                                 Util.LOG.debug("Using background prob. "
235                                     + currentPred + " instead of "
236                                     + backgroundMostNearToFifty
237                                     + " because it's closer to 0.5");
238                             } */
239 
240                             // prob. is nearer to 0.5 than other background
241                             // predictions
242                             backgroundMostNearToFifty = currentPred;
243                             lowestBackgroundDist = currentBackgroundDist;
244                         }
245                     } else {
246                         // prediction for a foreground class
247                         unnormalizedPreds.add(currentPred);
248 //                        summedProb += currentPred.getProbability();
249                     }
250                 }
251             }
252         }
253 
254         if ((backgroundMostNearToFifty != null)
255                 && candidateClasses.contains(backgroundClass)) {
256             // background class is a candidate:
257             // add it using the prob. that is most near to 0.5
258             unnormalizedPreds.add(backgroundMostNearToFifty);
259 //            summedProb += backgroundMostNearToFifty.getProbability();
260         }
261 
262         final PredictionDistribution result = new PredictionDistribution();
263         predIter = unnormalizedPreds.iterator();
264 
265         // build full distribution
266         while (predIter.hasNext()) {
267             currentPred = (Prediction) predIter.next();
268             result.add(currentPred);
269 
270 /*            // normalize probabilities so they sum to 1
271             result.add(new Prediction(currentPred.getType(),
272                 currentPred.getSource(),
273                 currentPred.getProbability() / summedProb,
274                 currentPred.getPR(), currentPred.getEvalStatus())); */
275         }
276 
277         //Util.LOG.debug("Combined probability distribution: "
278         //    + result.toString());
279         //Util.LOG.debug("Current context: " + context);
280 
281         return result;
282     }
283 
284     /***
285      * {@inheritDoc}
286      */
287     protected void doTrain(final FeatureVector features,
288                            final String targetClass,
289                            final ContextMap context)
290     throws ProcessingException {
291         final Iterator classIter = innerClassifiers.keySet().iterator();
292         String currentClass;
293         TrainableClassifier classifier;
294         String classToTrain;
295         String contextKey;
296         ContextMap currentContext;
297 
298         // delegate to doTrain methods of all classifiers
299         while (classIter.hasNext()) {
300             currentClass = (String) classIter.next();
301             classifier = innerClassifiers.get(currentClass);
302 
303             if (currentClass.equals(targetClass)) {
304                 // classifier for the target class
305                 classToTrain = targetClass;
306             } else {
307                 // other classifier: train background class
308                 classToTrain = backgroundClass;
309             }
310 
311             // retrieve original context map, if stored
312             contextKey = PREFIX_CONTEXT + currentClass;
313             if (context.containsKey(contextKey)) {
314                 currentContext = (ContextMap) context.get(contextKey);
315             } else {
316                 currentContext = new ContextMap();
317             }
318 
319             classifier.doTrain(features, classToTrain, currentContext);
320         }
321     }
322 
323     /***
324      * Returns the "background" class of this classifier.
325      * @return the value of the attribute
326      */
327     public String getBackgroundClass() {
328         return backgroundClass;
329     }
330 
331     /***
332      * Returns a string representation of this object.
333      *
334      * @return a textual representation
335      */
336     public String toString() {
337         return new ToStringBuilder(this)
338             .appendSuper(super.toString())
339             .append("background class", backgroundClass)
340             .append("inner classifiers", innerClassifiers.values())
341             .toString();
342     }
343 
344     /***
345      * {@inheritDoc}
346      */
347     protected boolean trainOnErrorHook(final PredictionDistribution predDist,
348             final FeatureVector features, final String targetClass,
349             final Set candidateClasses, final ContextMap context)
350     throws ProcessingException {
351         boolean result = false;
352         final Iterator predIter = predDist.iterator();
353         Prediction currentPred;
354         String currentClass;
355         TrainableClassifier classifier;
356         ContextMap currentContext;
357         PredictionDistribution currentDist;
358         String classToTrain;
359 
360         // delegate to all involved classifiers, returning ORing of results
361         // (i.e. return true if at least one of them returns true)
362         while (predIter.hasNext()) {
363             currentPred = (Prediction) predIter.next();
364             currentClass = currentPred.getType();
365 
366             if (!backgroundClass.equals(currentClass)) {
367                 classifier = innerClassifiers.get(currentClass);
368 
369                 // retrieve original context map + prediction distribution
370                 currentContext = (ContextMap)
371                     context.get(PREFIX_CONTEXT + currentClass);
372                 currentDist = (PredictionDistribution)
373                     context.get(PREFIX_DIST + currentClass);
374 
375                 if (currentClass.equals(targetClass)) {
376                     // classifier for the target class
377                     classToTrain = targetClass;
378                 } else {
379                     // other classifier: train background class
380                     classToTrain = backgroundClass;
381                 }
382 
383                 result = classifier.trainOnErrorHook(currentDist, features,
384                     classToTrain, classifier.getAllClasses(), currentContext)
385                     || result;
386             }
387         }
388         return result;
389     }
390 
391     /***
392      * {@inheritDoc}
393      */
394     public void reset() throws ProcessingException {
395         final Iterator innerIter = innerClassifiers.values().iterator();
396         TrainableClassifier classifier;
397 
398         // delegate to reset methods of all inner classifiers
399         while (innerIter.hasNext()) {
400             classifier = (TrainableClassifier) innerIter.next();
401             classifier.reset();
402         }
403     }
404 
405 }