1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22 package de.fu_berlin.ties.classify;
23
24 import java.io.File;
25 import java.util.ArrayList;
26 import java.util.HashMap;
27 import java.util.Iterator;
28 import java.util.List;
29 import java.util.Map;
30 import java.util.Set;
31 import java.util.SortedSet;
32 import java.util.TreeSet;
33
34 import org.apache.commons.lang.builder.ToStringBuilder;
35
36 import de.fu_berlin.ties.ContextMap;
37 import de.fu_berlin.ties.ProcessingException;
38 import de.fu_berlin.ties.TiesConfiguration;
39 import de.fu_berlin.ties.classify.feature.FeatureTransformer;
40 import de.fu_berlin.ties.classify.feature.FeatureVector;
41
42 /***
43 * This classifier converts an multi-class classification task into a several
44 * binary (two-class) classification task. It wraps several instances of
45 * another classifier that are used to perform the binary classifications and
46 * combines their results.
47 *
48 * <p>When used with <em>N</em> classes, this classifier created and wrapped
49 * <em>N</em> inner binary classifiers. Each inner classifier classifiers
50 * between members (positive instances) and non-members (negative instances) of
51 * a specified class. The inner classifier returning the highest
52 * {@linkplain de.fu_berlin.ties.classify.Prediction#getProbability()
53 * probability} for its (positive) class wins, even if the negative class of
54 * that classifier would be still higher. This is the classical
55 * "one-against-the-rest" (or "one-against-all") setup.
56 *
57 * <p>This classifier differs from
58 * {@link de.fu_berlin.ties.classify.MultiBinaryClassifier} in not using a
59 * "background class" and using <em>N</em> inner classifiers, while
60 * <code>MultiBinaryClassifier</code> uses <em>N-1</em>.
61 *
62 * <p>Instances of this class are thread-safe if and only if instances of the
63 * wrapped classifier are.
64 *
65 * <p><em><strong>WARNING:</strong> The current implementation does
66 * <strong>not</strong> query the {@link
67 * de.fu_berlin.ties.classify.TrainableClassifier#shouldTrain(String,
68 * PredictionDistribution, ContextMap) shouldTrain} method of inner classifiers.
69 * Because of this, classifiers overwriting <code>shouldTrain</code> might not
70 * work correctly within this classifier.</em>
71 *
72 * @author Christian Siefkes
73 * @version $Revision: 1.2 $, $Date: 2004/11/19 14:04:19 $, $Author: siefkes $
74 */
75 public class OneAgainstTheRestClassifier extends TrainableClassifier {
76
77 /***
78 * Key prefix used to store the contexts of inner classifiers.
79 */
80 private static final String PREFIX_CONTEXT = "context-";
81
82 /***
83 * Key prefix used to store the prediction distributions returned by inner
84 * classifiers.
85 */
86 private static final String PREFIX_DIST = "dist-";
87
88 /***
89 * The suffix appended to negative classes.
90 */
91 private static final String NEG_SUFFIX = "-neg";
92
93 /***
94 * Maps from the names of the classes to the binary classifiers
95 * used to decide between positive and negative instances of this class.
96 */
97 private final Map<String, TrainableClassifier> innerClassifiers =
98 new HashMap<String, TrainableClassifier>();
99
100 /***
101 * Creates a new instance.
102 *
103 * @param allValidClasses the set of all valid classes
104 * @param trans the last transformer in the transformer chain to use, or
105 * <code>null</code> if no feature transformers should be used
106 * @param runDirectory optional run directory passed to inner classifiers
107 * of the {@link ExternalClassifier} type
108 * @param innerSpec the specification used to initialize the inner
109 * classifiers, passed to the
110 * {@link TrainableClassifier#createClassifier(Set, File,
111 * FeatureTransformer, String[], TiesConfiguration)} factory method
112 * @param conf used to configure this instance and the inner classifiers
113 * @throws IllegalArgumentException if there are fewer than three classes
114 * (in which case you should use the inner classifier directly since there
115 * is no need to wrap several instances)
116 * @throws ProcessingException if an error occurred while creating this
117 * classifier or one of the wrapped classifiers
118 */
119 public OneAgainstTheRestClassifier(final Set<String> allValidClasses,
120 final FeatureTransformer trans,
121 final File runDirectory,
122 final String[] innerSpec,
123 final TiesConfiguration conf)
124 throws IllegalArgumentException, ProcessingException {
125 super(allValidClasses, trans, conf);
126
127 if (allValidClasses.size() < 3) {
128 throw new IllegalArgumentException("OneAgainstTheRestClassifier "
129 + "requires at least 3 classes instead of "
130 + allValidClasses.size());
131 }
132
133
134 final Iterator<String> classIter= allValidClasses.iterator();
135 String baseClass;
136 TrainableClassifier innerClassifier;
137 Set<String> innerSet;
138
139 while (classIter.hasNext()) {
140
141
142 baseClass = classIter.next();
143 innerSet = createBinarySet(baseClass);
144
145
146
147 innerClassifier = TrainableClassifier.createClassifier(innerSet,
148 runDirectory, null, innerSpec, conf);
149 innerClassifiers.put(baseClass, innerClassifier);
150 }
151 }
152
153 /***
154 * Helper method that creates a set containing the two classes of a
155 * binary classifier.
156 *
157 * @param baseClass the base class to use
158 * @return a set containing a positive and a negative class derived from the
159 * base class; this implementation returns the two classes in alphabetic
160 * order which means that the negative class comes first
161 */
162 protected Set<String> createBinarySet(final String baseClass) {
163 final SortedSet<String> result = new TreeSet<String>();
164 result.add(baseClass);
165 result.add(baseClass + NEG_SUFFIX);
166 return result;
167 }
168
169 /***
170 * {@inheritDoc}
171 * This implementation combines the predictions for the positive class
172 * of all involved inner classifiers.
173 *
174 * <p><em>The probability estimates returned by each classifier are used
175 * "as is", so the result will <strong>not</strong> be a real probability
176 * distribution because sum of all probabilities will be more than 1.
177 * If you want to work on a real probability distribution you have
178 * normalize it yourself.</em>
179 */
180 protected PredictionDistribution doClassify(final FeatureVector features,
181 final Set candidateClasses,
182 final ContextMap context)
183 throws ProcessingException {
184 final List<Prediction> unnormalizedPreds =
185 new ArrayList<Prediction>(candidateClasses.size());
186 final Iterator classIter = candidateClasses.iterator();
187 String currentClass;
188 TrainableClassifier innerClassifier;
189 PredictionDistribution currentDist;
190 Prediction currentPred;
191 String predClass;
192 ContextMap currentContext;
193 Iterator predIter;
194
195
196
197 while (classIter.hasNext()) {
198 currentClass = (String) classIter.next();
199
200
201 currentContext = new ContextMap();
202 innerClassifier = innerClassifiers.get(currentClass);
203 currentDist = innerClassifier.doClassify(features,
204 innerClassifier.getAllClasses(), currentContext);
205
206
207
208 context.put(PREFIX_CONTEXT + currentClass, currentContext);
209 context.put(PREFIX_DIST + currentClass, currentDist);
210
211 predIter = currentDist.iterator();
212
213
214 while (predIter.hasNext()) {
215 currentPred = (Prediction) predIter.next();
216 predClass = currentPred.getType();
217
218 if (predClass.equals(currentClass)) {
219
220 unnormalizedPreds.add(currentPred);
221
222 } else if (!predClass.equals(currentClass + NEG_SUFFIX)) {
223
224 throw new RuntimeException("Implementation error: "
225 + "unexpected class name " + predClass + " used by "
226 + currentClass + " classifier");
227 }
228 }
229 }
230
231 final PredictionDistribution result = new PredictionDistribution();
232 predIter = unnormalizedPreds.iterator();
233
234
235 while (predIter.hasNext()) {
236 currentPred = (Prediction) predIter.next();
237 result.add(currentPred);
238
239
240
241
242
243
244 }
245
246
247
248
249
250 return result;
251 }
252
253 /***
254 * {@inheritDoc}
255 */
256 protected void doTrain(final FeatureVector features,
257 final String targetClass,
258 final ContextMap context)
259 throws ProcessingException {
260 final Iterator classIter = innerClassifiers.keySet().iterator();
261 String currentClass;
262 TrainableClassifier classifier;
263 String classToTrain;
264 String contextKey;
265 ContextMap currentContext;
266
267
268 while (classIter.hasNext()) {
269 currentClass = (String) classIter.next();
270 classifier = innerClassifiers.get(currentClass);
271
272 if (currentClass.equals(targetClass)) {
273
274 classToTrain = targetClass;
275 } else {
276
277 classToTrain = currentClass + NEG_SUFFIX;
278 }
279
280
281 contextKey = PREFIX_CONTEXT + currentClass;
282 if (context.containsKey(contextKey)) {
283 currentContext = (ContextMap) context.get(contextKey);
284 } else {
285 currentContext = new ContextMap();
286 }
287
288 classifier.doTrain(features, classToTrain, currentContext);
289 }
290 }
291
292 /***
293 * Returns a string representation of this object.
294 *
295 * @return a textual representation
296 */
297 public String toString() {
298 return new ToStringBuilder(this)
299 .appendSuper(super.toString())
300 .append("inner classifiers", innerClassifiers.values())
301 .toString();
302 }
303
304 /***
305 * {@inheritDoc}
306 */
307 protected boolean trainOnErrorHook(final PredictionDistribution predDist,
308 final FeatureVector features, final String targetClass,
309 final Set candidateClasses, final ContextMap context)
310 throws ProcessingException {
311 boolean result = false;
312 final Iterator predIter = predDist.iterator();
313 Prediction currentPred;
314 String currentClass;
315 TrainableClassifier classifier;
316 ContextMap currentContext;
317 PredictionDistribution currentDist;
318 String classToTrain;
319
320
321
322 while (predIter.hasNext()) {
323 currentPred = (Prediction) predIter.next();
324 currentClass = currentPred.getType();
325 classifier = innerClassifiers.get(currentClass);
326
327
328 currentContext = (ContextMap)
329 context.get(PREFIX_CONTEXT + currentClass);
330 currentDist = (PredictionDistribution)
331 context.get(PREFIX_DIST + currentClass);
332
333 if (currentClass.equals(targetClass)) {
334
335 classToTrain = targetClass;
336 } else {
337
338 classToTrain = currentClass + NEG_SUFFIX;
339 }
340
341 result = classifier.trainOnErrorHook(currentDist, features,
342 classToTrain, classifier.getAllClasses(), currentContext)
343 || result;
344 }
345 return result;
346 }
347
348 /***
349 * {@inheritDoc}
350 */
351 public void reset() throws ProcessingException {
352 final Iterator innerIter = innerClassifiers.values().iterator();
353 TrainableClassifier classifier;
354
355
356 while (innerIter.hasNext()) {
357 classifier = (TrainableClassifier) innerIter.next();
358 classifier.reset();
359 }
360 }
361
362 }