1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22 package de.fu_berlin.ties.classify;
23
24 import java.io.File;
25 import java.util.ArrayList;
26 import java.util.HashMap;
27 import java.util.Iterator;
28 import java.util.List;
29 import java.util.Map;
30 import java.util.Set;
31 import java.util.SortedSet;
32 import java.util.TreeSet;
33
34 import org.apache.commons.lang.builder.ToStringBuilder;
35 import org.dom4j.Element;
36 import org.dom4j.QName;
37 import org.dom4j.tree.DefaultElement;
38
39 import de.fu_berlin.ties.ContextMap;
40 import de.fu_berlin.ties.ProcessingException;
41 import de.fu_berlin.ties.TiesConfiguration;
42 import de.fu_berlin.ties.classify.feature.FeatureTransformer;
43 import de.fu_berlin.ties.classify.feature.FeatureVector;
44 import de.fu_berlin.ties.io.ObjectElement;
45 import de.fu_berlin.ties.xml.dom.DOMUtils;
46
47 /***
48 * This classifier converts an multi-class classification task into a several
49 * binary (two-class) classification task. It wraps several instances of
50 * another classifier that are used to perform the binary classifications and
51 * combines their results.
52 *
53 * <p>The first class from the set of classes passed to the constructor is
54 * considered as the "background" class, while all further members are
55 * considered as "foreground" classes.
56 *
57 * <p>Instances of this class are thread-safe if and only if instances of the
58 * wrapped classifier are.
59 *
60 * <p><em><strong>WARNING:</strong> The current implementation does
61 * <strong>not</strong> query the {@link
62 * de.fu_berlin.ties.classify.TrainableClassifier#shouldTrain(String,
63 * PredictionDistribution, ContextMap) shouldTrain} method of inner classifiers.
64 * Because of this, classifiers overwriting <code>shouldTrain</code> might not
65 * work correctly within this classifier.</em>
66 *
67 * @author Christian Siefkes
68 * @version $Revision: 1.26 $, $Date: 2006/10/21 16:03:54 $, $Author: siefkes $
69 */
70 public class MultiBinaryClassifier extends TrainableClassifier {
71
72 /***
73 * Element name used for XML serialization.
74 */
75 static final QName ELEMENT_INNER = DOMUtils.defaultName("inner");
76
77 /***
78 * Attribute name used for XML serialization.
79 */
80 static final QName ATTRIB_FOR = DOMUtils.defaultName("for");
81
82 /***
83 * Attribute name used for XML serialization.
84 */
85 private static final QName ATTRIB_BACKGROUND =
86 DOMUtils.defaultName("background-class");
87
88 /***
89 * Key prefix used to store the contexts of inner classifiers.
90 */
91 private static final String PREFIX_CONTEXT = "context-";
92
93 /***
94 * Key prefix used to store the prediction distributions returned by inner
95 * classifiers.
96 */
97 private static final String PREFIX_DIST = "dist-";
98
99
100 /***
101 * The "background" class of this classifier.
102 */
103 private final String backgroundClass;
104
105 /***
106 * Maps from the names of the foreground classes to the binary classifiers
107 * used to decide between this foreground class and the background class.
108 */
109 private final Map<String, TrainableClassifier> innerClassifiers =
110 new HashMap<String, TrainableClassifier>();
111
112
113 /***
114 * Creates a new instance from an XML element, fulfilling the
115 * recommandation of the {@link de.fu_berlin.ties.io.XMLStorable} interface.
116 *
117 * @param element the XML element containing the serialized representation
118 * @throws InstantiationException if the given element does not contain
119 * a valid classifier description
120 */
121 public MultiBinaryClassifier(final Element element)
122 throws InstantiationException {
123
124 super(element);
125 backgroundClass = element.attributeValue(ATTRIB_BACKGROUND);
126
127
128 final Iterator innerIter = element.elementIterator(ELEMENT_INNER);
129 Element innerElement;
130
131 while (innerIter.hasNext()) {
132 innerElement = (Element) innerIter.next();
133 innerClassifiers.put(innerElement.attributeValue(ATTRIB_FOR),
134 (TrainableClassifier) ObjectElement.createObject(
135 innerElement.element(TrainableClassifier.ELEMENT_MAIN)));
136 }
137
138
139 if (!(getAllClasses().size() - 1 == innerClassifiers.size())) {
140 throw new InstantiationException("Serialization error: Found "
141 + innerClassifiers.size()
142 + " inner classifiers but there are "
143 + (getAllClasses().size() - 1) + " foreground classes");
144 }
145
146 final Iterator classIter = getAllClasses().iterator();
147 String className;
148
149 while (classIter.hasNext()) {
150 className = (String) classIter.next();
151 if (!(innerClassifiers.containsKey(className)
152 || backgroundClass.equals(className))) {
153 throw new InstantiationException("Serialization error: "
154 + "No inner classifier exists for the foreground class "
155 + className);
156 }
157 }
158 }
159
160 /***
161 * Creates a new instance.
162 *
163 * @param allValidClasses the set of all valid classes; the first member
164 * of this set is considered as the "background" class, all further members
165 * are considered as "foreground" classes
166 * @param trans the last transformer in the transformer chain to use, or
167 * <code>null</code> if no feature transformers should be used
168 * @param runDirectory optional run directory passed to inner classifiers
169 * of the {@link ExternalClassifier} type
170 * @param innerSpec the specification used to initialize the inner
171 * classifiers, passed to the
172 * {@link TrainableClassifier#createClassifier(Set, File,
173 * FeatureTransformer, String[], TiesConfiguration)} factory method
174 * @param conf used to configure this instance and the inner classifiers
175 * @throws IllegalArgumentException if there are fewer than three classes
176 * (in which case you should use the inner classifier directly since there
177 * is no need to wrap several instances)
178 * @throws ProcessingException if an error occurred while creating this
179 * classifier or one of the wrapped classifiers
180 */
181 public MultiBinaryClassifier(final Set<String> allValidClasses,
182 final FeatureTransformer trans,
183 final File runDirectory,
184 final String[] innerSpec,
185 final TiesConfiguration conf)
186 throws IllegalArgumentException, ProcessingException {
187 super(allValidClasses, trans, conf);
188
189 if (allValidClasses.size() < 3) {
190 throw new IllegalArgumentException(
191 "MultiBinaryClassifier requires at least 3 classes instead of "
192 + allValidClasses.size());
193 }
194
195
196 final Iterator<String> classIter = allValidClasses.iterator();
197 backgroundClass = classIter.next();
198
199
200 String foregroundClass;
201 TrainableClassifier innerClassifier;
202 Set<String> innerSet;
203
204 while (classIter.hasNext()) {
205
206
207 foregroundClass = classIter.next();
208 innerSet = createBinarySet(foregroundClass);
209
210
211
212 innerClassifier = TrainableClassifier.createClassifier(innerSet,
213 runDirectory, null, innerSpec, conf);
214 innerClassifiers.put(foregroundClass, innerClassifier);
215 }
216 }
217
218 /***
219 * Helper method that creates a set containing the two classes of a
220 * binary classifier.
221 *
222 * @param foregroundClass the "foreground" class to use
223 * @return a set containing the {@linkplain #getBackgroundClass()
224 * "background" class} and the specified <code>foregroundClass</code>;
225 * this implementation returns the two classes in alphabetic order
226 */
227 protected Set<String> createBinarySet(final String foregroundClass) {
228 final SortedSet<String> result = new TreeSet<String>();
229 result.add(backgroundClass);
230 result.add(foregroundClass);
231 return result;
232 }
233
234 /***
235 * {@inheritDoc}
236 */
237 public void destroy() throws ProcessingException {
238 final Iterator innerIter = innerClassifiers.values().iterator();
239 TrainableClassifier classifier;
240
241
242 while (innerIter.hasNext()) {
243 classifier = (TrainableClassifier) innerIter.next();
244 classifier.destroy();
245 }
246 }
247
248 /***
249 * {@inheritDoc}
250 * This implementation combines the predictions for the foreground
251 * of all involved inner classifiers.
252 *
253 * <p>If the {@linkplain #getBackgroundClass() background class} is part of
254 * the <code>candidateClasses</code>, the classifier whose background
255 * probability is closest to 0.5 determines the probability of the
256 * background class. In this way all classes will be sorted the right way
257 * (foreground classes with a higher than 0.5 before the background class,
258 * those with a lower probability after it). pR values are <em>not</em>
259 * considered for this purpose, so you should be careful when combination
260 * classifiers that mainly rely on the pR value in a multi-binary setup.
261 *
262 * <p><em>The probability estimates returned by each classifier are used
263 * "as is", so the result will <strong>not</strong> be a real probability
264 * distribution because sum of all probabilities will be more than 1.
265 * If you want to work on a real probability distribution you have
266 * normalize it yourself.</em>
267 */
268 protected PredictionDistribution doClassify(final FeatureVector features,
269 final Set candidateClasses,
270 final ContextMap context)
271 throws ProcessingException {
272 final List<Prediction> unnormalizedPreds =
273 new ArrayList<Prediction>(candidateClasses.size());
274 final Iterator classIter = candidateClasses.iterator();
275 String currentClass;
276 TrainableClassifier innerClassifier;
277 PredictionDistribution currentDist;
278 Prediction currentPred;
279 ContextMap currentContext;
280 Iterator predIter;
281
282
283
284 Prediction backgroundMostNearToFifty = null;
285 double lowestBackgroundDist = Double.MAX_VALUE;
286 double currentBackgroundDist;
287
288
289 while (classIter.hasNext()) {
290 currentClass = (String) classIter.next();
291
292 if (!backgroundClass.equals(currentClass)) {
293
294 currentContext = new ContextMap();
295 innerClassifier = innerClassifiers.get(currentClass);
296 currentDist = innerClassifier.doClassify(features,
297 innerClassifier.getAllClasses(), currentContext);
298
299
300
301 context.put(PREFIX_CONTEXT + currentClass, currentContext);
302 context.put(PREFIX_DIST + currentClass, currentDist);
303
304 predIter = currentDist.iterator();
305
306
307 while (predIter.hasNext()) {
308 currentPred = (Prediction) predIter.next();
309
310
311 if (backgroundClass.equals(currentPred.getType())) {
312
313 currentBackgroundDist = Math.abs(
314 currentPred.getProbability().getProb() - 0.5);
315
316 if (currentBackgroundDist < lowestBackgroundDist) {
317
318
319
320
321
322
323
324
325
326 backgroundMostNearToFifty = currentPred;
327 lowestBackgroundDist = currentBackgroundDist;
328 }
329 } else {
330
331 unnormalizedPreds.add(currentPred);
332
333 }
334 }
335 }
336 }
337
338 if ((backgroundMostNearToFifty != null)
339 && candidateClasses.contains(backgroundClass)) {
340
341
342 unnormalizedPreds.add(backgroundMostNearToFifty);
343
344 }
345
346 final PredictionDistribution result = new PredictionDistribution();
347 predIter = unnormalizedPreds.iterator();
348
349
350 while (predIter.hasNext()) {
351 currentPred = (Prediction) predIter.next();
352 result.add(currentPred);
353
354
355
356
357
358
359 }
360
361
362
363
364
365 return result;
366 }
367
368 /***
369 * {@inheritDoc}
370 */
371 protected void doTrain(final FeatureVector features,
372 final String targetClass,
373 final ContextMap context)
374 throws ProcessingException {
375 final Iterator classIter = innerClassifiers.keySet().iterator();
376 String currentClass;
377 TrainableClassifier classifier;
378 String classToTrain;
379 String contextKey;
380 ContextMap currentContext;
381
382
383 while (classIter.hasNext()) {
384 currentClass = (String) classIter.next();
385 classifier = innerClassifiers.get(currentClass);
386
387 if (currentClass.equals(targetClass)) {
388
389 classToTrain = targetClass;
390 } else {
391
392 classToTrain = backgroundClass;
393 }
394
395
396 contextKey = PREFIX_CONTEXT + currentClass;
397 if (context.containsKey(contextKey)) {
398 currentContext = (ContextMap) context.get(contextKey);
399 } else {
400 currentContext = new ContextMap();
401 }
402
403 classifier.doTrain(features, classToTrain, currentContext);
404 }
405 }
406
407 /***
408 * Returns the "background" class of this classifier.
409 * @return the value of the attribute
410 */
411 public String getBackgroundClass() {
412 return backgroundClass;
413 }
414
415 /***
416 * {@inheritDoc}
417 */
418 public ObjectElement toElement() {
419
420 final ObjectElement result = super.toElement();
421 result.addAttribute(ATTRIB_BACKGROUND, backgroundClass);
422
423
424 Element innerElement;
425 Map.Entry<String, TrainableClassifier> inner;
426
427 for (Iterator<Map.Entry<String, TrainableClassifier>> iter =
428 innerClassifiers.entrySet().iterator();
429 iter.hasNext();) {
430 inner = iter.next();
431 innerElement = new DefaultElement(ELEMENT_INNER);
432 result.add(innerElement);
433 innerElement.addAttribute(ATTRIB_FOR, inner.getKey());
434 innerElement.add(inner.getValue().toElement());
435 }
436
437 return result;
438 }
439
440 /***
441 * Returns a string representation of this object.
442 *
443 * @return a textual representation
444 */
445 public String toString() {
446 return new ToStringBuilder(this)
447 .appendSuper(super.toString())
448 .append("background class", backgroundClass)
449 .append("inner classifiers", innerClassifiers.values())
450 .toString();
451 }
452
453 /***
454 * {@inheritDoc}
455 */
456 protected boolean trainOnErrorHook(final PredictionDistribution predDist,
457 final FeatureVector features, final String targetClass,
458 final Set candidateClasses, final ContextMap context)
459 throws ProcessingException {
460 boolean result = false;
461 final Iterator predIter = predDist.iterator();
462 Prediction currentPred;
463 String currentClass;
464 TrainableClassifier classifier;
465 ContextMap currentContext;
466 PredictionDistribution currentDist;
467 String classToTrain;
468
469
470
471 while (predIter.hasNext()) {
472 currentPred = (Prediction) predIter.next();
473 currentClass = currentPred.getType();
474
475 if (!backgroundClass.equals(currentClass)) {
476 classifier = innerClassifiers.get(currentClass);
477
478
479 currentContext = (ContextMap)
480 context.get(PREFIX_CONTEXT + currentClass);
481 currentDist = (PredictionDistribution)
482 context.get(PREFIX_DIST + currentClass);
483
484 if (currentClass.equals(targetClass)) {
485
486 classToTrain = targetClass;
487 } else {
488
489 classToTrain = backgroundClass;
490 }
491
492 result = classifier.trainOnErrorHook(currentDist, features,
493 classToTrain, classifier.getAllClasses(), currentContext)
494 || result;
495 }
496 }
497 return result;
498 }
499
500 /***
501 * {@inheritDoc}
502 */
503 public void reset() throws ProcessingException {
504 final Iterator innerIter = innerClassifiers.values().iterator();
505 TrainableClassifier classifier;
506
507
508 while (innerIter.hasNext()) {
509 classifier = (TrainableClassifier) innerIter.next();
510 classifier.reset();
511 }
512 }
513
514 }