1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22 package de.fu_berlin.ties.classify;
23
24 import java.io.File;
25 import java.util.ArrayList;
26 import java.util.HashMap;
27 import java.util.Iterator;
28 import java.util.List;
29 import java.util.Map;
30 import java.util.Set;
31 import java.util.SortedSet;
32 import java.util.TreeSet;
33
34 import org.apache.commons.lang.builder.ToStringBuilder;
35 import org.dom4j.Element;
36 import org.dom4j.tree.DefaultElement;
37
38 import de.fu_berlin.ties.ContextMap;
39 import de.fu_berlin.ties.ProcessingException;
40 import de.fu_berlin.ties.TiesConfiguration;
41 import de.fu_berlin.ties.classify.feature.FeatureTransformer;
42 import de.fu_berlin.ties.classify.feature.FeatureVector;
43 import de.fu_berlin.ties.io.ObjectElement;
44
45 /***
46 * This classifier converts an multi-class classification task into a several
47 * binary (two-class) classification task. It wraps several instances of
48 * another classifier that are used to perform the binary classifications and
49 * combines their results.
50 *
51 * <p>When used with <em>N</em> classes, this classifier created and wrapped
52 * <em>N</em> inner binary classifiers. Each inner classifier classifiers
53 * between members (positive instances) and non-members (negative instances) of
54 * a specified class. The inner classifier returning the highest
55 * {@linkplain de.fu_berlin.ties.classify.Prediction#getProbability()
56 * probability} for its (positive) class wins, even if the negative class of
57 * that classifier would be still higher. This is the classical
58 * "one-against-the-rest" (or "one-against-all") setup.
59 *
60 * <p>This classifier differs from
61 * {@link de.fu_berlin.ties.classify.MultiBinaryClassifier} in not using a
62 * "background class" and using <em>N</em> inner classifiers, while
63 * <code>MultiBinaryClassifier</code> uses <em>N-1</em>.
64 *
65 * <p>Instances of this class are thread-safe if and only if instances of the
66 * wrapped classifier are.
67 *
68 * <p><em><strong>WARNING:</strong> The current implementation does
69 * <strong>not</strong> query the {@link
70 * de.fu_berlin.ties.classify.TrainableClassifier#shouldTrain(String,
71 * PredictionDistribution, ContextMap) shouldTrain} method of inner classifiers.
72 * Because of this, classifiers overwriting <code>shouldTrain</code> might not
73 * work correctly within this classifier.</em>
74 *
75 * @author Christian Siefkes
76 * @version $Revision: 1.9 $, $Date: 2006/10/21 16:03:54 $, $Author: siefkes $
77 */
78 public class OneAgainstTheRestClassifier extends TrainableClassifier {
79
80 /***
81 * Key prefix used to store the contexts of inner classifiers.
82 */
83 private static final String PREFIX_CONTEXT = "context-";
84
85 /***
86 * Key prefix used to store the prediction distributions returned by inner
87 * classifiers.
88 */
89 private static final String PREFIX_DIST = "dist-";
90
91 /***
92 * The suffix appended to negative classes.
93 */
94 private static final String NEG_SUFFIX = "-neg";
95
96
97 /***
98 * Maps from the names of the classes to the binary classifiers
99 * used to decide between positive and negative instances of this class.
100 */
101 private final Map<String, TrainableClassifier> innerClassifiers =
102 new HashMap<String, TrainableClassifier>();
103
104
105 /***
106 * Creates a new instance from an XML element, fulfilling the
107 * recommandation of the {@link de.fu_berlin.ties.io.XMLStorable} interface.
108 *
109 * @param element the XML element containing the serialized representation
110 * @throws InstantiationException if the given element does not contain
111 * a valid classifier description
112 */
113 public OneAgainstTheRestClassifier(final Element element)
114 throws InstantiationException {
115
116 super(element);
117
118
119 final Iterator innerIter =
120 element.elementIterator(MultiBinaryClassifier.ELEMENT_INNER);
121 Element innerElement;
122
123 while (innerIter.hasNext()) {
124 innerElement = (Element) innerIter.next();
125 innerClassifiers.put(
126 innerElement.attributeValue(MultiBinaryClassifier.ATTRIB_FOR),
127 (TrainableClassifier) ObjectElement.createObject(
128 innerElement.element(TrainableClassifier.ELEMENT_MAIN)));
129 }
130
131
132 if (!(getAllClasses().size() == innerClassifiers.size())) {
133 throw new InstantiationException("Serialization error: Found "
134 + innerClassifiers.size()
135 + " inner classifiers but there are "
136 + getAllClasses().size() + " classes");
137 }
138
139 final Iterator classIter = getAllClasses().iterator();
140 String className;
141
142 while (classIter.hasNext()) {
143 className = (String) classIter.next();
144 if (!innerClassifiers.containsKey(className)) {
145 throw new InstantiationException("Serialization error: "
146 + "No inner classifier exists for the class " + className);
147 }
148 }
149 }
150
151 /***
152 * Creates a new instance.
153 *
154 * @param allValidClasses the set of all valid classes
155 * @param trans the last transformer in the transformer chain to use, or
156 * <code>null</code> if no feature transformers should be used
157 * @param runDirectory optional run directory passed to inner classifiers
158 * of the {@link ExternalClassifier} type
159 * @param innerSpec the specification used to initialize the inner
160 * classifiers, passed to the
161 * {@link TrainableClassifier#createClassifier(Set, File,
162 * FeatureTransformer, String[], TiesConfiguration)} factory method
163 * @param conf used to configure this instance and the inner classifiers
164 * @throws IllegalArgumentException if there are fewer than three classes
165 * (in which case you should use the inner classifier directly since there
166 * is no need to wrap several instances)
167 * @throws ProcessingException if an error occurred while creating this
168 * classifier or one of the wrapped classifiers
169 */
170 public OneAgainstTheRestClassifier(final Set<String> allValidClasses,
171 final FeatureTransformer trans,
172 final File runDirectory,
173 final String[] innerSpec,
174 final TiesConfiguration conf)
175 throws IllegalArgumentException, ProcessingException {
176 super(allValidClasses, trans, conf);
177
178 if (allValidClasses.size() < 3) {
179 throw new IllegalArgumentException("OneAgainstTheRestClassifier "
180 + "requires at least 3 classes instead of "
181 + allValidClasses.size());
182 }
183
184
185 final Iterator<String> classIter = allValidClasses.iterator();
186 String baseClass;
187 TrainableClassifier innerClassifier;
188 Set<String> innerSet;
189
190 while (classIter.hasNext()) {
191
192
193 baseClass = classIter.next();
194 innerSet = createBinarySet(baseClass);
195
196
197
198 innerClassifier = TrainableClassifier.createClassifier(innerSet,
199 runDirectory, null, innerSpec, conf);
200 innerClassifiers.put(baseClass, innerClassifier);
201 }
202 }
203
204 /***
205 * Helper method that creates a set containing the two classes of a
206 * binary classifier.
207 *
208 * @param baseClass the base class to use
209 * @return a set containing a positive and a negative class derived from the
210 * base class; this implementation returns the two classes in alphabetic
211 * order which means that the negative class comes first
212 */
213 protected Set<String> createBinarySet(final String baseClass) {
214 final SortedSet<String> result = new TreeSet<String>();
215 result.add(baseClass);
216 result.add(baseClass + NEG_SUFFIX);
217 return result;
218 }
219
220 /***
221 * {@inheritDoc}
222 */
223 public void destroy() throws ProcessingException {
224 final Iterator innerIter = innerClassifiers.values().iterator();
225 TrainableClassifier classifier;
226
227
228 while (innerIter.hasNext()) {
229 classifier = (TrainableClassifier) innerIter.next();
230 classifier.destroy();
231 }
232 }
233
234 /***
235 * {@inheritDoc}
236 * This implementation combines the predictions for the positive class
237 * of all involved inner classifiers.
238 *
239 * <p><em>The probability estimates returned by each classifier are used
240 * "as is", so the result will <strong>not</strong> be a real probability
241 * distribution because sum of all probabilities will be more than 1.
242 * If you want to work on a real probability distribution you have
243 * normalize it yourself.</em>
244 */
245 protected PredictionDistribution doClassify(final FeatureVector features,
246 final Set candidateClasses,
247 final ContextMap context)
248 throws ProcessingException {
249 final List<Prediction> unnormalizedPreds =
250 new ArrayList<Prediction>(candidateClasses.size());
251 final Iterator classIter = candidateClasses.iterator();
252 String currentClass;
253 TrainableClassifier innerClassifier;
254 PredictionDistribution currentDist;
255 Prediction currentPred;
256 String predClass;
257 ContextMap currentContext;
258 Iterator predIter;
259
260
261
262 while (classIter.hasNext()) {
263 currentClass = (String) classIter.next();
264
265
266 currentContext = new ContextMap();
267 innerClassifier = innerClassifiers.get(currentClass);
268 currentDist = innerClassifier.doClassify(features,
269 innerClassifier.getAllClasses(), currentContext);
270
271
272
273 context.put(PREFIX_CONTEXT + currentClass, currentContext);
274 context.put(PREFIX_DIST + currentClass, currentDist);
275
276 predIter = currentDist.iterator();
277
278
279 while (predIter.hasNext()) {
280 currentPred = (Prediction) predIter.next();
281 predClass = currentPred.getType();
282
283 if (predClass.equals(currentClass)) {
284
285 unnormalizedPreds.add(currentPred);
286
287 } else if (!predClass.equals(currentClass + NEG_SUFFIX)) {
288
289 throw new RuntimeException("Implementation error: "
290 + "unexpected class name " + predClass + " used by "
291 + currentClass + " classifier");
292 }
293 }
294 }
295
296 final PredictionDistribution result = new PredictionDistribution();
297 predIter = unnormalizedPreds.iterator();
298
299
300 while (predIter.hasNext()) {
301 currentPred = (Prediction) predIter.next();
302 result.add(currentPred);
303
304
305
306
307
308
309 }
310
311
312
313
314
315 return result;
316 }
317
318 /***
319 * {@inheritDoc}
320 */
321 protected void doTrain(final FeatureVector features,
322 final String targetClass,
323 final ContextMap context)
324 throws ProcessingException {
325 final Iterator classIter = innerClassifiers.keySet().iterator();
326 String currentClass;
327 TrainableClassifier classifier;
328 String classToTrain;
329 String contextKey;
330 ContextMap currentContext;
331
332
333 while (classIter.hasNext()) {
334 currentClass = (String) classIter.next();
335 classifier = innerClassifiers.get(currentClass);
336
337 if (currentClass.equals(targetClass)) {
338
339 classToTrain = targetClass;
340 } else {
341
342 classToTrain = currentClass + NEG_SUFFIX;
343 }
344
345
346 contextKey = PREFIX_CONTEXT + currentClass;
347 if (context.containsKey(contextKey)) {
348 currentContext = (ContextMap) context.get(contextKey);
349 } else {
350 currentContext = new ContextMap();
351 }
352
353 classifier.doTrain(features, classToTrain, currentContext);
354 }
355 }
356
357 /***
358 * {@inheritDoc}
359 */
360 public void reset() throws ProcessingException {
361 final Iterator innerIter = innerClassifiers.values().iterator();
362 TrainableClassifier classifier;
363
364
365 while (innerIter.hasNext()) {
366 classifier = (TrainableClassifier) innerIter.next();
367 classifier.reset();
368 }
369 }
370
371 /***
372 * {@inheritDoc}
373 */
374 public ObjectElement toElement() {
375
376 final ObjectElement result = super.toElement();
377
378
379 Element innerElement;
380 Map.Entry<String, TrainableClassifier> inner;
381
382 for (Iterator<Map.Entry<String, TrainableClassifier>> iter =
383 innerClassifiers.entrySet().iterator();
384 iter.hasNext();) {
385 inner = iter.next();
386 innerElement =
387 new DefaultElement(MultiBinaryClassifier.ELEMENT_INNER);
388 result.add(innerElement);
389 innerElement.addAttribute(MultiBinaryClassifier.ATTRIB_FOR,
390 inner.getKey());
391 innerElement.add(inner.getValue().toElement());
392 }
393
394 return result;
395 }
396
397 /***
398 * Returns a string representation of this object.
399 *
400 * @return a textual representation
401 */
402 public String toString() {
403 return new ToStringBuilder(this)
404 .appendSuper(super.toString())
405 .append("inner classifiers", innerClassifiers.values())
406 .toString();
407 }
408
409 /***
410 * {@inheritDoc}
411 */
412 protected boolean trainOnErrorHook(final PredictionDistribution predDist,
413 final FeatureVector features, final String targetClass,
414 final Set candidateClasses, final ContextMap context)
415 throws ProcessingException {
416 boolean result = false;
417 final Iterator predIter = predDist.iterator();
418 Prediction currentPred;
419 String currentClass;
420 TrainableClassifier classifier;
421 ContextMap currentContext;
422 PredictionDistribution currentDist;
423 String classToTrain;
424
425
426
427 while (predIter.hasNext()) {
428 currentPred = (Prediction) predIter.next();
429 currentClass = currentPred.getType();
430 classifier = innerClassifiers.get(currentClass);
431
432
433 currentContext = (ContextMap)
434 context.get(PREFIX_CONTEXT + currentClass);
435 currentDist = (PredictionDistribution)
436 context.get(PREFIX_DIST + currentClass);
437
438 if (currentClass.equals(targetClass)) {
439
440 classToTrain = targetClass;
441 } else {
442
443 classToTrain = currentClass + NEG_SUFFIX;
444 }
445
446 result = classifier.trainOnErrorHook(currentDist, features,
447 classToTrain, classifier.getAllClasses(), currentContext)
448 || result;
449 }
450 return result;
451 }
452
453 }