package de.fu_berlin.ties.classify;

import de.fu_berlin.ties.Closeable;
import de.fu_berlin.ties.ContextMap;
import de.fu_berlin.ties.ProcessingException;
import de.fu_berlin.ties.TextProcessor;
import de.fu_berlin.ties.TiesConfiguration;
import de.fu_berlin.ties.classify.feature.FeatureExtractor;
import de.fu_berlin.ties.classify.feature.FeatureExtractorFactory;
import de.fu_berlin.ties.classify.feature.FeatureVector;
import de.fu_berlin.ties.eval.Accuracy;
import de.fu_berlin.ties.extract.TrainEval;
import de.fu_berlin.ties.io.FieldContainer;
import de.fu_berlin.ties.io.FieldMap;
import de.fu_berlin.ties.io.IOUtils;
import de.fu_berlin.ties.io.ObjectElement;
import de.fu_berlin.ties.util.Util;
import java.io.File;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.Reader;
import java.io.Writer;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;
import org.apache.commons.configuration.Configuration;
import org.apache.commons.lang.StringUtils;

/* loaded from: input_file:de/fu_berlin/ties/classify/ClassTrain.class */
public class ClassTrain extends TextProcessor implements Closeable {
    public static final String CONFIG_FILE_EXT = "file.ext";
    public static final String CONFIG_SUFFIX_TEXT = "text";
    public static final String KEY_FILE = "File";
    public static final String KEY_CLASS = "Class";
    public static final String KEY_CLASSIFICATION = "Classification";
    public static final String CORRECT_CLASS = "+";
    private final FeatureExtractor featureExtractor;
    private final String fileExtension;
    private File classifierDirectory;
    private final String classifierFileName;
    private final boolean reUse;
    private final boolean store;
    private final boolean testOnly;
    private TrainableClassifier classifier;
    private final Tuner tuner;

    public ClassTrain() throws ProcessingException {
        this("cls");
    }

    public ClassTrain(String str) throws ProcessingException {
        this(str, TiesConfiguration.CONF);
    }

    public ClassTrain(String str, TiesConfiguration tiesConfiguration) throws ProcessingException {
        this(str, tiesConfiguration, FeatureExtractorFactory.createExtractor(tiesConfiguration, Classifier.CONFIG_CLASSIFIER), new Tuner(tiesConfiguration, CONFIG_SUFFIX_TEXT), tiesConfiguration.getString(CONFIG_FILE_EXT, ""), tiesConfiguration.getString("classifier.file"), tiesConfiguration.getBoolean("classifier.re-use"), tiesConfiguration.getBoolean("classifier.store"), tiesConfiguration.getBoolean("classifier.test-only"));
    }

    public ClassTrain(String str, TiesConfiguration tiesConfiguration, FeatureExtractor featureExtractor, Tuner tuner, String str2, String str3, boolean z, boolean z2, boolean z3) {
        super(str, tiesConfiguration);
        this.classifier = null;
        this.featureExtractor = featureExtractor;
        this.tuner = tuner;
        this.fileExtension = str2 != null ? str2 : "";
        this.classifierFileName = str3;
        this.reUse = z;
        this.store = z2;
        this.testOnly = z3;
    }

    public FieldContainer classifyAndTrain(FieldContainer fieldContainer, File file, String str, String str2) throws IOException, ProcessingException {
        if (!this.reUse) {
            this.classifier = null;
        } else if (this.classifier == null) {
            File file2 = new File(this.classifierDirectory, this.classifierFileName);
            if (file2.exists() && file2.canRead()) {
                try {
                    this.classifier = (TrainableClassifier) ObjectElement.createObject(file2);
                    Util.LOG.info("Restored classifier from " + file2);
                } catch (InstantiationException e) {
                    throw new ProcessingException("Deserialization of classifier failed: " + e.getMessage(), e);
                }
            }
        }
        FieldContainer createFieldContainer = FieldContainer.createFieldContainer(getConfig());
        FieldContainer createFieldContainer2 = FieldContainer.createFieldContainer(getConfig());
        int size = fieldContainer.size();
        String[] strArr = new String[size];
        String[] strArr2 = new String[size];
        Iterator<FieldMap> entryIterator = fieldContainer.entryIterator();
        Set<String> hashSet = this.classifier == null ? new HashSet<>() : this.classifier.getAllClasses();
        int i = 0;
        while (entryIterator.hasNext()) {
            FieldMap next = entryIterator.next();
            strArr[i] = StringUtils.trimToNull((String) next.get(KEY_FILE));
            String trimToNull = StringUtils.trimToNull((String) next.get(KEY_CLASS));
            if (this.classifier == null && trimToNull != null) {
                hashSet.add(trimToNull);
            }
            strArr2[i] = trimToNull;
            i++;
        }
        if (this.classifier == null) {
            this.classifier = TrainableClassifier.createClassifier(hashSet, getConfig(), CONFIG_SUFFIX_TEXT);
        }
        if (!this.reUse) {
            this.classifier.reset();
        }
        this.tuner.reset();
        boolean z = true;
        int round = Math.round(this.tuner.getTrainSplit() * size);
        int min = this.tuner.getTestSplit() < 0.0f ? size : Math.min(size, Math.round((this.tuner.getTrainSplit() + this.tuner.getTestSplit()) * size));
        Util.LOG.debug("Using " + min + " of " + size + " files: " + round + " for training, " + (min - round) + " only for evaluation");
        int i2 = 1;
        while (z) {
            if (this.tuner.getTuneIterations() > 1) {
                Util.LOG.debug("Starting TUNE iteration " + i2 + "/" + this.tuner.getTuneIterations() + " (will stop if no improvement for " + this.tuner.getTuneStop() + " iterations)");
            }
            Accuracy accuracy = new Accuracy();
            int i3 = 0;
            while (i3 < round) {
                createFieldContainer.add(processFile(file, strArr[i3], strArr2[i3], hashSet, str2, accuracy, true));
                i3++;
            }
            if (accuracy.getTrueCount() + accuracy.getFalseCount() > 0) {
                FieldMap storeFields = accuracy.storeFields();
                storeFields.put("Type", TrainEval.TYPE_TRAIN);
                storeFields.put(TrainEval.KEY_ITERATION, Integer.valueOf(i2));
                createFieldContainer2.add(storeFields);
            }
            z = this.tuner.continueTraining(new double[]{accuracy.getAccuracy()}, i2);
            if (this.tuner.shouldEvaluate(z, i2)) {
                Accuracy accuracy2 = new Accuracy();
                while (i3 < min) {
                    createFieldContainer.add(processFile(file, strArr[i3], strArr2[i3], hashSet, str2, accuracy2, false));
                    i3++;
                }
                if (accuracy2.getTrueCount() + accuracy2.getFalseCount() > 0) {
                    FieldMap storeFields2 = accuracy2.storeFields();
                    storeFields2.put("Type", TrainEval.TYPE_EVAL);
                    storeFields2.put(TrainEval.KEY_ITERATION, Integer.valueOf(i2));
                    createFieldContainer2.add(storeFields2);
                }
            }
            i2++;
        }
        if (createFieldContainer2.size() > 0) {
            Writer openWriter = IOUtils.openWriter(IOUtils.createOutFile(this.classifierDirectory, str, "metrics"), (Configuration) getConfig());
            createFieldContainer2.store(openWriter);
            openWriter.flush();
            openWriter.close();
        }
        Util.LOG.debug("Finished classifying and training using " + this.classifier + " and " + this.featureExtractor);
        if (!this.reUse) {
            this.classifier.destroy();
            this.classifier = null;
        }
        return createFieldContainer;
    }

    @Override // de.fu_berlin.ties.Closeable
    public void close(int i) throws IOException {
        if (!this.store || this.testOnly || this.classifier == null) {
            return;
        }
        if (i != 0) {
            Util.LOG.warn(i + " errors ocurred -- won't store the classifier");
            return;
        }
        File file = new File(this.classifierDirectory, this.classifierFileName);
        this.classifier.toElement().store(file, getConfig());
        Util.LOG.info("Stored classifier in " + file);
    }

    @Override // de.fu_berlin.ties.TextProcessor
    protected void doProcess(Reader reader, Writer writer, ContextMap contextMap) throws IOException, ProcessingException {
        FieldContainer createFieldContainer = FieldContainer.createFieldContainer(getConfig());
        createFieldContainer.read(reader);
        String str = (String) contextMap.get(IOUtils.KEY_LOCAL_CHARSET);
        File file = (File) contextMap.get(TextProcessor.KEY_DIRECTORY);
        String str2 = (String) contextMap.get(TextProcessor.KEY_LOCAL_NAME);
        this.classifierDirectory = (File) contextMap.get(TextProcessor.KEY_OUT_DIRECTORY);
        classifyAndTrain(createFieldContainer, file, str2, str).store(writer);
    }

    private FieldMap processFile(File file, String str, String str2, Set<String> set, String str3, Accuracy accuracy, boolean z) throws IOException, ProcessingException {
        FieldMap fieldMap = new FieldMap();
        fieldMap.put(KEY_FILE, str);
        InputStreamReader openReader = IOUtils.openReader(new File(file, str + this.fileExtension), str3);
        try {
            FeatureVector buildFeatures = this.featureExtractor.buildFeatures(openReader);
            if (str2 == null || this.testOnly || !z) {
                Prediction best = this.classifier.classify(buildFeatures, set).best();
                if (str2 != null) {
                    fieldMap.put(KEY_CLASS, str2);
                    if (best.getType().equals(str2)) {
                        Util.LOG.debug("Processed " + str + this.fileExtension + ": classification as " + str2 + " was correct");
                        fieldMap.put(KEY_CLASSIFICATION, CORRECT_CLASS);
                        accuracy.incTrueCount();
                    } else {
                        Util.LOG.debug("Processed " + str + this.fileExtension + ": misclassified as " + best.getType() + " instead of " + str2 + " (but training is disabled)");
                        fieldMap.put(KEY_CLASSIFICATION, best.getType());
                        accuracy.incFalseCount();
                    }
                } else {
                    Util.LOG.debug("Processed " + str + this.fileExtension + ": classified as " + best.getType());
                    fieldMap.put(KEY_CLASSIFICATION, best.getType());
                }
            } else {
                PredictionDistribution trainOnError = this.classifier.trainOnError(buildFeatures, str2, set);
                fieldMap.put(KEY_CLASS, str2);
                if (trainOnError == null) {
                    Util.LOG.debug("Processed " + str + this.fileExtension + ": classification as " + str2 + " was correct");
                    fieldMap.put(KEY_CLASSIFICATION, CORRECT_CLASS);
                    accuracy.incTrueCount();
                } else {
                    Prediction best2 = trainOnError.best();
                    Util.LOG.debug("Processed " + str + this.fileExtension + ": misclassified as " + best2.getType() + " instead of " + str2);
                    fieldMap.put(KEY_CLASSIFICATION, best2.getType());
                    accuracy.incFalseCount();
                }
            }
            return fieldMap;
        } finally {
            IOUtils.tryToClose(openReader);
        }
    }
}
