1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22 package de.fu_berlin.ties.classify;
23
24 import java.io.File;
25 import java.util.ArrayList;
26 import java.util.HashMap;
27 import java.util.Iterator;
28 import java.util.List;
29 import java.util.Map;
30 import java.util.Set;
31 import java.util.SortedSet;
32 import java.util.TreeSet;
33
34 import org.apache.commons.lang.builder.ToStringBuilder;
35
36 import de.fu_berlin.ties.ContextMap;
37 import de.fu_berlin.ties.ProcessingException;
38 import de.fu_berlin.ties.TiesConfiguration;
39 import de.fu_berlin.ties.classify.feature.FeatureTransformer;
40 import de.fu_berlin.ties.classify.feature.FeatureVector;
41
42 /***
43 * This classifier converts an multi-class classification task into a several
44 * binary (two-class) classification task. It wraps several instances of
45 * another classifier that are used to perform the binary classifications and
46 * combines their results.
47 *
48 * <p>The first class from the set of classes passed to the constructor is
49 * considered as the "background" class, while all further members are
50 * considered as "foreground" classes.
51 *
52 * <p>Instances of this class are thread-safe if and only if instances of the
53 * wrapped classifier are.
54 *
55 * <p><em><strong>WARNING:</strong> The current implementation does
56 * <strong>not</strong> query the {@link
57 * de.fu_berlin.ties.classify.TrainableClassifier#shouldTrain(String,
58 * PredictionDistribution, ContextMap) shouldTrain} method of inner classifiers.
59 * Because of this, classifiers overwriting <code>shouldTrain</code> might not
60 * work correctly within this classifier.</em>
61 *
62 * @author Christian Siefkes
63 * @version $Revision: 1.19 $, $Date: 2004/12/09 18:09:14 $, $Author: siefkes $
64 */
65 public class MultiBinaryClassifier extends TrainableClassifier {
66
67 /***
68 * Key prefix used to store the contexts of inner classifiers.
69 */
70 private static final String PREFIX_CONTEXT = "context-";
71
72 /***
73 * Key prefix used to store the prediction distributions returned by inner
74 * classifiers.
75 */
76 private static final String PREFIX_DIST = "dist-";
77
78 /***
79 * The "background" class of this classifier.
80 */
81 private final String backgroundClass;
82
83 /***
84 * Maps from the names of the foreground classes to the binary classifiers
85 * used to decide between this foreground class and the background class.
86 */
87 private final Map<String, TrainableClassifier> innerClassifiers =
88 new HashMap<String, TrainableClassifier>();
89
90 /***
91 * Creates a new instance.
92 *
93 * @param allValidClasses the set of all valid classes; the first member
94 * of this set is considered as the "background" class, all further members
95 * are considered as "foreground" classes
96 * @param trans the last transformer in the transformer chain to use, or
97 * <code>null</code> if no feature transformers should be used
98 * @param runDirectory optional run directory passed to inner classifiers
99 * of the {@link ExternalClassifier} type
100 * @param innerSpec the specification used to initialize the inner
101 * classifiers, passed to the
102 * {@link TrainableClassifier#createClassifier(Set, File,
103 * FeatureTransformer, String[], TiesConfiguration)} factory method
104 * @param conf used to configure this instance and the inner classifiers
105 * @throws IllegalArgumentException if there are fewer than three classes
106 * (in which case you should use the inner classifier directly since there
107 * is no need to wrap several instances)
108 * @throws ProcessingException if an error occurred while creating this
109 * classifier or one of the wrapped classifiers
110 */
111 public MultiBinaryClassifier(final Set<String> allValidClasses,
112 final FeatureTransformer trans,
113 final File runDirectory,
114 final String[] innerSpec,
115 final TiesConfiguration conf)
116 throws IllegalArgumentException, ProcessingException {
117 super(allValidClasses, trans, conf);
118
119 if (allValidClasses.size() < 3) {
120 throw new IllegalArgumentException(
121 "MultiBinaryClassifier requires at least 3 classes instead of "
122 + allValidClasses.size());
123 }
124
125
126 final Iterator<String> classIter= allValidClasses.iterator();
127 backgroundClass = classIter.next();
128
129
130 String foregroundClass;
131 TrainableClassifier innerClassifier;
132 Set<String> innerSet;
133
134 while (classIter.hasNext()) {
135
136
137 foregroundClass = classIter.next();
138 innerSet = createBinarySet(foregroundClass);
139
140
141
142 innerClassifier = TrainableClassifier.createClassifier(innerSet,
143 runDirectory, null, innerSpec, conf);
144 innerClassifiers.put(foregroundClass, innerClassifier);
145 }
146 }
147
148 /***
149 * Helper method that creates a set containing the two classes of a
150 * binary classifier.
151 *
152 * @param foregroundClass the "foreground" class to use
153 * @return a set containing the {@linkplain #getBackgroundClass()
154 * "background" class} and the specified <code>foregroundClass</code>;
155 * this implementation returns the two classes in alphabetic order
156 */
157 protected Set<String> createBinarySet(final String foregroundClass) {
158 final SortedSet<String> result = new TreeSet<String>();
159 result.add(backgroundClass);
160 result.add(foregroundClass);
161 return result;
162 }
163
164 /***
165 * {@inheritDoc}
166 * This implementation combines the predictions for the foreground
167 * of all involved inner classifiers.
168 *
169 * <p>If the {@linkplain #getBackgroundClass() background class} is part of
170 * the <code>candidateClasses</code>, the classifier whose background
171 * probability is closest to 0.5 determines the probability of the
172 * background class. In this way all classes will be sorted the right way
173 * (foreground classes with a higher than 0.5 before the background class,
174 * those with a lower probability after it). pR values are <em>not</em>
175 * considered for this purpose, so you should be careful when combination
176 * classifiers that mainly rely on the pR value in a multi-binary setup.
177 *
178 * <p><em>The probability estimates returned by each classifier are used
179 * "as is", so the result will <strong>not</strong> be a real probability
180 * distribution because sum of all probabilities will be more than 1.
181 * If you want to work on a real probability distribution you have
182 * normalize it yourself.</em>
183 */
184 protected PredictionDistribution doClassify(final FeatureVector features,
185 final Set candidateClasses,
186 final ContextMap context)
187 throws ProcessingException {
188 final List<Prediction> unnormalizedPreds =
189 new ArrayList<Prediction>(candidateClasses.size());
190 final Iterator classIter = candidateClasses.iterator();
191 String currentClass;
192 TrainableClassifier innerClassifier;
193 PredictionDistribution currentDist;
194 Prediction currentPred;
195 ContextMap currentContext;
196 Iterator predIter;
197
198
199
200 Prediction backgroundMostNearToFifty = null;
201 double lowestBackgroundDist = Double.MAX_VALUE;
202 double currentBackgroundDist;
203
204
205 while (classIter.hasNext()) {
206 currentClass = (String) classIter.next();
207
208 if (!backgroundClass.equals(currentClass)) {
209
210 currentContext = new ContextMap();
211 innerClassifier = innerClassifiers.get(currentClass);
212 currentDist = innerClassifier.doClassify(features,
213 innerClassifier.getAllClasses(), currentContext);
214
215
216
217 context.put(PREFIX_CONTEXT + currentClass, currentContext);
218 context.put(PREFIX_DIST + currentClass, currentDist);
219
220 predIter = currentDist.iterator();
221
222
223 while (predIter.hasNext()) {
224 currentPred = (Prediction) predIter.next();
225
226
227 if (backgroundClass.equals(currentPred.getType())) {
228
229 currentBackgroundDist = Math.abs(
230 currentPred.getProbability().getProb()- 0.5);
231
232 if (currentBackgroundDist < lowestBackgroundDist) {
233
234
235
236
237
238
239
240
241
242 backgroundMostNearToFifty = currentPred;
243 lowestBackgroundDist = currentBackgroundDist;
244 }
245 } else {
246
247 unnormalizedPreds.add(currentPred);
248
249 }
250 }
251 }
252 }
253
254 if ((backgroundMostNearToFifty != null)
255 && candidateClasses.contains(backgroundClass)) {
256
257
258 unnormalizedPreds.add(backgroundMostNearToFifty);
259
260 }
261
262 final PredictionDistribution result = new PredictionDistribution();
263 predIter = unnormalizedPreds.iterator();
264
265
266 while (predIter.hasNext()) {
267 currentPred = (Prediction) predIter.next();
268 result.add(currentPred);
269
270
271
272
273
274
275 }
276
277
278
279
280
281 return result;
282 }
283
284 /***
285 * {@inheritDoc}
286 */
287 protected void doTrain(final FeatureVector features,
288 final String targetClass,
289 final ContextMap context)
290 throws ProcessingException {
291 final Iterator classIter = innerClassifiers.keySet().iterator();
292 String currentClass;
293 TrainableClassifier classifier;
294 String classToTrain;
295 String contextKey;
296 ContextMap currentContext;
297
298
299 while (classIter.hasNext()) {
300 currentClass = (String) classIter.next();
301 classifier = innerClassifiers.get(currentClass);
302
303 if (currentClass.equals(targetClass)) {
304
305 classToTrain = targetClass;
306 } else {
307
308 classToTrain = backgroundClass;
309 }
310
311
312 contextKey = PREFIX_CONTEXT + currentClass;
313 if (context.containsKey(contextKey)) {
314 currentContext = (ContextMap) context.get(contextKey);
315 } else {
316 currentContext = new ContextMap();
317 }
318
319 classifier.doTrain(features, classToTrain, currentContext);
320 }
321 }
322
323 /***
324 * Returns the "background" class of this classifier.
325 * @return the value of the attribute
326 */
327 public String getBackgroundClass() {
328 return backgroundClass;
329 }
330
331 /***
332 * Returns a string representation of this object.
333 *
334 * @return a textual representation
335 */
336 public String toString() {
337 return new ToStringBuilder(this)
338 .appendSuper(super.toString())
339 .append("background class", backgroundClass)
340 .append("inner classifiers", innerClassifiers.values())
341 .toString();
342 }
343
344 /***
345 * {@inheritDoc}
346 */
347 protected boolean trainOnErrorHook(final PredictionDistribution predDist,
348 final FeatureVector features, final String targetClass,
349 final Set candidateClasses, final ContextMap context)
350 throws ProcessingException {
351 boolean result = false;
352 final Iterator predIter = predDist.iterator();
353 Prediction currentPred;
354 String currentClass;
355 TrainableClassifier classifier;
356 ContextMap currentContext;
357 PredictionDistribution currentDist;
358 String classToTrain;
359
360
361
362 while (predIter.hasNext()) {
363 currentPred = (Prediction) predIter.next();
364 currentClass = currentPred.getType();
365
366 if (!backgroundClass.equals(currentClass)) {
367 classifier = innerClassifiers.get(currentClass);
368
369
370 currentContext = (ContextMap)
371 context.get(PREFIX_CONTEXT + currentClass);
372 currentDist = (PredictionDistribution)
373 context.get(PREFIX_DIST + currentClass);
374
375 if (currentClass.equals(targetClass)) {
376
377 classToTrain = targetClass;
378 } else {
379
380 classToTrain = backgroundClass;
381 }
382
383 result = classifier.trainOnErrorHook(currentDist, features,
384 classToTrain, classifier.getAllClasses(), currentContext)
385 || result;
386 }
387 }
388 return result;
389 }
390
391 /***
392 * {@inheritDoc}
393 */
394 public void reset() throws ProcessingException {
395 final Iterator innerIter = innerClassifiers.values().iterator();
396 TrainableClassifier classifier;
397
398
399 while (innerIter.hasNext()) {
400 classifier = (TrainableClassifier) innerIter.next();
401 classifier.reset();
402 }
403 }
404
405 }