1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22 package de.fu_berlin.ties.classify;
23
24 import java.io.File;
25 import java.io.IOException;
26 import java.io.Reader;
27 import java.io.Writer;
28 import java.util.HashSet;
29 import java.util.Iterator;
30 import java.util.Set;
31
32 import org.apache.commons.lang.StringUtils;
33
34 import de.fu_berlin.ties.Closeable;
35 import de.fu_berlin.ties.ContextMap;
36 import de.fu_berlin.ties.ProcessingException;
37 import de.fu_berlin.ties.TextProcessor;
38 import de.fu_berlin.ties.TiesConfiguration;
39 import de.fu_berlin.ties.classify.feature.FeatureExtractor;
40 import de.fu_berlin.ties.classify.feature.FeatureExtractorFactory;
41 import de.fu_berlin.ties.classify.feature.FeatureVector;
42 import de.fu_berlin.ties.eval.Accuracy;
43 import de.fu_berlin.ties.extract.TrainEval;
44 import de.fu_berlin.ties.io.FieldContainer;
45 import de.fu_berlin.ties.io.FieldMap;
46 import de.fu_berlin.ties.io.IOUtils;
47 import de.fu_berlin.ties.io.ObjectElement;
48 import de.fu_berlin.ties.util.Util;
49
50 /***
51 * Classifies a list of files, training the classifier on each error if the
52 * true class is provided. See
53 * {@link #classifyAndTrain(FieldContainer, File, String, String)} for a
54 * description of input and output formats.
55 *
56 * <p>This class does not calculate statistics; you can do so be calling e.g.
57 * <code>tail -q --lines 500 <em>FILENAME</em>|grep -v "|+"|wc</code> on the
58 * output serialized in {@link de.fu_berlin.ties.io.DelimSepValues} format to
59 * get the number of errors during the last 500 classifications (assuming that
60 * classes to not start with a "+" and that the true class is known for all
61 * files).
62 *
63 * <p>Instances of this class are not thread-safe and must be synchronized
64 * externally, if required.
65 *
66 * @author Christian Siefkes
67 * @version $Revision: 1.35 $, $Date: 2006/10/21 16:03:54 $, $Author: siefkes $
68 */
69 public class ClassTrain extends TextProcessor implements Closeable {
70
71 /***
72 * Configuration key: The extension to append to file names given via the
73 * {@linkplain #KEY_FILE File key} (if any).
74 */
75 public static final String CONFIG_FILE_EXT = "file.ext";
76
77 /***
78 * Configuration suffix used for text classification--specific settings.
79 */
80 public static final String CONFIG_SUFFIX_TEXT = "text";
81
82 /***
83 * Serialization key for the name of the file to classify.
84 */
85 public static final String KEY_FILE = "File";
86
87 /***
88 * Serialization key for the correct class.
89 */
90 public static final String KEY_CLASS = "Class";
91
92 /***
93 * Serialization key for the result of the classification: either
94 * {@link #CORRECT_CLASS} if the correct class was predicted or the
95 * wrongly predicted class in case of an error.
96 */
97 public static final String KEY_CLASSIFICATION = "Classification";
98
99 /***
100 * Value of the {@link #KEY_CLASSIFICATION} field for correct predictions:
101 * {@value}.
102 */
103 public static final String CORRECT_CLASS = "+";
104
105 /***
106 * Used to convert text sequences into feature vectors.
107 */
108 private final FeatureExtractor featureExtractor;
109
110 /***
111 * The extension to append to file names given via the
112 * {@linkplain #KEY_FILE File key}; empty string if none.
113 */
114 private final String fileExtension;
115
116 /***
117 * The {@link #classifierFileName} is resolved relative to this directory;
118 * if <code>null</code>, the working directory is used.
119 */
120 private File classifierDirectory;
121
122 /***
123 * Name of the file used for storing the classifier.
124 */
125 private final String classifierFileName;
126
127 /***
128 * Whether to re-use classifiers between several runs (including classifiers
129 * stored in the {@linkplain #classifierFileName classifier file},
130 * if exists).
131 */
132 private final boolean reUse;
133
134 /***
135 * Whether to store the final classifier in the
136 * {@linkplain #classifierFileName classifier file}.
137 */
138 private final boolean store;
139
140 /***
141 * If this is set to <code>true</code>, the classifier will be used only
142 * for prediction -- no training will take place.
143 */
144 private final boolean testOnly;
145
146 /***
147 * The classifier used by this instance.
148 */
149 private TrainableClassifier classifier = null;
150
151 /***
152 * Used for TUNE training (iterative training).
153 */
154 private final Tuner tuner;
155
156
157 /***
158 * Creates a new instance using a default extension and the
159 * {@link TiesConfiguration#CONF standard configuration}.
160 *
161 * @throws ProcessingException if an error occurs while initializing this
162 * instance
163 */
164 public ClassTrain() throws ProcessingException {
165 this("cls");
166 }
167
168 /***
169 * Creates a new instance using the
170 * {@link TiesConfiguration#CONF standard configuration}.
171 *
172 * @param outExt the extension to use for output files
173 * @throws ProcessingException if an error occurs while initializing this
174 * instance
175 */
176 public ClassTrain(final String outExt) throws ProcessingException {
177 this(outExt, TiesConfiguration.CONF);
178 }
179
180
181 /***
182 * Creates a new instance from the provided configuration.
183 *
184 * @param outExt the extension to use for output files
185 * @param conf used to configure this instance; if <code>null</code>,
186 * the {@linkplain TiesConfiguration#CONF standard configuration} is used
187 * @throws ProcessingException if an error occurs while initializing this
188 * instance
189 */
190 public ClassTrain(final String outExt, final TiesConfiguration conf)
191 throws ProcessingException {
192 this(outExt, conf,
193 FeatureExtractorFactory.createExtractor(conf,
194 Classifier.CONFIG_CLASSIFIER),
195 new Tuner(conf, CONFIG_SUFFIX_TEXT),
196 conf.getString(CONFIG_FILE_EXT, ""),
197 conf.getString("classifier.file"),
198 conf.getBoolean("classifier.re-use"),
199 conf.getBoolean("classifier.store"),
200 conf.getBoolean("classifier.test-only"));
201 }
202
203 /***
204 * Creates a new instance.
205 *
206 * @param outExt the extension to use for output files
207 * @param conf used to configure this instance; if <code>null</code>,
208 * the {@linkplain TiesConfiguration#CONF standard configuration} is used
209 * @param featureExt used to convert texts into feature vectors
210 * @param myTuner used to control TUNE training (iterative training)
211 * @param fileExt the extension to append to file names given via the
212 * {@linkplain #KEY_FILE File key}; <code>null</code> or the empty string
213 * if none should be appended
214 * @param classifierFile name of the file used for storing the classifier
215 * @param doReUse whether to re-use classifiers between several runs
216 * (incl. classifiers stored in the <code>classifierFile</code>, if exists)
217 * @param doStore whether to store the final classifier in the
218 * <code>classifierFile</code>
219 * @param doTestOnly If this is set to <code>true</code>, the classifier
220 * will be used only for prediction -- no training will take place
221 */
222 public ClassTrain(final String outExt, final TiesConfiguration conf,
223 final FeatureExtractor featureExt, final Tuner myTuner,
224 final String fileExt, final String classifierFile,
225 final boolean doReUse, final boolean doStore,
226 final boolean doTestOnly) {
227 super(outExt, conf);
228 featureExtractor = featureExt;
229 tuner = myTuner;
230 fileExtension = (fileExt != null) ? fileExt : "";
231 classifierFileName = classifierFile;
232 reUse = doReUse;
233 store = doStore;
234 testOnly = doTestOnly;
235 }
236
237
238 /***
239 * Classifies a list of files, training the classifier on each error if the
240 * true class is known.
241 *
242 * @param filesToClassify a field container of the files to process; each
243 * entry must contain a {@link #KEY_FILE} field giving the name of the file
244 * to classify; if it also contains a {@link #KEY_CLASS} field giving the
245 * true class of the file, the classifier is trained in case of an error
246 * @param directory file names are relative to this directory; if
247 * <code>null</code> they are relative to the working directory
248 * @param baseName the base name of the file listing the files to classify
249 * @param charset the character set of the files to process
250 * @return a field container of the classification results; in addition to
251 * the fields given above, each entry will contain the classification result
252 * in a {@link #KEY_CLASSIFICATION} field: {@link #CORRECT_CLASS} in
253 * case of a classification that is known to be correct (this requires that
254 * the true class is given in the {@link #KEY_CLASS} field, otherwise we
255 * don't know whether a prediction is correct); the name of the predicted
256 * class otherwise
257 * @throws IOException if an I/O error occurs
258 * @throws ProcessingException if an error occurs during processing
259 */
260 public FieldContainer classifyAndTrain(final FieldContainer filesToClassify,
261 final File directory, final String baseName, final String charset)
262 throws IOException, ProcessingException {
263
264 if (reUse) {
265 if (classifier == null) {
266 final File classifierFile =
267 new File(classifierDirectory, classifierFileName);
268 if (classifierFile.exists() && classifierFile.canRead()) {
269 try {
270 classifier = (TrainableClassifier)
271 ObjectElement.createObject(classifierFile);
272 Util.LOG.info("Restored classifier from "
273 + classifierFile);
274 } catch (InstantiationException ie) {
275 throw new ProcessingException(
276 "Deserialization of classifier failed: "
277 + ie.getMessage(), ie);
278 }
279 }
280 }
281 } else {
282
283 classifier = null;
284 }
285
286 final FieldContainer result =
287 FieldContainer.createFieldContainer(getConfig());
288 final FieldContainer accuracyStore =
289 FieldContainer.createFieldContainer(getConfig());
290 final int numFiles = filesToClassify.size();
291 FieldMap inMap;
292 String currentClass;
293 String[] filenames = new String[numFiles];
294 String[] classes = new String[numFiles];
295 final Iterator fileIter = filesToClassify.entryIterator();
296
297
298 final Set<String> classSet = (classifier == null)
299 ? new HashSet<String>() : classifier.getAllClasses();
300 int i = 0;
301
302
303 while (fileIter.hasNext()) {
304 inMap = (FieldMap) fileIter.next();
305 filenames[i] = StringUtils.trimToNull((String) inMap.get(KEY_FILE));
306 currentClass =
307 StringUtils.trimToNull((String) inMap.get(KEY_CLASS));
308
309 if (classifier == null && currentClass != null) {
310 classSet.add(currentClass);
311 }
312
313 classes[i] = currentClass;
314 i++;
315 }
316
317
318 if (classifier == null) {
319 classifier = TrainableClassifier.createClassifier(classSet,
320 getConfig(), CONFIG_SUFFIX_TEXT);
321 }
322 if (!reUse) {
323
324
325 classifier.reset();
326 }
327
328
329 tuner.reset();
330
331
332 boolean continueTraining = true;
333 final int numTrainFiles = Math.round(tuner.getTrainSplit() * numFiles);
334 final int filesToUse;
335 Accuracy trainAccuracy, evalAccuracy;
336 FieldMap accuracyMap;
337
338 if (tuner.getTestSplit() < 0) {
339
340 filesToUse = numFiles;
341 } else {
342 filesToUse = Math.min(numFiles,
343 Math.round((tuner.getTrainSplit() + tuner.getTestSplit())
344 * numFiles));
345 }
346
347 Util.LOG.debug("Using " + filesToUse + " of " + numFiles
348 + " files: " + numTrainFiles + " for training, "
349 + (filesToUse - numTrainFiles) + " only for evaluation");
350
351
352 for (int iteration = 1; continueTraining; iteration++) {
353 if (tuner.getTuneIterations() > 1) {
354 Util.LOG.debug("Starting TUNE iteration " + iteration + "/"
355 + tuner.getTuneIterations()
356 + " (will stop if no improvement for "
357 + tuner.getTuneStop() + " iterations)");
358 }
359
360
361 trainAccuracy = new Accuracy();
362
363
364 for (i = 0; i < numTrainFiles; i++) {
365 result.add(processFile(directory, filenames[i], classes[i],
366 classSet, charset, trainAccuracy, true));
367 }
368
369
370 if ((trainAccuracy.getTrueCount()
371 + trainAccuracy.getFalseCount()) > 0) {
372 accuracyMap = trainAccuracy.storeFields();
373 accuracyMap.put(TrainEval.KEY_TYPE, TrainEval.TYPE_TRAIN);
374 accuracyMap.put(TrainEval.KEY_ITERATION, iteration);
375 accuracyStore.add(accuracyMap);
376 }
377
378 continueTraining = tuner.continueTraining(
379 new double[] {trainAccuracy.getAccuracy()}, iteration);
380
381
382 if (tuner.shouldEvaluate(continueTraining, iteration)) {
383
384 evalAccuracy = new Accuracy();
385
386
387 for (; i < filesToUse; i++) {
388 result.add(processFile(directory, filenames[i], classes[i],
389 classSet, charset, evalAccuracy, false));
390 }
391
392
393 if ((evalAccuracy.getTrueCount()
394 + evalAccuracy.getFalseCount()) > 0) {
395 accuracyMap = evalAccuracy.storeFields();
396 accuracyMap.put(TrainEval.KEY_TYPE, TrainEval.TYPE_EVAL);
397 accuracyMap.put(TrainEval.KEY_ITERATION, iteration);
398 accuracyStore.add(accuracyMap);
399 }
400
401 }
402 }
403
404
405 if (accuracyStore.size() > 0) {
406 final File accFile = IOUtils.createOutFile(classifierDirectory,
407 baseName, TrainEval.EXT_METRICS);
408 final Writer accWriter =
409 IOUtils.openWriter(accFile, getConfig());
410 accuracyStore.store(accWriter);
411 accWriter.flush();
412 accWriter.close();
413 }
414
415 Util.LOG.debug("Finished classifying and training using "
416 + classifier + " and " + featureExtractor);
417
418 if (!reUse) {
419
420 classifier.destroy();
421 classifier = null;
422 }
423 return result;
424 }
425
426 /***
427 * {@inheritDoc}
428 */
429 public void close(final int errorCount) throws IOException {
430
431 if (store && !testOnly && (classifier != null)) {
432
433 if (errorCount == 0) {
434
435 final File classifierFile =
436 new File(classifierDirectory, classifierFileName);
437 classifier.toElement().store(classifierFile, getConfig());
438 Util.LOG.info("Stored classifier in " + classifierFile);
439 } else {
440 Util.LOG.warn(errorCount
441 + " errors ocurred -- won't store the classifier");
442 }
443 }
444 }
445
446 /***
447 * Delegates to
448 * {@link #classifyAndTrain(FieldContainer, File, String, String)}.
449 *
450 * @param reader the {@link FieldContainer} of files to classify is read
451 * from this reader; not closed by this method
452 * @param writer the resulting {@link FieldContainer} containing
453 * classification results is serialized to this writer; not closed by
454 * this method
455 * @param context a map of objects that are made available for processing;
456 * the {@link IOUtils#KEY_LOCAL_CHARSET} is used to determine the character
457 * set of the listed files; the {@link TextProcessor#KEY_DIRECTORY}
458 * {@link File} determines the source of relative file names, if given
459 * (otherwise the current working directory is used)
460 * @throws IOException if an I/O error occurs
461 * @throws ProcessingException if an error occurs during processing
462 */
463 protected void doProcess(final Reader reader, final Writer writer,
464 final ContextMap context)
465 throws IOException, ProcessingException {
466
467 final FieldContainer filesToClassify =
468 FieldContainer.createFieldContainer(getConfig());
469 filesToClassify.read(reader);
470 final String charset = (String) context.get(IOUtils.KEY_LOCAL_CHARSET);
471 final File directory = (File) context.get(KEY_DIRECTORY);
472 final String baseName = (String) context.get(KEY_LOCAL_NAME);
473
474
475 classifierDirectory = (File) context.get(KEY_OUT_DIRECTORY);
476
477
478 final FieldContainer result =
479 classifyAndTrain(filesToClassify, directory, baseName, charset);
480
481
482 result.store(writer);
483 }
484
485 /***
486 * Helper method that processes a file.
487 *
488 * @param directory the directory containing the file
489 * @param filename the name of the file (without extension)
490 * @param currentClass the true class or the file; or <code>null</code>
491 * if not known
492 * @param classSet the set of classes to consider for classification
493 * @param charset the character set of the file
494 * @param accuracy will be updated if the true class of the file is known
495 * @param doTrain whether to train after classifying if the true class of
496 * the file is known
497 * @return the field map of information about the processed file
498 * @throws IOException if an I/O error occurs while reading the file
499 * @throws ProcessingException if an error occurs while processing the file
500 */
501 private FieldMap processFile(final File directory, final String filename,
502 final String currentClass, final Set<String> classSet,
503 final String charset, final Accuracy accuracy,
504 final boolean doTrain) throws IOException, ProcessingException {
505
506 final FieldMap outMap = new FieldMap();
507 outMap.put(KEY_FILE, filename);
508
509 PredictionDistribution predDist;
510 Prediction best;
511
512 final Reader reader = IOUtils.openReader(
513 new File(directory, filename + fileExtension), charset);
514
515 try {
516
517 final FeatureVector features =
518 featureExtractor.buildFeatures(reader);
519
520 if ((currentClass != null) && !testOnly && doTrain) {
521
522
523 predDist = classifier.trainOnError(features,
524 currentClass, classSet);
525 outMap.put(KEY_CLASS, currentClass);
526
527 if (predDist == null) {
528 Util.LOG.debug("Processed " + filename
529 + fileExtension + ": classification as "
530 + currentClass + " was correct");
531 outMap.put(KEY_CLASSIFICATION, CORRECT_CLASS);
532 accuracy.incTrueCount();
533 } else {
534 best = predDist.best();
535 Util.LOG.debug("Processed " + filename
536 + fileExtension + ": misclassified as "
537 + best.getType() + " instead of "
538 + currentClass);
539 outMap.put(KEY_CLASSIFICATION, best.getType());
540 accuracy.incFalseCount();
541 }
542 } else {
543
544 predDist = classifier.classify(features, classSet);
545 best = predDist.best();
546
547 if (currentClass != null) {
548 outMap.put(KEY_CLASS, currentClass);
549
550
551 if (best.getType().equals(currentClass)) {
552 Util.LOG.debug("Processed " + filename
553 + fileExtension + ": classification as "
554 + currentClass + " was correct");
555 outMap.put(KEY_CLASSIFICATION, CORRECT_CLASS);
556 accuracy.incTrueCount();
557 } else {
558 Util.LOG.debug("Processed " + filename
559 + fileExtension + ": misclassified as "
560 + best.getType() + " instead of "
561 + currentClass
562 + " (but training is disabled)");
563 outMap.put(KEY_CLASSIFICATION, best.getType());
564 accuracy.incFalseCount();
565 }
566 } else {
567
568 Util.LOG.debug("Processed " + filename
569 + fileExtension + ": classified as "
570 + best.getType());
571 outMap.put(KEY_CLASSIFICATION, best.getType());
572 }
573 }
574 } finally {
575 IOUtils.tryToClose(reader);
576 }
577 return outMap;
578 }
579
580 }