View Javadoc

1   /*
2    * Copyright (C) 2003-2004 Christian Siefkes <christian@siefkes.net>.
3    * Development of this software is supported by the German Research Society,
4    * Berlin-Brandenburg Graduate School in Distributed Information Systems
5    * (DFG grant no. GRK 316).
6    *
7    * This library is free software; you can redistribute it and/or
8    * modify it under the terms of the GNU Lesser General Public
9    * License as published by the Free Software Foundation; either
10   * version 2.1 of the License, or (at your option) any later version.
11   *
12   * This library is distributed in the hope that it will be useful,
13   * but WITHOUT ANY WARRANTY; without even the implied warranty of
14   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
15   * Lesser General Public License for more details.
16   *
17   * You should have received a copy of the GNU Lesser General Public
18   * License along with this library; if not, visit
19   * http://www.gnu.org/licenses/lgpl.html or write to the Free Software
20   * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307, USA.
21   */
22  package de.fu_berlin.ties.classify;
23  
24  import java.io.File;
25  import java.io.IOException;
26  import java.util.HashSet;
27  import java.util.Iterator;
28  import java.util.Set;
29  import java.util.regex.Matcher;
30  import java.util.regex.Pattern;
31  
32  import org.apache.commons.lang.StringUtils;
33  import org.apache.commons.lang.builder.ToStringBuilder;
34  
35  import de.fu_berlin.ties.ContextMap;
36  import de.fu_berlin.ties.ProcessingException;
37  import de.fu_berlin.ties.TiesConfiguration;
38  import de.fu_berlin.ties.classify.feature.FeatureTransformer;
39  import de.fu_berlin.ties.classify.feature.FeatureVector;
40  import de.fu_berlin.ties.util.ExternalCommand;
41  import de.fu_berlin.ties.util.Util;
42  
43  /***
44   * A proxy that provides a trainable classifier by communicating with an
45   * external (non-Java) program. Program name and command line options of the
46   * external classifier can be configured.
47   *
48   * <p>Instances of this class are thread-safe if and only if several instances
49   * of the external classifier can reliably run in parallel.
50   *
51   * @author Christian Siefkes
52   * @version $Revision: 1.29 $, $Date: 2004/11/19 14:04:19 $, $Author: siefkes $
53   */
54  public class ExternalClassifier extends TrainableClassifier {
55  
56      /***
57       * Configuration key: the directory to run the classifier in (optional,
58       * defaults to current working directory).
59       */
60      public static final String CONFIG_DIR = "classifier.ext.directory";
61  
62      /***
63       * Configuration key: Command name + arguments to call for
64       * classification (list of possible target classes will be second argument,
65       * feature vector will be provided as standard input).
66       */
67      private static final String CONFIG_CMD_CLASSIFY = "classifier.ext.classify";
68  
69      /***
70       * Configuration key: Command name + arguments to call for class
71       * initialization (class to initialize will be last arg).
72       */
73      private static final String CONFIG_CMD_INIT = "classifier.ext.init";
74  
75      /***
76       * Configuration key: Command name + arguments to call for resetting the
77       * classifier by deleting the prediction model (class to reset will be last
78       * arg).
79       */
80      private static final String CONFIG_CMD_RESET = "classifier.ext.reset";
81  
82      /***
83       * Configuration key: Command name + arguments to call for
84       * training (expected target class will be second argument, feature vector
85       * will be provided as standard input).
86       */
87      private static final String CONFIG_CMD_TRAIN = "classifier.ext.train";
88  
89      /***
90       * Configuration key: the suffix to append to classes for the classifier
91       * (optional).
92       */
93      private static final String CONFIG_CLASS_SUFFIX = "classifier.ext.suffix";
94  
95      /***
96       * Configuration key:  regular expression to extract the predicted class
97       * (group 1) and the probability (group 2) from the classifier's standard
98       * output; for all classes or at least for the best one.
99       */
100     private static final String CONFIG_REGEX = "classifier.ext.regex";
101 
102     /***
103      * External command called for classification.
104      */
105     private final ExternalCommand extClassifier;
106 
107     /***
108      * External command called for initialization; might be <code>null</code>.
109      */
110     private final ExternalCommand extInitializer;
111 
112     /***
113      * External command called for resetting the prediction model; might be
114      * <code>null</code>.
115      */
116     private final ExternalCommand extResetter;
117 
118     /***
119      * External command called for training.
120      */
121     private final ExternalCommand extTrainer;
122 
123     /***
124      * The suffix to append to classes for the classifier (might be null).
125      */
126     private final String classSuffix;
127 
128     /***
129      * The directory to run the classifier in (if null, the current working
130      * directory is used).
131      */
132     private final File workDir;
133 
134     /***
135      * Regular expression to extract the predicted class (group 1) and the
136      * probability (group 2) from the classifier's standard output.
137      */
138     private final Pattern predictionPattern;
139 
140     /***
141      * The classifier is trained if the pR is below this value as well as
142      * on errors ("thick threshold" heuristic). {@link Double#NaN} if not used.
143      */
144     private final double thickThresholdPR;
145 
146     /***
147      * The classifier is trained if the probability is below this value as well
148      * as on errors ("thick threshold" heuristic). {@link Double#NaN} if not
149      * used.
150      */
151     private final double thickThresholdProb;
152 
153     /***
154      * Whether this instance has been initialized.
155      */
156     private boolean initialized = false;
157 
158     /***
159      * Creates a new instance based on the
160      * {@linkplain TiesConfiguration#CONF standard configuration}.
161      *
162      * @param allValidClasses the set of all valid classes
163      * @throws ProcessingException if an I/O error occurs during initialization
164      */
165     public ExternalClassifier(final Set<String> allValidClasses)
166             throws ProcessingException {
167         this(allValidClasses, TiesConfiguration.CONF);
168     }
169 
170     /***
171      * Creates a new instance based on the provided configuration.
172      *
173      * @param allValidClasses the set of all valid classes
174      * @param config contains configuration properties
175      * @throws ProcessingException if an I/O error occurs during initialization
176      */
177     public ExternalClassifier(final Set<String> allValidClasses,
178             final TiesConfiguration config) throws ProcessingException {
179         this(allValidClasses, FeatureTransformer.createTransformer(config),
180             null, config);
181     }
182 
183     /***
184      * Creates a new instance based on the provided arguments.
185      *
186      * @param allValidClasses the set of all valid classes
187      * @param trans the last transformer in the transformer chain to use, or
188      * <code>null</code> if no feature transformers should be used
189      * @param runDirectory the directory to run the classifier in; used instead
190      * of the {@linkplain #CONFIG_DIR configured directory} if not
191      * <code>null</code>
192      * @param config contains configuration properties
193      * @throws ProcessingException if an I/O error occurs during initialization
194      */
195     public ExternalClassifier(final Set<String> allValidClasses,
196             final FeatureTransformer trans, final File runDirectory,
197             final TiesConfiguration config)
198             throws ProcessingException {
199         super(allValidClasses, trans, config);
200 
201         // configure instance
202         final String[] cmdClassify = config.getStringArray(CONFIG_CMD_CLASSIFY);
203         final String[] cmdInit = config.getStringArray(CONFIG_CMD_INIT);
204         final String[] cmdReset = config.getStringArray(CONFIG_CMD_RESET);
205         final String[] cmdTrain = config.getStringArray(CONFIG_CMD_TRAIN);
206         predictionPattern = Pattern.compile(config.getString(CONFIG_REGEX));
207 
208         // handle optional properties
209         classSuffix = config.getString(CONFIG_CLASS_SUFFIX, null);
210 
211         final String rawThresholdProb =
212             config.getString("classifier.ext.threshold.prob", null);
213         if (StringUtils.isNotEmpty(rawThresholdProb)) {
214             // we do it like this to cope with empty strings
215             thickThresholdProb = Util.asDouble(rawThresholdProb);
216         } else {
217             thickThresholdProb = Double.NaN;
218         }
219 
220         final String rawThresholdPR =
221             config.getString("classifier.ext.threshold.pR", null);
222         if (StringUtils.isNotEmpty(rawThresholdPR)) {
223             // we do it like this to cope with empty strings
224             thickThresholdPR = Util.asDouble(rawThresholdPR);
225         } else {
226             thickThresholdPR = Double.NaN;
227         }
228 
229         // set directory
230         if (runDirectory != null) {
231             workDir = runDirectory;
232         } else if (config.containsKey(CONFIG_DIR)) {
233             workDir = new File(config.getString(CONFIG_DIR));
234         } else {
235             workDir = null;
236         }
237 
238         // create wrappers for external programs
239         extClassifier = new ExternalCommand(cmdClassify, workDir);
240         extTrainer = new ExternalCommand(cmdTrain, workDir);
241 
242         if (cmdInit.length > 0) {
243             extInitializer = new ExternalCommand(cmdInit, workDir);
244         } else {
245             extInitializer = null;
246         }
247 
248         if (cmdReset.length > 0) {
249             extResetter = new ExternalCommand(cmdReset, workDir);
250         } else {
251             extResetter = null;
252         }
253     }
254 
255     /***
256      * Helper method building a class name by appending the configured suffix,
257      * if any.
258      *
259      * @param baseName the base name of the class
260      * @return the complete class name
261      */
262     private String buildClassName(final String baseName) {
263         // append suffix, if exists
264         if (classSuffix != null) {
265             return baseName + classSuffix;
266         } else {
267             return baseName;
268         }
269     }
270 
271     /***
272      * Classifies an item that is represented by a feature vector by choosing
273      * the most probable class among a set of candidate classes.
274      *
275      * @param features the feature vector to consider
276      * @param candidateClasses an array of the classes that are allowed for
277      * this item
278      * @param context ignored by this implementation
279      * @return the result of the classification; you can call
280      * {@link PredictionDistribution#best()} to get the most probably class;
281      * this classifier returns only the best prediction, so
282      * {@link PredictionDistribution#size()} will be 1
283      * @throws ProcessingException if an I/O error occurs during communication
284      * with the external program
285      */
286     protected PredictionDistribution doClassify(final FeatureVector features,
287             final Set candidateClasses, final ContextMap context)
288     throws ProcessingException {
289         if (!initialized) {
290             initialize();
291         }
292 
293         // build list of candidate classes as last argument
294         final StringBuffer candidateBuffer = new StringBuffer();
295         final Iterator candidateIter = candidateClasses.iterator();
296 
297         while (candidateIter.hasNext()) {
298             candidateBuffer.append(
299                 buildClassName((String) candidateIter.next()));
300 
301             if (candidateIter.hasNext()) {
302                 // no whitespace required after the last candidate
303                 candidateBuffer.append(' ');
304             }
305         }
306         final String[] furtherArg = new String[] {candidateBuffer.toString()};
307         final String output;
308 
309         // run the external program on the flattened features
310         try {
311             output = extClassifier.execute(furtherArg, features.flatten());
312         } catch (IOException ioe) {
313             // wrap and rethrow exception
314             throw new ProcessingException("I/O error while classifying", ioe);
315         }
316 
317         // match output to regex to extract predicted class + probability + pR
318         final Matcher outputMatcher = predictionPattern.matcher(output);
319         final PredictionDistribution predDist = new PredictionDistribution();
320         final Set<String> addedClasses = new HashSet<String>();
321         Prediction pred;
322         String rawPredictedClass, predictedClass;
323         String rawProbability;
324         double probability;
325         String rawPR;
326 
327         // ensure regex matched okay
328         while (outputMatcher.find()) {
329 /*            throw new IllegalArgumentException(
330                     "No match found for extraction pattern '"
331                     + predictionPattern.pattern() + "' in classifier output: '"
332                     + output + "'"); */
333             if (outputMatcher.groupCount() < 2) {
334                 throw new IllegalArgumentException("Extraction pattern '"
335                         + predictionPattern.pattern()
336                         + "' should match 2 or 3 subgroups but matched only "
337                         + outputMatcher.groupCount()  + " in classifier output: '"
338                         + output + "'");
339             }
340     
341             // match and convert subgroups
342             rawPredictedClass = outputMatcher.group(1);
343     
344             if ((classSuffix != null)
345                     && rawPredictedClass.endsWith(classSuffix)) {
346                 // strip suffix from predicted class
347                 predictedClass = rawPredictedClass.substring(0,
348                         rawPredictedClass.length() - classSuffix.length());
349             } else {
350                 predictedClass = rawPredictedClass;
351             }
352     
353             rawProbability = outputMatcher.group(2);
354             probability = Util.asDouble(rawProbability);
355     
356             // pR is given in 3rd group (optional), otherwise NaN
357             final double pR;
358             if (outputMatcher.groupCount() >= 3) {
359                 rawPR = outputMatcher.group(3);
360                 pR = Util.asDouble(rawPR);
361             } else {
362                 pR = Double.NaN;
363             }
364     
365             // create and store prediction (they will be sorted automatically)
366             if (!addedClasses.contains(predictedClass)) {
367                 // avoid adding the same class multiple times
368                 addedClasses.add(predictedClass);
369                 pred = new Prediction(predictedClass,
370                         new Probability(probability, pR));
371                 predDist.add(pred);
372             }
373         }
374 
375         // check that we found at least one class
376         if (predDist.size() < 1) {
377             throw new IllegalArgumentException(
378                     "No match found for extraction pattern '"
379                     + predictionPattern.pattern() + "' in classifier output: '"
380                     + output + "'");
381         }
382 
383         //Util.LOG.debug(predDist + "; classifier output: " + output);
384 
385         return predDist;
386     }
387 
388     /***
389      * {@inheritDoc}
390      */
391     protected void doTrain(final FeatureVector features,
392             final String targetClass, final ContextMap context)
393     throws ProcessingException {
394         if (!initialized) {
395             initialize();
396         }
397 
398         // pass target class as last argument
399         final String[] furtherArg = new String[] {buildClassName(targetClass)};
400 
401         // run the external program on the flattened features
402         // (output can be ignored)
403         try {
404             extTrainer.execute(furtherArg, features.flatten());
405         } catch (IOException ioe) {
406             // wrap and rethrow exception
407             throw new ProcessingException("I/O error while training", ioe);
408         }
409     }
410 
411     /***
412      * Initializes all classes.
413      * @throws ProcessingException if an I/O error occurs during initialization
414      */
415     private void initialize() throws ProcessingException {
416         if (extInitializer != null) {
417             // initialize all classes
418             final Iterator classIter = getAllClasses().iterator();
419 
420             while (classIter.hasNext()) {
421                 init((String) classIter.next());
422             }
423         }
424 
425         initialized = true;
426     }
427 
428     /***
429      * Initializes a class. This method is called once for each class.
430      *
431      * @param cls the class to initialize
432      * @throws ProcessingException if an I/O error occurs during initialization
433      */
434     private void init(final String cls) throws ProcessingException {
435         // pass target class as last argument
436         final String[] furtherArg = new String[] {buildClassName(cls)};
437 
438         // run the external program (output can be ignored)
439         try {
440             extInitializer.execute(furtherArg);
441         } catch (IOException ioe) {
442             // wrap and rethrow exception
443             throw new ProcessingException("I/O error while initializing the "
444                     + cls + " class", ioe);
445         }
446     }
447 
448     /***
449      * {@inheritDoc}
450      */
451     public void reset() throws ProcessingException {
452         if (extResetter != null) {
453             // reset all classes
454             final Iterator classIter = getAllClasses().iterator();
455 
456             while (classIter.hasNext()) {
457                 resetClass((String) classIter.next());
458             }
459         }
460 
461         // re-initialization necessary
462         initialized = false;
463     }
464 
465     /***
466      * Resets a class.
467      *
468      * @param cls the class to reset
469      * @throws ProcessingException if an I/O error occurs during reset
470      */
471     private void resetClass(final String cls) throws ProcessingException {
472         // pass target class as last argument
473         final String[] furtherArg = new String[] {buildClassName(cls)};
474 
475         // run the external program (output can be ignored)
476         try {
477             extResetter.execute(furtherArg);
478         } catch (IOException ioe) {
479             // wrap and rethrow exception
480             throw new ProcessingException("I/O error while resetting the "
481                     + cls + " class", ioe);
482         }
483     }
484 
485     /***
486      * This implementation uses reinforcement training, if a thick threshold
487      * is configured.
488      * {@inheritDoc}
489      */
490     protected boolean shouldTrain(final String targetClass,
491             final PredictionDistribution predDist, final ContextMap context) {
492         final Prediction best = predDist.best();
493         final Probability bestProb = best.getProbability();
494 
495         if (super.shouldTrain(targetClass, predDist, context)) {
496             // normal train-on-error
497             return true;
498         } else if (!Double.isNaN(thickThresholdPR)
499                 && !Double.isNaN(bestProb.getPR())
500                 && bestProb.getPR() < thickThresholdPR) {
501             Util.LOG.debug("Reinforcement training because pR " 
502                     + bestProb.getPR() + " is below the pR threshold "
503                     + thickThresholdPR);
504             return true;
505         } else if (!Double.isNaN(thickThresholdProb)
506                 && bestProb.getProb() < thickThresholdProb) {
507             Util.LOG.debug("Reinforcement training because probability "
508                     + best.getProbability() + " is below the threshold "
509                     + thickThresholdProb);
510             return true;            
511         } else {
512             // no need to train
513             return false;
514         }
515     }
516 
517     /***
518      * Returns a string representation of this object.
519      *
520      * @return a textual representation
521      */
522     public String toString() {
523         final ToStringBuilder builder = new ToStringBuilder(this)
524             .appendSuper(super.toString())
525             .append("classify command", extClassifier)
526             .append("train command", extTrainer)
527             .append("init command", extInitializer)
528             .append("reset command", extResetter)
529             .append("class suffix", classSuffix)
530             .append("prediction pattern", predictionPattern.pattern());
531 
532         if (!Double.isNaN(thickThresholdProb)) {
533             builder.append("thick threshold prob.", thickThresholdProb);
534         }
535         if (!Double.isNaN(thickThresholdPR)) {
536             builder.append("thick threshold pR", thickThresholdPR);
537         }
538 
539         return builder.toString();
540     }
541 
542 }