View Javadoc

1   /*
2    * Copyright (C) 2004-2006 Christian Siefkes <christian@siefkes.net>.
3    * Development of this software is supported by the German Research Society,
4    * Berlin-Brandenburg Graduate School in Distributed Information Systems
5    * (DFG grant no. GRK 316).
6    *
7    * This program is free software; you can redistribute it and/or modify
8    * it under the terms of the GNU General Public License as published by
9    * the Free Software Foundation; either version 2 of the License, or
10   * (at your option) any later version.
11   *
12   * This program is distributed in the hope that it will be useful,
13   * but WITHOUT ANY WARRANTY; without even the implied warranty of
14   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15   * GNU General Public License for more details.
16   *
17   * You should have received a copy of the GNU General Public License
18   * along with this program; if not, visit
19   * http://www.gnu.org/licenses/gpl.html or write to the Free Software
20   * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
21   */
22  package de.fu_berlin.ties.classify;
23  
24  import java.io.File;
25  import java.io.IOException;
26  import java.io.Reader;
27  import java.io.Writer;
28  import java.util.HashSet;
29  import java.util.Iterator;
30  import java.util.Set;
31  
32  import org.apache.commons.lang.StringUtils;
33  
34  import de.fu_berlin.ties.Closeable;
35  import de.fu_berlin.ties.ContextMap;
36  import de.fu_berlin.ties.ProcessingException;
37  import de.fu_berlin.ties.TextProcessor;
38  import de.fu_berlin.ties.TiesConfiguration;
39  import de.fu_berlin.ties.classify.feature.FeatureExtractor;
40  import de.fu_berlin.ties.classify.feature.FeatureExtractorFactory;
41  import de.fu_berlin.ties.classify.feature.FeatureVector;
42  import de.fu_berlin.ties.eval.Accuracy;
43  import de.fu_berlin.ties.extract.TrainEval;
44  import de.fu_berlin.ties.io.FieldContainer;
45  import de.fu_berlin.ties.io.FieldMap;
46  import de.fu_berlin.ties.io.IOUtils;
47  import de.fu_berlin.ties.io.ObjectElement;
48  import de.fu_berlin.ties.util.Util;
49  
50  /***
51   * Classifies a list of files, training the classifier on each error if the
52   * true class is provided. See
53   * {@link #classifyAndTrain(FieldContainer, File, String, String)} for a
54   * description of input and output formats.
55   *
56   * <p>This class does not calculate statistics; you can do so be calling e.g.
57   * <code>tail -q --lines 500 <em>FILENAME</em>|grep -v "|+"|wc</code> on the
58   * output serialized in {@link de.fu_berlin.ties.io.DelimSepValues} format to
59   * get the number of errors during the last 500 classifications (assuming that
60   * classes to not start with a "+" and that the true class is known for all
61   * files).
62   *
63   * <p>Instances of this class are not thread-safe and must be synchronized
64   * externally, if required.
65   *
66   * @author Christian Siefkes
67   * @version $Revision: 1.35 $, $Date: 2006/10/21 16:03:54 $, $Author: siefkes $
68   */
69  public class ClassTrain extends TextProcessor implements Closeable {
70  
71      /***
72       * Configuration key: The extension to append to file names given via the
73       * {@linkplain #KEY_FILE File key} (if any).
74       */
75      public static final String CONFIG_FILE_EXT = "file.ext";
76  
77      /***
78       * Configuration suffix used for text classification--specific settings.
79       */
80      public static final String CONFIG_SUFFIX_TEXT = "text";
81  
82      /***
83       * Serialization key for the name of the file to classify.
84       */
85      public static final String KEY_FILE = "File";
86  
87      /***
88       * Serialization key for the correct class.
89       */
90      public static final String KEY_CLASS = "Class";
91  
92      /***
93       * Serialization key for the result of the classification: either
94       * {@link #CORRECT_CLASS} if the correct class was predicted or the
95       * wrongly predicted class in case of an error.
96       */
97      public static final String KEY_CLASSIFICATION = "Classification";
98  
99      /***
100      * Value of the {@link #KEY_CLASSIFICATION} field for correct predictions:
101      * {@value}.
102      */
103     public static final String CORRECT_CLASS = "+";
104 
105     /***
106      * Used to convert text sequences into feature vectors.
107      */
108     private final FeatureExtractor featureExtractor;
109 
110     /***
111      * The extension to append to file names given via the
112      * {@linkplain #KEY_FILE File key}; empty string if none.
113      */
114     private final String fileExtension;
115 
116     /***
117      * The {@link #classifierFileName} is resolved relative to this directory;
118      * if <code>null</code>, the working directory is used.
119      */
120     private File classifierDirectory;
121 
122     /***
123      * Name of the file used for storing the classifier.
124      */
125     private final String classifierFileName;
126 
127     /***
128      * Whether to re-use classifiers between several runs (including classifiers
129      * stored in the {@linkplain #classifierFileName classifier file},
130      * if exists).
131      */
132     private final boolean reUse;
133 
134     /***
135      * Whether to store the final classifier in the
136      * {@linkplain #classifierFileName classifier file}.
137      */
138     private final boolean store;
139 
140     /***
141      * If this is set to <code>true</code>, the classifier will be used only
142      * for prediction -- no training will take place.
143      */
144     private final boolean testOnly;
145 
146     /***
147      * The classifier used by this instance.
148      */
149     private TrainableClassifier classifier = null;
150 
151     /***
152      * Used for TUNE training (iterative training).
153      */
154     private final Tuner tuner;
155 
156 
157     /***
158      * Creates a new instance using a default extension and the
159      * {@link TiesConfiguration#CONF standard configuration}.
160      *
161      * @throws ProcessingException if an error occurs while initializing this
162      * instance
163      */
164     public ClassTrain() throws ProcessingException {
165         this("cls");
166     }
167 
168     /***
169      * Creates a new instance using the
170      * {@link TiesConfiguration#CONF standard configuration}.
171      *
172      * @param outExt the extension to use for output files
173      * @throws ProcessingException if an error occurs while initializing this
174      * instance
175      */
176     public ClassTrain(final String outExt) throws ProcessingException {
177         this(outExt, TiesConfiguration.CONF);
178     }
179 
180 
181     /***
182      * Creates a new instance from the provided configuration.
183      *
184      * @param outExt the extension to use for output files
185      * @param conf used to configure this instance; if <code>null</code>,
186      * the {@linkplain TiesConfiguration#CONF standard configuration} is used
187      * @throws ProcessingException if an error occurs while initializing this
188      * instance
189      */
190     public ClassTrain(final String outExt, final TiesConfiguration conf)
191     throws ProcessingException {
192         this(outExt, conf,
193             FeatureExtractorFactory.createExtractor(conf,
194                     Classifier.CONFIG_CLASSIFIER),
195             new Tuner(conf, CONFIG_SUFFIX_TEXT),
196             conf.getString(CONFIG_FILE_EXT, ""),
197             conf.getString("classifier.file"),
198             conf.getBoolean("classifier.re-use"),
199             conf.getBoolean("classifier.store"),
200             conf.getBoolean("classifier.test-only"));
201     }
202 
203     /***
204      * Creates a new instance.
205      *
206      * @param outExt the extension to use for output files
207      * @param conf used to configure this instance; if <code>null</code>,
208      * the {@linkplain TiesConfiguration#CONF standard configuration} is used
209      * @param featureExt used to convert texts into feature vectors
210      * @param myTuner used to control TUNE training (iterative training)
211      * @param fileExt the extension to append to file names given via the
212      * {@linkplain #KEY_FILE File key}; <code>null</code> or the empty string
213      * if none should be appended
214      * @param classifierFile name of the file used for storing the classifier
215      * @param doReUse whether to re-use classifiers between several runs
216      * (incl. classifiers stored in the <code>classifierFile</code>, if exists)
217      * @param doStore whether to store the final classifier in the
218      * <code>classifierFile</code>
219      * @param doTestOnly If this is set to <code>true</code>, the classifier
220      * will be used only for prediction -- no training will take place
221      */
222     public ClassTrain(final String outExt, final TiesConfiguration conf,
223             final FeatureExtractor featureExt, final Tuner myTuner,
224             final String fileExt, final String classifierFile,
225             final boolean doReUse, final boolean doStore,
226             final boolean doTestOnly) {
227         super(outExt, conf);
228         featureExtractor = featureExt;
229         tuner = myTuner;
230         fileExtension = (fileExt != null) ? fileExt : "";
231         classifierFileName = classifierFile;
232         reUse = doReUse;
233         store = doStore;
234         testOnly = doTestOnly;
235     }
236 
237 
238     /***
239      * Classifies a list of files, training the classifier on each error if the
240      * true class is known.
241      *
242      * @param filesToClassify a field container of the files to process; each
243      * entry must contain a {@link #KEY_FILE} field giving the name of the file
244      * to classify; if it also contains a {@link #KEY_CLASS} field giving the
245      * true class of the file, the classifier is trained in case of an error
246      * @param directory file names are relative to this directory; if
247      * <code>null</code> they are relative to the working directory
248      * @param baseName the base name of the file listing the files to classify
249      * @param charset the character set of the files to process
250      * @return a field container of the classification results; in addition to
251      * the fields given above, each entry will contain the classification result
252      * in a {@link #KEY_CLASSIFICATION} field: {@link #CORRECT_CLASS} in
253      * case of a classification that is known to be correct (this requires that
254      * the true class is given in the {@link #KEY_CLASS} field, otherwise we
255      * don't know whether a prediction is correct); the name of the predicted
256      * class otherwise
257      * @throws IOException if an I/O error occurs
258      * @throws ProcessingException if an error occurs during processing
259      */
260     public FieldContainer classifyAndTrain(final FieldContainer filesToClassify,
261             final File directory, final String baseName, final String charset)
262     throws IOException, ProcessingException {
263         // try to load classifier from existing file if configured and necessary
264         if (reUse) {
265             if (classifier == null) {
266                 final File classifierFile =
267                     new File(classifierDirectory, classifierFileName);
268                 if (classifierFile.exists() && classifierFile.canRead()) {
269                     try {
270                         classifier = (TrainableClassifier)
271                             ObjectElement.createObject(classifierFile);
272                         Util.LOG.info("Restored classifier from "
273                                 + classifierFile);
274                     } catch (InstantiationException ie) {
275                         throw new ProcessingException(
276                                 "Deserialization of classifier failed: "
277                                 + ie.getMessage(), ie);
278                     }
279                 }
280             }
281         } else {
282             // no re-use: reset classifier so it will be re-initialized
283             classifier = null;
284         }
285 
286         final FieldContainer result =
287             FieldContainer.createFieldContainer(getConfig());
288         final FieldContainer accuracyStore =
289             FieldContainer.createFieldContainer(getConfig());
290         final int numFiles = filesToClassify.size();
291         FieldMap inMap;
292         String currentClass;
293         String[] filenames = new String[numFiles];
294         String[] classes = new String[numFiles];
295         final Iterator fileIter = filesToClassify.entryIterator();
296         // class set is determined from the training data
297         // if it's not already stored in the classifier
298         final Set<String> classSet = (classifier == null)
299                 ? new HashSet<String>() : classifier.getAllClasses();
300         int i = 0;
301 
302         // collect files to process and determine set of classes
303         while (fileIter.hasNext()) {
304             inMap = (FieldMap) fileIter.next();
305             filenames[i] = StringUtils.trimToNull((String) inMap.get(KEY_FILE));
306             currentClass =
307                 StringUtils.trimToNull((String) inMap.get(KEY_CLASS));
308 
309             if (classifier == null && currentClass != null) {
310                 classSet.add(currentClass);
311             }
312 
313             classes[i] = currentClass;
314             i++;
315         }
316 
317         // initialize classifier if necessary
318         if (classifier == null) {
319             classifier = TrainableClassifier.createClassifier(classSet,
320                     getConfig(), CONFIG_SUFFIX_TEXT);
321         }
322         if (!reUse) {
323             // no re-use: reset might be necessary if predictions are stored
324             // externally
325             classifier.reset();
326         }
327 
328         // reset TUNER state
329         tuner.reset();
330 
331         // prepare (TUNE) training
332         boolean continueTraining = true;
333         final int numTrainFiles = Math.round(tuner.getTrainSplit() * numFiles);
334         final int filesToUse;
335         Accuracy trainAccuracy, evalAccuracy;
336         FieldMap accuracyMap;
337 
338         if (tuner.getTestSplit() < 0) {
339             // use all remaining files for evaluation
340             filesToUse = numFiles;
341         } else {
342             filesToUse = Math.min(numFiles,
343                 Math.round((tuner.getTrainSplit() + tuner.getTestSplit())
344                         * numFiles));
345         }
346 
347         Util.LOG.debug("Using " + filesToUse + " of " + numFiles
348                 + " files: " + numTrainFiles + " for training, "
349                 + (filesToUse - numTrainFiles) + " only for evaluation");
350 
351         // TUNE training loop
352         for (int iteration = 1; continueTraining; iteration++) {
353             if (tuner.getTuneIterations() > 1) {
354                 Util.LOG.debug("Starting TUNE iteration " + iteration + "/" 
355                         + tuner.getTuneIterations()
356                         + " (will stop if no improvement for "
357                         + tuner.getTuneStop() + " iterations)");
358             }
359 
360             // reset accuracy for each TUNE iteration
361             trainAccuracy = new Accuracy();
362 
363             // classify and if possible + necessary train each listed file
364             for (i = 0; i < numTrainFiles; i++) {
365                 result.add(processFile(directory, filenames[i], classes[i],
366                         classSet, charset, trainAccuracy, true));
367             } // for i
368 
369             // serialize accuracy (if there is anything to serialize)
370             if ((trainAccuracy.getTrueCount()
371                     + trainAccuracy.getFalseCount()) > 0) {
372                 accuracyMap = trainAccuracy.storeFields();
373                 accuracyMap.put(TrainEval.KEY_TYPE, TrainEval.TYPE_TRAIN);
374                 accuracyMap.put(TrainEval.KEY_ITERATION, iteration);
375                 accuracyStore.add(accuracyMap);
376             }
377 
378             continueTraining = tuner.continueTraining(
379                     new double[] {trainAccuracy.getAccuracy()}, iteration);
380 
381             // evaluate on remaining files if required
382             if (tuner.shouldEvaluate(continueTraining, iteration)) {
383                 // reset accuracy for each iteration
384                 evalAccuracy = new Accuracy();
385 
386                 // evaluate remaining files
387                 for (; i < filesToUse; i++) {
388                     result.add(processFile(directory, filenames[i], classes[i],
389                             classSet, charset, evalAccuracy, false));
390                 } // for i
391 
392                 // serialize accuracy (if there is anything to serialize)
393                 if ((evalAccuracy.getTrueCount()
394                         + evalAccuracy.getFalseCount()) > 0) {
395                     accuracyMap = evalAccuracy.storeFields();
396                     accuracyMap.put(TrainEval.KEY_TYPE, TrainEval.TYPE_EVAL);
397                     accuracyMap.put(TrainEval.KEY_ITERATION, iteration);
398                     accuracyStore.add(accuracyMap);
399                 }
400 
401             } // if shouldEvaluate
402         } // for iteration
403 
404         // store accuracies in metrics file, if any
405         if (accuracyStore.size() > 0) {
406             final File accFile = IOUtils.createOutFile(classifierDirectory,
407                     baseName, TrainEval.EXT_METRICS);
408                 final Writer accWriter =
409                     IOUtils.openWriter(accFile, getConfig());
410                 accuracyStore.store(accWriter);
411                 accWriter.flush();
412                 accWriter.close();
413         }
414 
415         Util.LOG.debug("Finished classifying and training using "
416                 + classifier + " and " + featureExtractor);
417 
418         if (!reUse) {
419             // no re-use: destroy the classifier
420             classifier.destroy();
421             classifier = null;
422         }
423         return result;
424     }
425 
426     /***
427      * {@inheritDoc}
428      */
429     public void close(final int errorCount) throws IOException {
430         // no need to store if training is disabled or nothing was trained
431         if (store && !testOnly && (classifier != null)) {
432             // store the classifier unless an error occurred
433             if (errorCount == 0) {
434                 // serialize classifier in (compressed) XML
435                 final File classifierFile =
436                     new File(classifierDirectory, classifierFileName);
437                 classifier.toElement().store(classifierFile, getConfig());
438                 Util.LOG.info("Stored classifier in " + classifierFile);
439             } else {
440                 Util.LOG.warn(errorCount
441                         + " errors ocurred -- won't store the classifier");
442             }
443         }
444     }
445 
446     /***
447      * Delegates to
448      * {@link #classifyAndTrain(FieldContainer, File, String, String)}.
449      *
450      * @param reader the {@link FieldContainer} of files to classify is read
451      * from this reader; not closed by this method
452      * @param writer the resulting {@link FieldContainer} containing
453      * classification results is serialized to this writer; not closed by
454      * this method
455      * @param context a map of objects that are made available for processing;
456      * the {@link IOUtils#KEY_LOCAL_CHARSET} is used to determine the character
457      * set of the listed files; the {@link TextProcessor#KEY_DIRECTORY}
458      * {@link File} determines the source of relative file names, if given
459      * (otherwise the current working directory is used)
460      * @throws IOException if an I/O error occurs
461      * @throws ProcessingException if an error occurs during processing
462      */
463     protected void doProcess(final Reader reader, final Writer writer,
464                              final ContextMap context)
465             throws IOException, ProcessingException {
466         // read input + determine charset + directory (if any)
467         final FieldContainer filesToClassify =
468             FieldContainer.createFieldContainer(getConfig());
469         filesToClassify.read(reader);
470         final String charset = (String) context.get(IOUtils.KEY_LOCAL_CHARSET);
471         final File directory = (File) context.get(KEY_DIRECTORY);
472         final String baseName = (String) context.get(KEY_LOCAL_NAME);
473 
474         // classifier file is resolved relative to output directory
475         classifierDirectory = (File) context.get(KEY_OUT_DIRECTORY);
476 
477         // delegate to classifyAndTrain
478         final FieldContainer result =
479             classifyAndTrain(filesToClassify, directory, baseName, charset);
480 
481         // serialize results
482         result.store(writer);
483     }
484 
485     /***
486      * Helper method that processes a file.
487      *
488      * @param directory the directory containing the file
489      * @param filename the name of the file (without extension)
490      * @param currentClass the true class or the file; or <code>null</code>
491      * if not known
492      * @param classSet the set of classes to consider for classification
493      * @param charset the character set of the file
494      * @param accuracy will be updated if the true class of the file is known
495      * @param doTrain whether to train after classifying if the true class of
496      * the file is known
497      * @return the field map of information about the processed file
498      * @throws IOException if an I/O error occurs while reading the file
499      * @throws ProcessingException if an error occurs while processing the file
500      */
501     private FieldMap processFile(final File directory, final String filename,
502             final String currentClass, final Set<String> classSet,
503             final String charset, final Accuracy accuracy,
504             final boolean doTrain) throws IOException, ProcessingException {
505         // init field map to store results
506         final FieldMap outMap = new FieldMap();
507         outMap.put(KEY_FILE, filename);
508 
509         PredictionDistribution predDist;
510         Prediction best;
511         // read contents of file (relative to given directory if any)
512         final Reader reader = IOUtils.openReader(
513             new File(directory, filename + fileExtension), charset);
514 
515         try {
516             // extract features from text
517             final FeatureVector features =
518                 featureExtractor.buildFeatures(reader);
519 
520             if ((currentClass != null) && !testOnly && doTrain) {
521                 // true class known + training is allowed:
522                 // TOE train classifier + evaluate result
523                 predDist = classifier.trainOnError(features,
524                         currentClass, classSet);
525                 outMap.put(KEY_CLASS, currentClass);
526 
527                 if (predDist == null) {
528                     Util.LOG.debug("Processed " + filename
529                             + fileExtension + ": classification as "
530                             + currentClass + " was correct");
531                     outMap.put(KEY_CLASSIFICATION, CORRECT_CLASS);
532                     accuracy.incTrueCount();
533                 } else {
534                     best = predDist.best();
535                     Util.LOG.debug("Processed " + filename
536                             + fileExtension + ": misclassified as "
537                             + best.getType() + " instead of "
538                             + currentClass);
539                     outMap.put(KEY_CLASSIFICATION, best.getType());
540                     accuracy.incFalseCount();
541                 }
542             } else {
543                 // no training allowed/possible: invoke classifier
544                 predDist = classifier.classify(features, classSet);
545                 best = predDist.best();
546 
547                 if (currentClass != null) {
548                     outMap.put(KEY_CLASS, currentClass);
549 
550                     // evaluate by comparing with true class
551                     if (best.getType().equals(currentClass)) {
552                         Util.LOG.debug("Processed " + filename
553                                 + fileExtension + ": classification as "
554                                 + currentClass + " was correct");
555                         outMap.put(KEY_CLASSIFICATION, CORRECT_CLASS);
556                         accuracy.incTrueCount();
557                     } else {
558                         Util.LOG.debug("Processed " + filename
559                                 + fileExtension + ": misclassified as "
560                                 + best.getType() + " instead of "
561                                 + currentClass
562                                 + " (but training is disabled)");
563                         outMap.put(KEY_CLASSIFICATION, best.getType());
564                         accuracy.incFalseCount();
565                     }
566                 } else {
567                     // true class unknown: only store prediction
568                     Util.LOG.debug("Processed " + filename
569                             + fileExtension + ": classified as "
570                             + best.getType());
571                     outMap.put(KEY_CLASSIFICATION, best.getType());
572                 }
573             }
574         } finally {
575             IOUtils.tryToClose(reader);
576         }
577         return outMap;
578     }
579 
580 }