View Javadoc

1   /*
2    * Copyright (C) 2004-2006 Christian Siefkes <christian@siefkes.net>.
3    * Development of this software is supported by the German Research Society,
4    * Berlin-Brandenburg Graduate School in Distributed Information Systems
5    * (DFG grant no. GRK 316).
6    *
7    * This program is free software; you can redistribute it and/or modify
8    * it under the terms of the GNU General Public License as published by
9    * the Free Software Foundation; either version 2 of the License, or
10   * (at your option) any later version.
11   *
12   * This program is distributed in the hope that it will be useful,
13   * but WITHOUT ANY WARRANTY; without even the implied warranty of
14   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15   * GNU General Public License for more details.
16   *
17   * You should have received a copy of the GNU General Public License
18   * along with this program; if not, visit
19   * http://www.gnu.org/licenses/gpl.html or write to the Free Software
20   * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
21   */
22  package de.fu_berlin.ties.classify;
23  
24  import java.io.File;
25  import java.util.Iterator;
26  import java.util.List;
27  import java.util.Set;
28  
29  import org.apache.commons.lang.ArrayUtils;
30  import org.apache.commons.lang.builder.ToStringBuilder;
31  import org.dom4j.Element;
32  import org.dom4j.QName;
33  import org.dom4j.tree.DefaultElement;
34  
35  import de.fu_berlin.ties.ContextMap;
36  import de.fu_berlin.ties.ProcessingException;
37  import de.fu_berlin.ties.TiesConfiguration;
38  import de.fu_berlin.ties.classify.feature.FeatureTransformer;
39  import de.fu_berlin.ties.classify.feature.FeatureVector;
40  import de.fu_berlin.ties.io.ObjectElement;
41  import de.fu_berlin.ties.util.Util;
42  import de.fu_berlin.ties.xml.dom.DOMUtils;
43  
44  /***
45   * A tie classifier combines several layers of classifiers. If the probabilities
46   * of the two best predictions of a layer are close to each other, the next
47   * layer is invoked to resolve this "tie".
48   *
49   * <p><b>This classifier supports <em>only</em> error-driven training since it
50   * necessary to TOE train each classifier to decide whether to train the next
51   * one. Thus you always have to use the
52   * {@link #trainOnError(FeatureVector, String, Set)} method. Trying to
53   * call the {@link
54   * de.fu_berlin.ties.classify.TrainableClassifier#train(FeatureVector, String)}
55   * method instead will result in an
56   * {@link java.lang.UnsupportedOperationException}.</b>
57   *
58   * <p>Instances of this class are thread-safe if and only if instances of the
59   * wrapped classifier are.
60   *
61   * @author Christian Siefkes
62   * @version $Revision: 1.9 $, $Date: 2006/10/21 16:03:55 $, $Author: siefkes $
63   */
64  public class TieClassifier extends TrainableClassifier {
65  
66      /***
67       * Attribute name used for XML serialization.
68       */
69      static final QName ATTRIB_TIE_THRESHOLD =
70          DOMUtils.defaultName("tieThreshold");
71  
72      /***
73       * Key used to store the contexts of the inner classifiers.
74       */
75      private static final String KEY_INNER_CONTEXTS = "inner-context";
76  
77      /***
78       * Key used to store the prediction distributions returned by the inner
79       * classifiers.
80       */
81      private static final String KEY_INNER_DISTS = "inner-dist";
82  
83  
84      /***
85       * The array of inner classifiers managed by this instance.
86       */
87      private final TrainableClassifier[] inner;
88  
89      /***
90       * The next layer is invoked if the relative probability of the second best
91       * prediction as above or equal to this threshold (in the 0 to 1 range).
92       */
93      private final double tieThreshold;
94  
95  
96      /***
97       * Creates a new instance from an XML element, fulfilling the
98       * recommandation of the {@link de.fu_berlin.ties.io.XMLStorable} interface.
99       *
100      * @param element the XML element containing the serialized representation
101      * @throws InstantiationException if the given element does not contain
102      * a valid classifier description
103      */
104     public TieClassifier(final Element element) throws InstantiationException {
105         // delegate to superclass + check & init threshold
106         super(element);
107         final double threshold =
108             Util.asDouble(element.attributeValue(ATTRIB_TIE_THRESHOLD));
109         checkTieThreshold(threshold);
110         tieThreshold = threshold;
111 
112         // initialize inner classifiers
113         List innerElements =
114             element.element(MultiBinaryClassifier.ELEMENT_INNER).elements();
115 
116         if (!innerElements.isEmpty()) {
117             inner = new TrainableClassifier[innerElements.size()];
118             final Iterator innerIter = innerElements.iterator();
119 
120             for (int i = 0; i < innerElements.size(); i++) {
121                 inner[i] = (TrainableClassifier) ObjectElement.createObject(
122                         (Element) innerIter.next());
123             }
124         } else {
125             throw new InstantiationException(
126                     "TieClassifier: no inner classifiers found");
127         }
128     }
129 
130     /***
131      * Creates a new instance.
132      *
133      * @param allValidClasses the set of all valid classes
134      * @param trans the last transformer in the transformer chain to use, or
135      * <code>null</code> if no feature transformers should be used
136      * @param runDirectory optional run directory passed to inner classifiers
137      * of the {@link ExternalClassifier} type
138      * @param innerSpec the specification used to initialize the inner
139      * classifiers, passed to the
140      * {@link TrainableClassifier#createClassifier(Set, File,
141      * FeatureTransformer, String[], TiesConfiguration)} factory method
142      * @param conf used to configure this instance and the inner classifiers
143      * @throws ProcessingException if an error occurred while creating this
144      * classifier or one of the wrapped classifiers
145      */
146     public TieClassifier(final Set<String> allValidClasses,
147             final FeatureTransformer trans, final File runDirectory,
148             final String[] innerSpec, final TiesConfiguration conf)
149     throws ProcessingException {
150         this(allValidClasses, trans, runDirectory, innerSpec,
151                 conf.getInt("classifier.tie.layers"),
152                 conf.getDouble("classifier.tie.threshold"),
153                 conf);
154     }
155 
156     /***
157      * Creates a new instance.
158      *
159      * @param allValidClasses the set of all valid classes
160      * @param trans the last transformer in the transformer chain to use, or
161      * <code>null</code> if no feature transformers should be used
162      * @param runDirectory optional run directory passed to inner classifiers
163      * of the {@link ExternalClassifier} type
164      * @param innerSpec the specification used to initialize the inner
165      * classifiers, passed to the
166      * {@link TrainableClassifier#createClassifier(Set, File,
167      * FeatureTransformer, String[], TiesConfiguration)} factory method
168      * @param layers the number of layers to use, must be at least 1
169      * @param threshold the next layer is invoked if the relative probability
170      * of the second best prediction as above or equal to this threshold;
171      * must be a number between 0 and 1
172      * @param conf used to configure this instance as well as the inner
173      * classifiers
174      * @throws ProcessingException if an error occurred while creating this
175      * classifier or one of the wrapped classifiers
176      */
177     public TieClassifier(final Set<String> allValidClasses,
178             final FeatureTransformer trans, final File runDirectory,
179             final String[] innerSpec, final int layers,
180             final double threshold, final TiesConfiguration conf)
181     throws ProcessingException {
182         super(allValidClasses, trans, conf);
183 
184         // check arguments and store threshold
185         if (layers < 1) {
186             throw new IllegalArgumentException(
187                 "TieClassifier requires at least 1 layer instead of "
188                     + layers);
189         }
190         checkTieThreshold(threshold);
191         tieThreshold = threshold;
192 
193         // init classifiers for each layer
194         inner = new TrainableClassifier[layers];
195         for (int i = 0; i < inner.length; i++) {
196             // Transformer is set to null because features shouldn't be
197             // transformed twice
198             inner[i] = TrainableClassifier.createClassifier(allValidClasses,
199                 runDirectory, null, innerSpec, conf);
200         }
201     }
202 
203     /***
204      * Helper method that checks whether the tie threshold is valid.
205      *
206      * @param threshold the tie threshold to check
207      * @throws IllegalArgumentException if the threshold is less than 0 or
208      * larger than 1
209      */
210     private void checkTieThreshold(final double threshold)
211     throws IllegalArgumentException {
212         if (threshold < 0.0 || threshold > 1.0) {
213             throw new IllegalArgumentException(
214                     "Tie threshold must be in the [0, 1] range: "
215                     + threshold);
216         }
217     }
218 
219     /***
220      * {@inheritDoc}
221      */
222     public void destroy() throws ProcessingException {
223         // destroy all layers
224         for (int i = 0; i < inner.length; i++) {
225             inner[i].destroy();
226         }
227     }
228 
229     /***
230      * {@inheritDoc}
231      */
232     protected PredictionDistribution doClassify(final FeatureVector features,
233             final Set candidateClasses, final ContextMap context)
234             throws ProcessingException {
235         // Used to populate the context
236         final PredictionDistribution[] innerDists =
237             new PredictionDistribution[inner.length];
238         final ContextMap[] innerContexts = new ContextMap[inner.length];
239 
240         PredictionDistribution innerDist = null;
241         Iterator<Prediction> predIter;
242         ContextMap innerContext;
243         int i = 0;
244         double bestProb, secondBestProb;
245         boolean tieBetweenProbs = true;
246 
247         // invoke inner classifier until 2nd best prediction is below the
248         // threshold or all layers are exhausted
249         while ((i < inner.length) && tieBetweenProbs) {
250             // invoke i-th inner classifier
251             innerContext = new ContextMap();
252             innerDist = inner[i].doClassify(features, candidateClasses,
253                     innerContext);
254             innerContexts[i] = innerContext;
255             innerDists[i] = innerDist;
256 
257             // check if two best probs are near to each other.
258             // We assume that there are at least 2 classes -- otherwise there
259             // would be no need to classify
260             predIter = innerDist.iterator();
261             bestProb = predIter.next().getProbability().getProb();
262             secondBestProb = predIter.next().getProbability().getProb();
263 
264             if (secondBestProb >= bestProb * tieThreshold) {
265                 tieBetweenProbs = true;
266                 Util.LOG.debug("Layer " + i + " of TieClassifier: will invoke "
267                         + "next layer (if exists) since probability of 2nd best"
268                         + " prediction (" + secondBestProb
269                         + ")  >= best prediction (" + bestProb
270                         + ") * tie threshold (" + tieThreshold + ")");
271             } else {
272                 tieBetweenProbs = false;
273             }
274 
275             i++;
276         }
277 
278         // store distributions, and contexts
279         context.put(KEY_INNER_CONTEXTS, innerContexts);
280         context.put(KEY_INNER_DISTS, innerDists);
281 
282         // return decision of last invoked layer
283         return innerDist;
284     }
285 
286     /***
287      * <b>This classifier supports <em>only</em> error-driven training, so you
288      * always have to use the {@link #trainOnError(FeatureVector, String, Set)}
289      * method instead of this one. Trying to call this method instead will
290      * result in an{@link java.lang.UnsupportedOperationException}.</b>
291      *
292      * @param features ignored by this method
293      * @param targetClass ignored by this method
294      * @param context ignored by this method
295      * @throws UnsupportedOperationException always thrown by this method;
296      * use {@link #trainOnError(FeatureVector, String, Set)} instead
297      */
298     protected void doTrain(final FeatureVector features,
299             final String targetClass, final ContextMap context)
300     throws UnsupportedOperationException {
301         // we cannot support this method because we need to TOE train each
302         // classifier to decide whether to train the next one
303         throw new UnsupportedOperationException("TieClassifier supports only "
304             + "error-driven training -- call trainOnError instead of train");
305     }
306 
307     /***
308      * {@inheritDoc}
309      */
310     protected boolean doTrainOnError(final PredictionDistribution predDist,
311             final FeatureVector features, final String targetClass,
312             final Set candidateClasses, final ContextMap context)
313             throws ProcessingException {
314         // retrieve objects stored in context
315         final PredictionDistribution[] innerDists =
316             (PredictionDistribution[]) context.get(KEY_INNER_DISTS);
317         final ContextMap[] innerContexts =
318             (ContextMap[]) context.get(KEY_INNER_CONTEXTS);
319 
320         boolean innerShouldTrain = true;
321 
322         // delegate to doTrainOnError method of each invoked classifier
323         for (int innerIndex = 0; (innerIndex < innerDists.length)
324                 && (innerContexts[innerIndex] != null); innerIndex++) {
325             innerShouldTrain = inner[innerIndex].doTrainOnError(
326                     innerDists[innerIndex], features, targetClass,
327                     candidateClasses, innerContexts[innerIndex]);
328         }
329 
330         // return shouldTrain statement of last inner classifier invoked during
331         // classification
332         return innerShouldTrain;
333     }
334 
335     /***
336      * {@inheritDoc}
337      */
338     public void reset() throws ProcessingException {
339         // reset all layers
340         for (int i = 0; i < inner.length; i++) {
341             inner[i].reset();
342         }
343     }
344 
345     /***
346      * {@inheritDoc}
347      */
348     protected boolean shouldTrain(final String targetClass,
349             final PredictionDistribution predDist, final ContextMap context) {
350         // method should never be called since is not used by doTrainOnError
351         throw new UnsupportedOperationException("TieClassifier: "
352                 + "shouldTrain is not required and thus not supported");
353     }
354 
355     /***
356      * {@inheritDoc}
357      */
358     public ObjectElement toElement() {
359         // delegate to superclass + add tie tresheshold attribute
360         final ObjectElement result = super.toElement();
361         result.addAttribute(ATTRIB_TIE_THRESHOLD,
362                 Double.toString(tieThreshold));
363 
364         // add all inner classifiers within a single <inner> element
365         final Element innerElement =
366             new DefaultElement(MultiBinaryClassifier.ELEMENT_INNER);
367         result.add(innerElement);
368 
369         for (int i = 0; i < inner.length; i++) {
370             innerElement.add(inner[i].toElement());
371         }
372 
373         return result;
374     }
375 
376     /***
377      * Returns a string representation of this object.
378      *
379      * @return a textual representation
380      */
381     public String toString() {
382         return new ToStringBuilder(this)
383             .appendSuper(super.toString())
384             .append("inner classifiers", ArrayUtils.toString(inner))
385             .append("tie threshold", tieThreshold)
386             .toString();
387     }
388 
389     /***
390      * {@inheritDoc}
391      */
392     protected boolean trainOnErrorHook(final PredictionDistribution predDist,
393             final FeatureVector features, final String targetClass,
394             final Set candidateClasses, final ContextMap context)
395     throws ProcessingException {
396         // method shouldn't be called in normal operation, but might if
397         // we're inside another classifier (e.g. MultiBinary)
398 
399         // retrieve objects stored in context
400         final PredictionDistribution[] innerDists =
401             (PredictionDistribution[]) context.get(KEY_INNER_DISTS);
402         final ContextMap[] innerContexts =
403             (ContextMap[]) context.get(KEY_INNER_CONTEXTS);
404 
405         boolean result = false;
406 
407         // delegate to doTrainOnError method of each invoked classifier,
408         // returning ORing of results (i.e. return true if at least one of them
409         // returns true)
410         for (int innerIndex = 0; (innerIndex < innerDists.length)
411                 && (innerContexts[innerIndex] != null); innerIndex++) {
412             result = inner[innerIndex].trainOnErrorHook(
413                     innerDists[innerIndex], features, targetClass,
414                     candidateClasses, innerContexts[innerIndex]) || result;
415         }
416 
417         return result;
418     }
419 
420 }