View Javadoc

1   /*
2    * Copyright (C) 2003-2006 Christian Siefkes <christian@siefkes.net>.
3    * Development of this software is supported by the German Research Society,
4    * Berlin-Brandenburg Graduate School in Distributed Information Systems
5    * (DFG grant no. GRK 316).
6    *
7    * This program is free software; you can redistribute it and/or modify
8    * it under the terms of the GNU General Public License as published by
9    * the Free Software Foundation; either version 2 of the License, or
10   * (at your option) any later version.
11   *
12   * This program is distributed in the hope that it will be useful,
13   * but WITHOUT ANY WARRANTY; without even the implied warranty of
14   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15   * GNU General Public License for more details.
16   *
17   * You should have received a copy of the GNU General Public License
18   * along with this program; if not, visit
19   * http://www.gnu.org/licenses/gpl.html or write to the Free Software
20   * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
21   */
22  package de.fu_berlin.ties.classify;
23  
24  import java.io.File;
25  import java.io.IOException;
26  import java.util.HashSet;
27  import java.util.Iterator;
28  import java.util.Set;
29  import java.util.regex.Matcher;
30  import java.util.regex.Pattern;
31  
32  import org.apache.commons.lang.StringUtils;
33  import org.apache.commons.lang.builder.ToStringBuilder;
34  
35  import de.fu_berlin.ties.ContextMap;
36  import de.fu_berlin.ties.ProcessingException;
37  import de.fu_berlin.ties.TiesConfiguration;
38  import de.fu_berlin.ties.classify.feature.FeatureTransformer;
39  import de.fu_berlin.ties.classify.feature.FeatureVector;
40  import de.fu_berlin.ties.io.ObjectElement;
41  import de.fu_berlin.ties.util.ExternalCommand;
42  import de.fu_berlin.ties.util.Util;
43  
44  /***
45   * A proxy that provides a trainable classifier by communicating with an
46   * external (non-Java) program. Program name and command line options of the
47   * external classifier can be configured.
48   *
49   * <p>Instances of this class are thread-safe if and only if several instances
50   * of the external classifier can reliably run in parallel.
51   *
52   * @author Christian Siefkes
53   * @version $Revision: 1.35 $, $Date: 2006/10/21 16:03:54 $, $Author: siefkes $
54   */
55  public class ExternalClassifier extends TrainableClassifier {
56  
57      /***
58       * Configuration key: the directory to run the classifier in (optional,
59       * defaults to current working directory).
60       */
61      public static final String CONFIG_DIR = "classifier.ext.directory";
62  
63      /***
64       * Configuration key: Command name + arguments to call for
65       * classification (list of possible target classes will be second argument,
66       * feature vector will be provided as standard input).
67       */
68      private static final String CONFIG_CMD_CLASSIFY = "classifier.ext.classify";
69  
70      /***
71       * Configuration key: Command name + arguments to call for class
72       * initialization (class to initialize will be last arg).
73       */
74      private static final String CONFIG_CMD_INIT = "classifier.ext.init";
75  
76      /***
77       * Configuration key: Command name + arguments to call for resetting the
78       * classifier by deleting the prediction model (class to reset will be last
79       * arg).
80       */
81      private static final String CONFIG_CMD_RESET = "classifier.ext.reset";
82  
83      /***
84       * Configuration key: Command name + arguments to call for
85       * training (expected target class will be second argument, feature vector
86       * will be provided as standard input).
87       */
88      private static final String CONFIG_CMD_TRAIN = "classifier.ext.train";
89  
90      /***
91       * Configuration key: the suffix to append to classes for the classifier
92       * (optional).
93       */
94      private static final String CONFIG_CLASS_SUFFIX = "classifier.ext.suffix";
95  
96      /***
97       * Configuration key:  regular expression to extract the predicted class
98       * (group 1) and the probability (group 2) from the classifier's standard
99       * output; for all classes or at least for the best one.
100      */
101     private static final String CONFIG_REGEX = "classifier.ext.regex";
102 
103     /***
104      * External command called for classification.
105      */
106     private final ExternalCommand extClassifier;
107 
108     /***
109      * External command called for initialization; might be <code>null</code>.
110      */
111     private final ExternalCommand extInitializer;
112 
113     /***
114      * External command called for resetting the prediction model; might be
115      * <code>null</code>.
116      */
117     private final ExternalCommand extResetter;
118 
119     /***
120      * External command called for training.
121      */
122     private final ExternalCommand extTrainer;
123 
124     /***
125      * The suffix to append to classes for the classifier (might be null).
126      */
127     private final String classSuffix;
128 
129     /***
130      * The directory to run the classifier in (if null, the current working
131      * directory is used).
132      */
133     private final File workDir;
134 
135     /***
136      * Regular expression to extract the predicted class (group 1) and the
137      * probability (group 2) from the classifier's standard output.
138      */
139     private final Pattern predictionPattern;
140 
141     /***
142      * The classifier is trained if the pR is below this value as well as
143      * on errors ("thick threshold" heuristic). {@link Double#NaN} if not used.
144      */
145     private final double thickThresholdPR;
146 
147     /***
148      * The classifier is trained if the probability is below this value as well
149      * as on errors ("thick threshold" heuristic). {@link Double#NaN} if not
150      * used.
151      */
152     private final double thickThresholdProb;
153 
154     /***
155      * Whether this instance has been initialized.
156      */
157     private boolean initialized = false;
158 
159     /***
160      * Creates a new instance based on the
161      * {@linkplain TiesConfiguration#CONF standard configuration}.
162      *
163      * @param allValidClasses the set of all valid classes
164      * @throws ProcessingException if an I/O error occurs during initialization
165      */
166     public ExternalClassifier(final Set<String> allValidClasses)
167             throws ProcessingException {
168         this(allValidClasses, TiesConfiguration.CONF);
169     }
170 
171     /***
172      * Creates a new instance based on the provided configuration.
173      *
174      * @param allValidClasses the set of all valid classes
175      * @param config contains configuration properties
176      * @throws ProcessingException if an I/O error occurs during initialization
177      */
178     public ExternalClassifier(final Set<String> allValidClasses,
179             final TiesConfiguration config) throws ProcessingException {
180         this(allValidClasses, FeatureTransformer.createTransformer(config),
181             null, config);
182     }
183 
184     /***
185      * Creates a new instance based on the provided arguments.
186      *
187      * @param allValidClasses the set of all valid classes
188      * @param trans the last transformer in the transformer chain to use, or
189      * <code>null</code> if no feature transformers should be used
190      * @param runDirectory the directory to run the classifier in; used instead
191      * of the {@linkplain #CONFIG_DIR configured directory} if not
192      * <code>null</code>
193      * @param config contains configuration properties
194      * @throws ProcessingException if an I/O error occurs during initialization
195      */
196     public ExternalClassifier(final Set<String> allValidClasses,
197             final FeatureTransformer trans, final File runDirectory,
198             final TiesConfiguration config)
199             throws ProcessingException {
200         super(allValidClasses, trans, config);
201 
202         // configure instance
203         final String[] cmdClassify = config.getStringArray(CONFIG_CMD_CLASSIFY);
204         final String[] cmdInit = config.getStringArray(CONFIG_CMD_INIT);
205         final String[] cmdReset = config.getStringArray(CONFIG_CMD_RESET);
206         final String[] cmdTrain = config.getStringArray(CONFIG_CMD_TRAIN);
207         predictionPattern = Pattern.compile(config.getString(CONFIG_REGEX));
208 
209         // handle optional properties
210         classSuffix = config.getString(CONFIG_CLASS_SUFFIX, null);
211 
212         final String rawThresholdProb =
213             config.getString("classifier.ext.threshold.prob", null);
214         if (StringUtils.isNotEmpty(rawThresholdProb)) {
215             // we do it like this to cope with empty strings
216             thickThresholdProb = Util.asDouble(rawThresholdProb);
217         } else {
218             thickThresholdProb = Double.NaN;
219         }
220 
221         final String rawThresholdPR =
222             config.getString("classifier.ext.threshold.pR", null);
223         if (StringUtils.isNotEmpty(rawThresholdPR)) {
224             // we do it like this to cope with empty strings
225             thickThresholdPR = Util.asDouble(rawThresholdPR);
226         } else {
227             thickThresholdPR = Double.NaN;
228         }
229 
230         // set directory
231         if (runDirectory != null) {
232             workDir = runDirectory;
233         } else if (config.containsKey(CONFIG_DIR)) {
234             workDir = new File(config.getString(CONFIG_DIR));
235         } else {
236             workDir = null;
237         }
238 
239         // create wrappers for external programs
240         extClassifier = new ExternalCommand(cmdClassify, workDir);
241         extTrainer = new ExternalCommand(cmdTrain, workDir);
242 
243         if (cmdInit.length > 0) {
244             extInitializer = new ExternalCommand(cmdInit, workDir);
245         } else {
246             extInitializer = null;
247         }
248 
249         if (cmdReset.length > 0) {
250             extResetter = new ExternalCommand(cmdReset, workDir);
251         } else {
252             extResetter = null;
253         }
254     }
255 
256     /***
257      * Helper method building a class name by appending the configured suffix,
258      * if any.
259      *
260      * @param baseName the base name of the class
261      * @return the complete class name
262      */
263     private String buildClassName(final String baseName) {
264         // append suffix, if exists
265         if (classSuffix != null) {
266             return baseName + classSuffix;
267         } else {
268             return baseName;
269         }
270     }
271 
272     /***
273      * Classifies an item that is represented by a feature vector by choosing
274      * the most probable class among a set of candidate classes.
275      *
276      * @param features the feature vector to consider
277      * @param candidateClasses an array of the classes that are allowed for
278      * this item
279      * @param context ignored by this implementation
280      * @return the result of the classification; you can call
281      * {@link PredictionDistribution#best()} to get the most probably class;
282      * this classifier returns only the best prediction, so
283      * {@link PredictionDistribution#size()} will be 1
284      * @throws ProcessingException if an I/O error occurs during communication
285      * with the external program
286      */
287     protected PredictionDistribution doClassify(final FeatureVector features,
288             final Set candidateClasses, final ContextMap context)
289     throws ProcessingException {
290         if (!initialized) {
291             initialize();
292         }
293 
294         // build list of candidate classes as last argument
295         final StringBuilder candidateBuffer = new StringBuilder();
296         final Iterator candidateIter = candidateClasses.iterator();
297 
298         while (candidateIter.hasNext()) {
299             candidateBuffer.append(
300                 buildClassName((String) candidateIter.next()));
301 
302             if (candidateIter.hasNext()) {
303                 // no whitespace required after the last candidate
304                 candidateBuffer.append(' ');
305             }
306         }
307         final String[] furtherArg = new String[] {candidateBuffer.toString()};
308         final String output;
309 
310         // run the external program on the flattened features
311         try {
312             output = extClassifier.execute(furtherArg, features.flatten());
313         } catch (IOException ioe) {
314             // wrap and rethrow exception
315             throw new ProcessingException("I/O error while classifying", ioe);
316         }
317 
318         // match output to regex to extract predicted class + probability + pR
319         final Matcher outputMatcher = predictionPattern.matcher(output);
320         final PredictionDistribution predDist = new PredictionDistribution();
321         final Set<String> addedClasses = new HashSet<String>();
322         Prediction pred;
323         String rawPredictedClass, predictedClass;
324         String rawProbability;
325         double probability;
326         String rawPR;
327 
328         // ensure regex matched okay
329         while (outputMatcher.find()) {
330 /*            throw new IllegalArgumentException(
331                     "No match found for extraction pattern '"
332                     + predictionPattern.pattern() + "' in classifier output: '"
333                     + output + "'"); */
334             if (outputMatcher.groupCount() < 2) {
335                 throw new IllegalArgumentException("Extraction pattern '"
336                     + predictionPattern.pattern()
337                     + "' should match 2 or 3 subgroups but matched only "
338                     + outputMatcher.groupCount()  + " in classifier output: '"
339                     + output + "'");
340             }
341 
342             // match and convert subgroups
343             rawPredictedClass = outputMatcher.group(1);
344 
345             if ((classSuffix != null)
346                     && rawPredictedClass.endsWith(classSuffix)) {
347                 // strip suffix from predicted class
348                 predictedClass = rawPredictedClass.substring(0,
349                         rawPredictedClass.length() - classSuffix.length());
350             } else {
351                 predictedClass = rawPredictedClass;
352             }
353 
354             rawProbability = outputMatcher.group(2);
355             probability = Util.asDouble(rawProbability);
356 
357             // pR is given in 3rd group (optional), otherwise NaN
358             final double pR;
359             if (outputMatcher.groupCount() >= 3) {
360                 rawPR = outputMatcher.group(3);
361                 pR = Util.asDouble(rawPR);
362             } else {
363                 pR = Double.NaN;
364             }
365 
366             // create and store prediction (they will be sorted automatically)
367             if (!addedClasses.contains(predictedClass)) {
368                 // avoid adding the same class multiple times
369                 addedClasses.add(predictedClass);
370                 pred = new Prediction(predictedClass,
371                         new Probability(probability, pR));
372                 predDist.add(pred);
373             }
374         }
375 
376         // check that we found at least one class
377         if (predDist.size() < 1) {
378             throw new IllegalArgumentException(
379                     "No match found for extraction pattern '"
380                     + predictionPattern.pattern() + "' in classifier output: '"
381                     + output + "'");
382         }
383 
384         //Util.LOG.debug(predDist + "; classifier output: " + output);
385 
386         return predDist;
387     }
388 
389     /***
390      * {@inheritDoc}
391      */
392     protected void doTrain(final FeatureVector features,
393             final String targetClass, final ContextMap context)
394     throws ProcessingException {
395         if (!initialized) {
396             initialize();
397         }
398 
399         // pass target class as last argument
400         final String[] furtherArg = new String[] {buildClassName(targetClass)};
401 
402         // run the external program on the flattened features
403         // (output can be ignored)
404         try {
405             extTrainer.execute(furtherArg, features.flatten());
406         } catch (IOException ioe) {
407             // wrap and rethrow exception
408             throw new ProcessingException("I/O error while training", ioe);
409         }
410     }
411 
412     /***
413      * Initializes all classes.
414      * @throws ProcessingException if an I/O error occurs during initialization
415      */
416     private void initialize() throws ProcessingException {
417         if (extInitializer != null) {
418             // initialize all classes
419             final Iterator classIter = getAllClasses().iterator();
420 
421             while (classIter.hasNext()) {
422                 init((String) classIter.next());
423             }
424         }
425 
426         initialized = true;
427     }
428 
429     /***
430      * Initializes a class. This method is called once for each class.
431      *
432      * @param cls the class to initialize
433      * @throws ProcessingException if an I/O error occurs during initialization
434      */
435     private void init(final String cls) throws ProcessingException {
436         // pass target class as last argument
437         final String[] furtherArg = new String[] {buildClassName(cls)};
438 
439         // run the external program (output can be ignored)
440         try {
441             extInitializer.execute(furtherArg);
442         } catch (IOException ioe) {
443             // wrap and rethrow exception
444             throw new ProcessingException("I/O error while initializing the "
445                     + cls + " class", ioe);
446         }
447     }
448 
449     /***
450      * {@inheritDoc}
451      */
452     public void reset() throws ProcessingException {
453         if (extResetter != null) {
454             // reset all classes
455             final Iterator classIter = getAllClasses().iterator();
456 
457             while (classIter.hasNext()) {
458                 resetClass((String) classIter.next());
459             }
460         }
461 
462         // re-initialization necessary
463         initialized = false;
464     }
465 
466     /***
467      * Resets a class.
468      *
469      * @param cls the class to reset
470      * @throws ProcessingException if an I/O error occurs during reset
471      */
472     private void resetClass(final String cls) throws ProcessingException {
473         // pass target class as last argument
474         final String[] furtherArg = new String[] {buildClassName(cls)};
475 
476         // run the external program (output can be ignored)
477         try {
478             extResetter.execute(furtherArg);
479         } catch (IOException ioe) {
480             // wrap and rethrow exception
481             throw new ProcessingException("I/O error while resetting the "
482                     + cls + " class", ioe);
483         }
484     }
485 
486     /***
487      * This implementation uses reinforcement training, if a thick threshold
488      * is configured.
489      * {@inheritDoc}
490      */
491     protected boolean shouldTrain(final String targetClass,
492             final PredictionDistribution predDist, final ContextMap context) {
493         final Prediction best = predDist.best();
494         final Probability bestProb = best.getProbability();
495 
496         if (super.shouldTrain(targetClass, predDist, context)) {
497             // normal train-on-error
498             return true;
499         } else if (!Double.isNaN(thickThresholdPR)
500                 && !Double.isNaN(bestProb.getPR())
501                 && bestProb.getPR() < thickThresholdPR) {
502             Util.LOG.debug("Reinforcement training because pR " 
503                     + bestProb.getPR() + " is below the pR threshold "
504                     + thickThresholdPR);
505             return true;
506         } else if (!Double.isNaN(thickThresholdProb)
507                 && bestProb.getProb() < thickThresholdProb) {
508             Util.LOG.debug("Reinforcement training because probability "
509                     + best.getProbability() + " is below the threshold "
510                     + thickThresholdProb);
511             return true;
512         } else {
513             // no need to train
514             return false;
515         }
516     }
517 
518     /***
519      * {@inheritDoc} Currently, this classifier does not support XML
520      * serialization, throwing an {@link UnsupportedOperationException} instead.
521      *
522      * @throws UnsupportedOperationException always thrown by this
523      * implementation
524      */
525     public ObjectElement toElement() throws UnsupportedOperationException {
526             throw new UnsupportedOperationException(
527                     "XML serialization is not supported by ExternalClassifier");
528     }
529 
530     /***
531      * Returns a string representation of this object.
532      *
533      * @return a textual representation
534      */
535     public String toString() {
536         final ToStringBuilder builder = new ToStringBuilder(this)
537             .appendSuper(super.toString())
538             .append("classify command", extClassifier)
539             .append("train command", extTrainer)
540             .append("init command", extInitializer)
541             .append("reset command", extResetter)
542             .append("class suffix", classSuffix)
543             .append("prediction pattern", predictionPattern.pattern());
544 
545         if (!Double.isNaN(thickThresholdProb)) {
546             builder.append("thick threshold prob.", thickThresholdProb);
547         }
548         if (!Double.isNaN(thickThresholdPR)) {
549             builder.append("thick threshold pR", thickThresholdPR);
550         }
551 
552         return builder.toString();
553     }
554 
555 }