package de.fu_berlin.ties.classify;

import de.fu_berlin.ties.ContextMap;
import de.fu_berlin.ties.ProcessingException;
import de.fu_berlin.ties.TiesConfiguration;
import de.fu_berlin.ties.classify.feature.FeatureTransformer;
import de.fu_berlin.ties.classify.feature.FeatureVector;
import de.fu_berlin.ties.util.ExternalCommand;
import de.fu_berlin.ties.util.Util;
import java.io.File;
import java.io.IOException;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.apache.commons.lang.StringUtils;
import org.apache.commons.lang.builder.ToStringBuilder;

/* loaded from: input_file:de/fu_berlin/ties/classify/ExternalClassifier.class */
public class ExternalClassifier extends TrainableClassifier {
    public static final String CONFIG_DIR = "classifier.ext.directory";
    private static final String CONFIG_CMD_CLASSIFY = "classifier.ext.classify";
    private static final String CONFIG_CMD_INIT = "classifier.ext.init";
    private static final String CONFIG_CMD_RESET = "classifier.ext.reset";
    private static final String CONFIG_CMD_TRAIN = "classifier.ext.train";
    private static final String CONFIG_CLASS_SUFFIX = "classifier.ext.suffix";
    private static final String CONFIG_REGEX = "classifier.ext.regex";
    private final ExternalCommand extClassifier;
    private final ExternalCommand extInitializer;
    private final ExternalCommand extResetter;
    private final ExternalCommand extTrainer;
    private final String classSuffix;
    private final File workDir;
    private final Pattern predictionPattern;
    private final double thickThresholdPR;
    private final double thickThresholdProb;
    private boolean initialized;

    public ExternalClassifier(Set<String> set) throws ProcessingException {
        this(set, TiesConfiguration.CONF);
    }

    public ExternalClassifier(Set<String> set, TiesConfiguration tiesConfiguration) throws ProcessingException {
        this(set, FeatureTransformer.createTransformer(tiesConfiguration), null, tiesConfiguration);
    }

    public ExternalClassifier(Set<String> set, FeatureTransformer featureTransformer, File file, TiesConfiguration tiesConfiguration) throws ProcessingException {
        super(set, featureTransformer, tiesConfiguration);
        this.initialized = false;
        String[] stringArray = tiesConfiguration.getStringArray(CONFIG_CMD_CLASSIFY);
        String[] stringArray2 = tiesConfiguration.getStringArray(CONFIG_CMD_INIT);
        String[] stringArray3 = tiesConfiguration.getStringArray(CONFIG_CMD_RESET);
        String[] stringArray4 = tiesConfiguration.getStringArray(CONFIG_CMD_TRAIN);
        this.predictionPattern = Pattern.compile(tiesConfiguration.getString(CONFIG_REGEX));
        this.classSuffix = tiesConfiguration.getString(CONFIG_CLASS_SUFFIX, null);
        String string = tiesConfiguration.getString("classifier.ext.threshold.prob", null);
        if (StringUtils.isNotEmpty(string)) {
            this.thickThresholdProb = Util.asDouble(string);
        } else {
            this.thickThresholdProb = Double.NaN;
        }
        String string2 = tiesConfiguration.getString("classifier.ext.threshold.pR", null);
        if (StringUtils.isNotEmpty(string2)) {
            this.thickThresholdPR = Util.asDouble(string2);
        } else {
            this.thickThresholdPR = Double.NaN;
        }
        if (file != null) {
            this.workDir = file;
        } else if (tiesConfiguration.containsKey(CONFIG_DIR)) {
            this.workDir = new File(tiesConfiguration.getString(CONFIG_DIR));
        } else {
            this.workDir = null;
        }
        this.extClassifier = new ExternalCommand(stringArray, this.workDir);
        this.extTrainer = new ExternalCommand(stringArray4, this.workDir);
        if (stringArray2.length > 0) {
            this.extInitializer = new ExternalCommand(stringArray2, this.workDir);
        } else {
            this.extInitializer = null;
        }
        if (stringArray3.length > 0) {
            this.extResetter = new ExternalCommand(stringArray3, this.workDir);
        } else {
            this.extResetter = null;
        }
    }

    private String buildClassName(String str) {
        return this.classSuffix != null ? str + this.classSuffix : str;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // de.fu_berlin.ties.classify.TrainableClassifier
    public PredictionDistribution doClassify(FeatureVector featureVector, Set set, ContextMap contextMap) throws ProcessingException {
        if (!this.initialized) {
            initialize();
        }
        StringBuffer stringBuffer = new StringBuffer();
        Iterator it = set.iterator();
        while (it.hasNext()) {
            stringBuffer.append(buildClassName((String) it.next()));
            if (it.hasNext()) {
                stringBuffer.append(' ');
            }
        }
        try {
            String execute = this.extClassifier.execute(new String[]{stringBuffer.toString()}, featureVector.flatten());
            Matcher matcher = this.predictionPattern.matcher(execute);
            PredictionDistribution predictionDistribution = new PredictionDistribution();
            HashSet hashSet = new HashSet();
            while (matcher.find()) {
                if (matcher.groupCount() < 2) {
                    throw new IllegalArgumentException("Extraction pattern '" + this.predictionPattern.pattern() + "' should match 2 or 3 subgroups but matched only " + matcher.groupCount() + " in classifier output: '" + execute + "'");
                }
                String group = matcher.group(1);
                String substring = (this.classSuffix == null || !group.endsWith(this.classSuffix)) ? group : group.substring(0, group.length() - this.classSuffix.length());
                double asDouble = Util.asDouble(matcher.group(2));
                double asDouble2 = matcher.groupCount() >= 3 ? Util.asDouble(matcher.group(3)) : Double.NaN;
                if (!hashSet.contains(substring)) {
                    hashSet.add(substring);
                    predictionDistribution.add(new Prediction(substring, new Probability(asDouble, asDouble2)));
                }
            }
            if (predictionDistribution.size() < 1) {
                throw new IllegalArgumentException("No match found for extraction pattern '" + this.predictionPattern.pattern() + "' in classifier output: '" + execute + "'");
            }
            return predictionDistribution;
        } catch (IOException e) {
            throw new ProcessingException("I/O error while classifying", e);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // de.fu_berlin.ties.classify.TrainableClassifier
    public void doTrain(FeatureVector featureVector, String str, ContextMap contextMap) throws ProcessingException {
        if (!this.initialized) {
            initialize();
        }
        try {
            this.extTrainer.execute(new String[]{buildClassName(str)}, featureVector.flatten());
        } catch (IOException e) {
            throw new ProcessingException("I/O error while training", e);
        }
    }

    private void initialize() throws ProcessingException {
        if (this.extInitializer != null) {
            Iterator it = getAllClasses().iterator();
            while (it.hasNext()) {
                init((String) it.next());
            }
        }
        this.initialized = true;
    }

    private void init(String str) throws ProcessingException {
        try {
            this.extInitializer.execute(new String[]{buildClassName(str)});
        } catch (IOException e) {
            throw new ProcessingException("I/O error while initializing the " + str + " class", e);
        }
    }

    @Override // de.fu_berlin.ties.classify.TrainableClassifier
    public void reset() throws ProcessingException {
        if (this.extResetter != null) {
            Iterator it = getAllClasses().iterator();
            while (it.hasNext()) {
                resetClass((String) it.next());
            }
        }
        this.initialized = false;
    }

    private void resetClass(String str) throws ProcessingException {
        try {
            this.extResetter.execute(new String[]{buildClassName(str)});
        } catch (IOException e) {
            throw new ProcessingException("I/O error while resetting the " + str + " class", e);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // de.fu_berlin.ties.classify.TrainableClassifier
    public boolean shouldTrain(String str, PredictionDistribution predictionDistribution, ContextMap contextMap) {
        Prediction best = predictionDistribution.best();
        Probability probability = best.getProbability();
        if (super.shouldTrain(str, predictionDistribution, contextMap)) {
            return true;
        }
        if (!Double.isNaN(this.thickThresholdPR) && !Double.isNaN(probability.getPR()) && probability.getPR() < this.thickThresholdPR) {
            Util.LOG.debug("Reinforcement training because pR " + probability.getPR() + " is below the pR threshold " + this.thickThresholdPR);
            return true;
        }
        if (Double.isNaN(this.thickThresholdProb) || probability.getProb() >= this.thickThresholdProb) {
            return false;
        }
        Util.LOG.debug("Reinforcement training because probability " + best.getProbability() + " is below the threshold " + this.thickThresholdProb);
        return true;
    }

    @Override // de.fu_berlin.ties.classify.TrainableClassifier
    public String toString() {
        ToStringBuilder append = new ToStringBuilder(this).appendSuper(super.toString()).append("classify command", this.extClassifier).append("train command", this.extTrainer).append("init command", this.extInitializer).append("reset command", this.extResetter).append("class suffix", this.classSuffix).append("prediction pattern", this.predictionPattern.pattern());
        if (!Double.isNaN(this.thickThresholdProb)) {
            append.append("thick threshold prob.", this.thickThresholdProb);
        }
        if (!Double.isNaN(this.thickThresholdPR)) {
            append.append("thick threshold pR", this.thickThresholdPR);
        }
        return append.toString();
    }
}
