1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22 package de.fu_berlin.ties.classify;
23
24 import java.io.File;
25 import java.util.Iterator;
26 import java.util.List;
27 import java.util.Set;
28
29 import org.apache.commons.lang.ArrayUtils;
30 import org.apache.commons.lang.builder.ToStringBuilder;
31 import org.dom4j.Element;
32 import org.dom4j.QName;
33 import org.dom4j.tree.DefaultElement;
34
35 import de.fu_berlin.ties.ContextMap;
36 import de.fu_berlin.ties.ProcessingException;
37 import de.fu_berlin.ties.TiesConfiguration;
38 import de.fu_berlin.ties.classify.feature.FeatureTransformer;
39 import de.fu_berlin.ties.classify.feature.FeatureVector;
40 import de.fu_berlin.ties.io.ObjectElement;
41 import de.fu_berlin.ties.util.Util;
42 import de.fu_berlin.ties.xml.dom.DOMUtils;
43
44 /***
45 * A tie classifier combines several layers of classifiers. If the probabilities
46 * of the two best predictions of a layer are close to each other, the next
47 * layer is invoked to resolve this "tie".
48 *
49 * <p><b>This classifier supports <em>only</em> error-driven training since it
50 * necessary to TOE train each classifier to decide whether to train the next
51 * one. Thus you always have to use the
52 * {@link #trainOnError(FeatureVector, String, Set)} method. Trying to
53 * call the {@link
54 * de.fu_berlin.ties.classify.TrainableClassifier#train(FeatureVector, String)}
55 * method instead will result in an
56 * {@link java.lang.UnsupportedOperationException}.</b>
57 *
58 * <p>Instances of this class are thread-safe if and only if instances of the
59 * wrapped classifier are.
60 *
61 * @author Christian Siefkes
62 * @version $Revision: 1.9 $, $Date: 2006/10/21 16:03:55 $, $Author: siefkes $
63 */
64 public class TieClassifier extends TrainableClassifier {
65
66 /***
67 * Attribute name used for XML serialization.
68 */
69 static final QName ATTRIB_TIE_THRESHOLD =
70 DOMUtils.defaultName("tieThreshold");
71
72 /***
73 * Key used to store the contexts of the inner classifiers.
74 */
75 private static final String KEY_INNER_CONTEXTS = "inner-context";
76
77 /***
78 * Key used to store the prediction distributions returned by the inner
79 * classifiers.
80 */
81 private static final String KEY_INNER_DISTS = "inner-dist";
82
83
84 /***
85 * The array of inner classifiers managed by this instance.
86 */
87 private final TrainableClassifier[] inner;
88
89 /***
90 * The next layer is invoked if the relative probability of the second best
91 * prediction as above or equal to this threshold (in the 0 to 1 range).
92 */
93 private final double tieThreshold;
94
95
96 /***
97 * Creates a new instance from an XML element, fulfilling the
98 * recommandation of the {@link de.fu_berlin.ties.io.XMLStorable} interface.
99 *
100 * @param element the XML element containing the serialized representation
101 * @throws InstantiationException if the given element does not contain
102 * a valid classifier description
103 */
104 public TieClassifier(final Element element) throws InstantiationException {
105
106 super(element);
107 final double threshold =
108 Util.asDouble(element.attributeValue(ATTRIB_TIE_THRESHOLD));
109 checkTieThreshold(threshold);
110 tieThreshold = threshold;
111
112
113 List innerElements =
114 element.element(MultiBinaryClassifier.ELEMENT_INNER).elements();
115
116 if (!innerElements.isEmpty()) {
117 inner = new TrainableClassifier[innerElements.size()];
118 final Iterator innerIter = innerElements.iterator();
119
120 for (int i = 0; i < innerElements.size(); i++) {
121 inner[i] = (TrainableClassifier) ObjectElement.createObject(
122 (Element) innerIter.next());
123 }
124 } else {
125 throw new InstantiationException(
126 "TieClassifier: no inner classifiers found");
127 }
128 }
129
130 /***
131 * Creates a new instance.
132 *
133 * @param allValidClasses the set of all valid classes
134 * @param trans the last transformer in the transformer chain to use, or
135 * <code>null</code> if no feature transformers should be used
136 * @param runDirectory optional run directory passed to inner classifiers
137 * of the {@link ExternalClassifier} type
138 * @param innerSpec the specification used to initialize the inner
139 * classifiers, passed to the
140 * {@link TrainableClassifier#createClassifier(Set, File,
141 * FeatureTransformer, String[], TiesConfiguration)} factory method
142 * @param conf used to configure this instance and the inner classifiers
143 * @throws ProcessingException if an error occurred while creating this
144 * classifier or one of the wrapped classifiers
145 */
146 public TieClassifier(final Set<String> allValidClasses,
147 final FeatureTransformer trans, final File runDirectory,
148 final String[] innerSpec, final TiesConfiguration conf)
149 throws ProcessingException {
150 this(allValidClasses, trans, runDirectory, innerSpec,
151 conf.getInt("classifier.tie.layers"),
152 conf.getDouble("classifier.tie.threshold"),
153 conf);
154 }
155
156 /***
157 * Creates a new instance.
158 *
159 * @param allValidClasses the set of all valid classes
160 * @param trans the last transformer in the transformer chain to use, or
161 * <code>null</code> if no feature transformers should be used
162 * @param runDirectory optional run directory passed to inner classifiers
163 * of the {@link ExternalClassifier} type
164 * @param innerSpec the specification used to initialize the inner
165 * classifiers, passed to the
166 * {@link TrainableClassifier#createClassifier(Set, File,
167 * FeatureTransformer, String[], TiesConfiguration)} factory method
168 * @param layers the number of layers to use, must be at least 1
169 * @param threshold the next layer is invoked if the relative probability
170 * of the second best prediction as above or equal to this threshold;
171 * must be a number between 0 and 1
172 * @param conf used to configure this instance as well as the inner
173 * classifiers
174 * @throws ProcessingException if an error occurred while creating this
175 * classifier or one of the wrapped classifiers
176 */
177 public TieClassifier(final Set<String> allValidClasses,
178 final FeatureTransformer trans, final File runDirectory,
179 final String[] innerSpec, final int layers,
180 final double threshold, final TiesConfiguration conf)
181 throws ProcessingException {
182 super(allValidClasses, trans, conf);
183
184
185 if (layers < 1) {
186 throw new IllegalArgumentException(
187 "TieClassifier requires at least 1 layer instead of "
188 + layers);
189 }
190 checkTieThreshold(threshold);
191 tieThreshold = threshold;
192
193
194 inner = new TrainableClassifier[layers];
195 for (int i = 0; i < inner.length; i++) {
196
197
198 inner[i] = TrainableClassifier.createClassifier(allValidClasses,
199 runDirectory, null, innerSpec, conf);
200 }
201 }
202
203 /***
204 * Helper method that checks whether the tie threshold is valid.
205 *
206 * @param threshold the tie threshold to check
207 * @throws IllegalArgumentException if the threshold is less than 0 or
208 * larger than 1
209 */
210 private void checkTieThreshold(final double threshold)
211 throws IllegalArgumentException {
212 if (threshold < 0.0 || threshold > 1.0) {
213 throw new IllegalArgumentException(
214 "Tie threshold must be in the [0, 1] range: "
215 + threshold);
216 }
217 }
218
219 /***
220 * {@inheritDoc}
221 */
222 public void destroy() throws ProcessingException {
223
224 for (int i = 0; i < inner.length; i++) {
225 inner[i].destroy();
226 }
227 }
228
229 /***
230 * {@inheritDoc}
231 */
232 protected PredictionDistribution doClassify(final FeatureVector features,
233 final Set candidateClasses, final ContextMap context)
234 throws ProcessingException {
235
236 final PredictionDistribution[] innerDists =
237 new PredictionDistribution[inner.length];
238 final ContextMap[] innerContexts = new ContextMap[inner.length];
239
240 PredictionDistribution innerDist = null;
241 Iterator<Prediction> predIter;
242 ContextMap innerContext;
243 int i = 0;
244 double bestProb, secondBestProb;
245 boolean tieBetweenProbs = true;
246
247
248
249 while ((i < inner.length) && tieBetweenProbs) {
250
251 innerContext = new ContextMap();
252 innerDist = inner[i].doClassify(features, candidateClasses,
253 innerContext);
254 innerContexts[i] = innerContext;
255 innerDists[i] = innerDist;
256
257
258
259
260 predIter = innerDist.iterator();
261 bestProb = predIter.next().getProbability().getProb();
262 secondBestProb = predIter.next().getProbability().getProb();
263
264 if (secondBestProb >= bestProb * tieThreshold) {
265 tieBetweenProbs = true;
266 Util.LOG.debug("Layer " + i + " of TieClassifier: will invoke "
267 + "next layer (if exists) since probability of 2nd best"
268 + " prediction (" + secondBestProb
269 + ") >= best prediction (" + bestProb
270 + ") * tie threshold (" + tieThreshold + ")");
271 } else {
272 tieBetweenProbs = false;
273 }
274
275 i++;
276 }
277
278
279 context.put(KEY_INNER_CONTEXTS, innerContexts);
280 context.put(KEY_INNER_DISTS, innerDists);
281
282
283 return innerDist;
284 }
285
286 /***
287 * <b>This classifier supports <em>only</em> error-driven training, so you
288 * always have to use the {@link #trainOnError(FeatureVector, String, Set)}
289 * method instead of this one. Trying to call this method instead will
290 * result in an{@link java.lang.UnsupportedOperationException}.</b>
291 *
292 * @param features ignored by this method
293 * @param targetClass ignored by this method
294 * @param context ignored by this method
295 * @throws UnsupportedOperationException always thrown by this method;
296 * use {@link #trainOnError(FeatureVector, String, Set)} instead
297 */
298 protected void doTrain(final FeatureVector features,
299 final String targetClass, final ContextMap context)
300 throws UnsupportedOperationException {
301
302
303 throw new UnsupportedOperationException("TieClassifier supports only "
304 + "error-driven training -- call trainOnError instead of train");
305 }
306
307 /***
308 * {@inheritDoc}
309 */
310 protected boolean doTrainOnError(final PredictionDistribution predDist,
311 final FeatureVector features, final String targetClass,
312 final Set candidateClasses, final ContextMap context)
313 throws ProcessingException {
314
315 final PredictionDistribution[] innerDists =
316 (PredictionDistribution[]) context.get(KEY_INNER_DISTS);
317 final ContextMap[] innerContexts =
318 (ContextMap[]) context.get(KEY_INNER_CONTEXTS);
319
320 boolean innerShouldTrain = true;
321
322
323 for (int innerIndex = 0; (innerIndex < innerDists.length)
324 && (innerContexts[innerIndex] != null); innerIndex++) {
325 innerShouldTrain = inner[innerIndex].doTrainOnError(
326 innerDists[innerIndex], features, targetClass,
327 candidateClasses, innerContexts[innerIndex]);
328 }
329
330
331
332 return innerShouldTrain;
333 }
334
335 /***
336 * {@inheritDoc}
337 */
338 public void reset() throws ProcessingException {
339
340 for (int i = 0; i < inner.length; i++) {
341 inner[i].reset();
342 }
343 }
344
345 /***
346 * {@inheritDoc}
347 */
348 protected boolean shouldTrain(final String targetClass,
349 final PredictionDistribution predDist, final ContextMap context) {
350
351 throw new UnsupportedOperationException("TieClassifier: "
352 + "shouldTrain is not required and thus not supported");
353 }
354
355 /***
356 * {@inheritDoc}
357 */
358 public ObjectElement toElement() {
359
360 final ObjectElement result = super.toElement();
361 result.addAttribute(ATTRIB_TIE_THRESHOLD,
362 Double.toString(tieThreshold));
363
364
365 final Element innerElement =
366 new DefaultElement(MultiBinaryClassifier.ELEMENT_INNER);
367 result.add(innerElement);
368
369 for (int i = 0; i < inner.length; i++) {
370 innerElement.add(inner[i].toElement());
371 }
372
373 return result;
374 }
375
376 /***
377 * Returns a string representation of this object.
378 *
379 * @return a textual representation
380 */
381 public String toString() {
382 return new ToStringBuilder(this)
383 .appendSuper(super.toString())
384 .append("inner classifiers", ArrayUtils.toString(inner))
385 .append("tie threshold", tieThreshold)
386 .toString();
387 }
388
389 /***
390 * {@inheritDoc}
391 */
392 protected boolean trainOnErrorHook(final PredictionDistribution predDist,
393 final FeatureVector features, final String targetClass,
394 final Set candidateClasses, final ContextMap context)
395 throws ProcessingException {
396
397
398
399
400 final PredictionDistribution[] innerDists =
401 (PredictionDistribution[]) context.get(KEY_INNER_DISTS);
402 final ContextMap[] innerContexts =
403 (ContextMap[]) context.get(KEY_INNER_CONTEXTS);
404
405 boolean result = false;
406
407
408
409
410 for (int innerIndex = 0; (innerIndex < innerDists.length)
411 && (innerContexts[innerIndex] != null); innerIndex++) {
412 result = inner[innerIndex].trainOnErrorHook(
413 innerDists[innerIndex], features, targetClass,
414 candidateClasses, innerContexts[innerIndex]) || result;
415 }
416
417 return result;
418 }
419
420 }