1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22 package de.fu_berlin.ties.classify;
23
24 import java.io.File;
25 import java.io.IOException;
26 import java.util.HashSet;
27 import java.util.Iterator;
28 import java.util.Set;
29 import java.util.regex.Matcher;
30 import java.util.regex.Pattern;
31
32 import org.apache.commons.lang.StringUtils;
33 import org.apache.commons.lang.builder.ToStringBuilder;
34
35 import de.fu_berlin.ties.ContextMap;
36 import de.fu_berlin.ties.ProcessingException;
37 import de.fu_berlin.ties.TiesConfiguration;
38 import de.fu_berlin.ties.classify.feature.FeatureTransformer;
39 import de.fu_berlin.ties.classify.feature.FeatureVector;
40 import de.fu_berlin.ties.util.ExternalCommand;
41 import de.fu_berlin.ties.util.Util;
42
43 /***
44 * A proxy that provides a trainable classifier by communicating with an
45 * external (non-Java) program. Program name and command line options of the
46 * external classifier can be configured.
47 *
48 * <p>Instances of this class are thread-safe if and only if several instances
49 * of the external classifier can reliably run in parallel.
50 *
51 * @author Christian Siefkes
52 * @version $Revision: 1.29 $, $Date: 2004/11/19 14:04:19 $, $Author: siefkes $
53 */
54 public class ExternalClassifier extends TrainableClassifier {
55
56 /***
57 * Configuration key: the directory to run the classifier in (optional,
58 * defaults to current working directory).
59 */
60 public static final String CONFIG_DIR = "classifier.ext.directory";
61
62 /***
63 * Configuration key: Command name + arguments to call for
64 * classification (list of possible target classes will be second argument,
65 * feature vector will be provided as standard input).
66 */
67 private static final String CONFIG_CMD_CLASSIFY = "classifier.ext.classify";
68
69 /***
70 * Configuration key: Command name + arguments to call for class
71 * initialization (class to initialize will be last arg).
72 */
73 private static final String CONFIG_CMD_INIT = "classifier.ext.init";
74
75 /***
76 * Configuration key: Command name + arguments to call for resetting the
77 * classifier by deleting the prediction model (class to reset will be last
78 * arg).
79 */
80 private static final String CONFIG_CMD_RESET = "classifier.ext.reset";
81
82 /***
83 * Configuration key: Command name + arguments to call for
84 * training (expected target class will be second argument, feature vector
85 * will be provided as standard input).
86 */
87 private static final String CONFIG_CMD_TRAIN = "classifier.ext.train";
88
89 /***
90 * Configuration key: the suffix to append to classes for the classifier
91 * (optional).
92 */
93 private static final String CONFIG_CLASS_SUFFIX = "classifier.ext.suffix";
94
95 /***
96 * Configuration key: regular expression to extract the predicted class
97 * (group 1) and the probability (group 2) from the classifier's standard
98 * output; for all classes or at least for the best one.
99 */
100 private static final String CONFIG_REGEX = "classifier.ext.regex";
101
102 /***
103 * External command called for classification.
104 */
105 private final ExternalCommand extClassifier;
106
107 /***
108 * External command called for initialization; might be <code>null</code>.
109 */
110 private final ExternalCommand extInitializer;
111
112 /***
113 * External command called for resetting the prediction model; might be
114 * <code>null</code>.
115 */
116 private final ExternalCommand extResetter;
117
118 /***
119 * External command called for training.
120 */
121 private final ExternalCommand extTrainer;
122
123 /***
124 * The suffix to append to classes for the classifier (might be null).
125 */
126 private final String classSuffix;
127
128 /***
129 * The directory to run the classifier in (if null, the current working
130 * directory is used).
131 */
132 private final File workDir;
133
134 /***
135 * Regular expression to extract the predicted class (group 1) and the
136 * probability (group 2) from the classifier's standard output.
137 */
138 private final Pattern predictionPattern;
139
140 /***
141 * The classifier is trained if the pR is below this value as well as
142 * on errors ("thick threshold" heuristic). {@link Double#NaN} if not used.
143 */
144 private final double thickThresholdPR;
145
146 /***
147 * The classifier is trained if the probability is below this value as well
148 * as on errors ("thick threshold" heuristic). {@link Double#NaN} if not
149 * used.
150 */
151 private final double thickThresholdProb;
152
153 /***
154 * Whether this instance has been initialized.
155 */
156 private boolean initialized = false;
157
158 /***
159 * Creates a new instance based on the
160 * {@linkplain TiesConfiguration#CONF standard configuration}.
161 *
162 * @param allValidClasses the set of all valid classes
163 * @throws ProcessingException if an I/O error occurs during initialization
164 */
165 public ExternalClassifier(final Set<String> allValidClasses)
166 throws ProcessingException {
167 this(allValidClasses, TiesConfiguration.CONF);
168 }
169
170 /***
171 * Creates a new instance based on the provided configuration.
172 *
173 * @param allValidClasses the set of all valid classes
174 * @param config contains configuration properties
175 * @throws ProcessingException if an I/O error occurs during initialization
176 */
177 public ExternalClassifier(final Set<String> allValidClasses,
178 final TiesConfiguration config) throws ProcessingException {
179 this(allValidClasses, FeatureTransformer.createTransformer(config),
180 null, config);
181 }
182
183 /***
184 * Creates a new instance based on the provided arguments.
185 *
186 * @param allValidClasses the set of all valid classes
187 * @param trans the last transformer in the transformer chain to use, or
188 * <code>null</code> if no feature transformers should be used
189 * @param runDirectory the directory to run the classifier in; used instead
190 * of the {@linkplain #CONFIG_DIR configured directory} if not
191 * <code>null</code>
192 * @param config contains configuration properties
193 * @throws ProcessingException if an I/O error occurs during initialization
194 */
195 public ExternalClassifier(final Set<String> allValidClasses,
196 final FeatureTransformer trans, final File runDirectory,
197 final TiesConfiguration config)
198 throws ProcessingException {
199 super(allValidClasses, trans, config);
200
201
202 final String[] cmdClassify = config.getStringArray(CONFIG_CMD_CLASSIFY);
203 final String[] cmdInit = config.getStringArray(CONFIG_CMD_INIT);
204 final String[] cmdReset = config.getStringArray(CONFIG_CMD_RESET);
205 final String[] cmdTrain = config.getStringArray(CONFIG_CMD_TRAIN);
206 predictionPattern = Pattern.compile(config.getString(CONFIG_REGEX));
207
208
209 classSuffix = config.getString(CONFIG_CLASS_SUFFIX, null);
210
211 final String rawThresholdProb =
212 config.getString("classifier.ext.threshold.prob", null);
213 if (StringUtils.isNotEmpty(rawThresholdProb)) {
214
215 thickThresholdProb = Util.asDouble(rawThresholdProb);
216 } else {
217 thickThresholdProb = Double.NaN;
218 }
219
220 final String rawThresholdPR =
221 config.getString("classifier.ext.threshold.pR", null);
222 if (StringUtils.isNotEmpty(rawThresholdPR)) {
223
224 thickThresholdPR = Util.asDouble(rawThresholdPR);
225 } else {
226 thickThresholdPR = Double.NaN;
227 }
228
229
230 if (runDirectory != null) {
231 workDir = runDirectory;
232 } else if (config.containsKey(CONFIG_DIR)) {
233 workDir = new File(config.getString(CONFIG_DIR));
234 } else {
235 workDir = null;
236 }
237
238
239 extClassifier = new ExternalCommand(cmdClassify, workDir);
240 extTrainer = new ExternalCommand(cmdTrain, workDir);
241
242 if (cmdInit.length > 0) {
243 extInitializer = new ExternalCommand(cmdInit, workDir);
244 } else {
245 extInitializer = null;
246 }
247
248 if (cmdReset.length > 0) {
249 extResetter = new ExternalCommand(cmdReset, workDir);
250 } else {
251 extResetter = null;
252 }
253 }
254
255 /***
256 * Helper method building a class name by appending the configured suffix,
257 * if any.
258 *
259 * @param baseName the base name of the class
260 * @return the complete class name
261 */
262 private String buildClassName(final String baseName) {
263
264 if (classSuffix != null) {
265 return baseName + classSuffix;
266 } else {
267 return baseName;
268 }
269 }
270
271 /***
272 * Classifies an item that is represented by a feature vector by choosing
273 * the most probable class among a set of candidate classes.
274 *
275 * @param features the feature vector to consider
276 * @param candidateClasses an array of the classes that are allowed for
277 * this item
278 * @param context ignored by this implementation
279 * @return the result of the classification; you can call
280 * {@link PredictionDistribution#best()} to get the most probably class;
281 * this classifier returns only the best prediction, so
282 * {@link PredictionDistribution#size()} will be 1
283 * @throws ProcessingException if an I/O error occurs during communication
284 * with the external program
285 */
286 protected PredictionDistribution doClassify(final FeatureVector features,
287 final Set candidateClasses, final ContextMap context)
288 throws ProcessingException {
289 if (!initialized) {
290 initialize();
291 }
292
293
294 final StringBuffer candidateBuffer = new StringBuffer();
295 final Iterator candidateIter = candidateClasses.iterator();
296
297 while (candidateIter.hasNext()) {
298 candidateBuffer.append(
299 buildClassName((String) candidateIter.next()));
300
301 if (candidateIter.hasNext()) {
302
303 candidateBuffer.append(' ');
304 }
305 }
306 final String[] furtherArg = new String[] {candidateBuffer.toString()};
307 final String output;
308
309
310 try {
311 output = extClassifier.execute(furtherArg, features.flatten());
312 } catch (IOException ioe) {
313
314 throw new ProcessingException("I/O error while classifying", ioe);
315 }
316
317
318 final Matcher outputMatcher = predictionPattern.matcher(output);
319 final PredictionDistribution predDist = new PredictionDistribution();
320 final Set<String> addedClasses = new HashSet<String>();
321 Prediction pred;
322 String rawPredictedClass, predictedClass;
323 String rawProbability;
324 double probability;
325 String rawPR;
326
327
328 while (outputMatcher.find()) {
329
330
331
332
333 if (outputMatcher.groupCount() < 2) {
334 throw new IllegalArgumentException("Extraction pattern '"
335 + predictionPattern.pattern()
336 + "' should match 2 or 3 subgroups but matched only "
337 + outputMatcher.groupCount() + " in classifier output: '"
338 + output + "'");
339 }
340
341
342 rawPredictedClass = outputMatcher.group(1);
343
344 if ((classSuffix != null)
345 && rawPredictedClass.endsWith(classSuffix)) {
346
347 predictedClass = rawPredictedClass.substring(0,
348 rawPredictedClass.length() - classSuffix.length());
349 } else {
350 predictedClass = rawPredictedClass;
351 }
352
353 rawProbability = outputMatcher.group(2);
354 probability = Util.asDouble(rawProbability);
355
356
357 final double pR;
358 if (outputMatcher.groupCount() >= 3) {
359 rawPR = outputMatcher.group(3);
360 pR = Util.asDouble(rawPR);
361 } else {
362 pR = Double.NaN;
363 }
364
365
366 if (!addedClasses.contains(predictedClass)) {
367
368 addedClasses.add(predictedClass);
369 pred = new Prediction(predictedClass,
370 new Probability(probability, pR));
371 predDist.add(pred);
372 }
373 }
374
375
376 if (predDist.size() < 1) {
377 throw new IllegalArgumentException(
378 "No match found for extraction pattern '"
379 + predictionPattern.pattern() + "' in classifier output: '"
380 + output + "'");
381 }
382
383
384
385 return predDist;
386 }
387
388 /***
389 * {@inheritDoc}
390 */
391 protected void doTrain(final FeatureVector features,
392 final String targetClass, final ContextMap context)
393 throws ProcessingException {
394 if (!initialized) {
395 initialize();
396 }
397
398
399 final String[] furtherArg = new String[] {buildClassName(targetClass)};
400
401
402
403 try {
404 extTrainer.execute(furtherArg, features.flatten());
405 } catch (IOException ioe) {
406
407 throw new ProcessingException("I/O error while training", ioe);
408 }
409 }
410
411 /***
412 * Initializes all classes.
413 * @throws ProcessingException if an I/O error occurs during initialization
414 */
415 private void initialize() throws ProcessingException {
416 if (extInitializer != null) {
417
418 final Iterator classIter = getAllClasses().iterator();
419
420 while (classIter.hasNext()) {
421 init((String) classIter.next());
422 }
423 }
424
425 initialized = true;
426 }
427
428 /***
429 * Initializes a class. This method is called once for each class.
430 *
431 * @param cls the class to initialize
432 * @throws ProcessingException if an I/O error occurs during initialization
433 */
434 private void init(final String cls) throws ProcessingException {
435
436 final String[] furtherArg = new String[] {buildClassName(cls)};
437
438
439 try {
440 extInitializer.execute(furtherArg);
441 } catch (IOException ioe) {
442
443 throw new ProcessingException("I/O error while initializing the "
444 + cls + " class", ioe);
445 }
446 }
447
448 /***
449 * {@inheritDoc}
450 */
451 public void reset() throws ProcessingException {
452 if (extResetter != null) {
453
454 final Iterator classIter = getAllClasses().iterator();
455
456 while (classIter.hasNext()) {
457 resetClass((String) classIter.next());
458 }
459 }
460
461
462 initialized = false;
463 }
464
465 /***
466 * Resets a class.
467 *
468 * @param cls the class to reset
469 * @throws ProcessingException if an I/O error occurs during reset
470 */
471 private void resetClass(final String cls) throws ProcessingException {
472
473 final String[] furtherArg = new String[] {buildClassName(cls)};
474
475
476 try {
477 extResetter.execute(furtherArg);
478 } catch (IOException ioe) {
479
480 throw new ProcessingException("I/O error while resetting the "
481 + cls + " class", ioe);
482 }
483 }
484
485 /***
486 * This implementation uses reinforcement training, if a thick threshold
487 * is configured.
488 * {@inheritDoc}
489 */
490 protected boolean shouldTrain(final String targetClass,
491 final PredictionDistribution predDist, final ContextMap context) {
492 final Prediction best = predDist.best();
493 final Probability bestProb = best.getProbability();
494
495 if (super.shouldTrain(targetClass, predDist, context)) {
496
497 return true;
498 } else if (!Double.isNaN(thickThresholdPR)
499 && !Double.isNaN(bestProb.getPR())
500 && bestProb.getPR() < thickThresholdPR) {
501 Util.LOG.debug("Reinforcement training because pR "
502 + bestProb.getPR() + " is below the pR threshold "
503 + thickThresholdPR);
504 return true;
505 } else if (!Double.isNaN(thickThresholdProb)
506 && bestProb.getProb() < thickThresholdProb) {
507 Util.LOG.debug("Reinforcement training because probability "
508 + best.getProbability() + " is below the threshold "
509 + thickThresholdProb);
510 return true;
511 } else {
512
513 return false;
514 }
515 }
516
517 /***
518 * Returns a string representation of this object.
519 *
520 * @return a textual representation
521 */
522 public String toString() {
523 final ToStringBuilder builder = new ToStringBuilder(this)
524 .appendSuper(super.toString())
525 .append("classify command", extClassifier)
526 .append("train command", extTrainer)
527 .append("init command", extInitializer)
528 .append("reset command", extResetter)
529 .append("class suffix", classSuffix)
530 .append("prediction pattern", predictionPattern.pattern());
531
532 if (!Double.isNaN(thickThresholdProb)) {
533 builder.append("thick threshold prob.", thickThresholdProb);
534 }
535 if (!Double.isNaN(thickThresholdPR)) {
536 builder.append("thick threshold pR", thickThresholdPR);
537 }
538
539 return builder.toString();
540 }
541
542 }