package de.fu_berlin.ties.classify;

import de.fu_berlin.ties.ContextMap;
import de.fu_berlin.ties.ProcessingException;
import de.fu_berlin.ties.TiesConfiguration;
import de.fu_berlin.ties.classify.feature.FeatureTransformer;
import de.fu_berlin.ties.classify.feature.FeatureVector;
import de.fu_berlin.ties.io.ObjectElement;
import de.fu_berlin.ties.util.Util;
import de.fu_berlin.ties.xml.dom.DOMUtils;
import java.io.File;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import org.apache.commons.lang.ArrayUtils;
import org.apache.commons.lang.builder.ToStringBuilder;
import org.dom4j.Element;
import org.dom4j.QName;
import org.dom4j.tree.DefaultElement;

/* loaded from: input_file:de/fu_berlin/ties/classify/TieClassifier.class */
public class TieClassifier extends TrainableClassifier {
    static final QName ATTRIB_TIE_THRESHOLD = DOMUtils.defaultName("tieThreshold");
    private static final String KEY_INNER_CONTEXTS = "inner-context";
    private static final String KEY_INNER_DISTS = "inner-dist";
    private final TrainableClassifier[] inner;
    private final double tieThreshold;

    public TieClassifier(Element element) throws InstantiationException {
        super(element);
        double asDouble = Util.asDouble(element.attributeValue(ATTRIB_TIE_THRESHOLD));
        checkTieThreshold(asDouble);
        this.tieThreshold = asDouble;
        List elements = element.element(MultiBinaryClassifier.ELEMENT_INNER).elements();
        if (elements.isEmpty()) {
            throw new InstantiationException("TieClassifier: no inner classifiers found");
        }
        this.inner = new TrainableClassifier[elements.size()];
        Iterator it = elements.iterator();
        for (int i = 0; i < elements.size(); i++) {
            this.inner[i] = (TrainableClassifier) ObjectElement.createObject((Element) it.next());
        }
    }

    public TieClassifier(Set<String> set, FeatureTransformer featureTransformer, File file, String[] strArr, TiesConfiguration tiesConfiguration) throws ProcessingException {
        this(set, featureTransformer, file, strArr, tiesConfiguration.getInt("classifier.tie.layers"), tiesConfiguration.getDouble("classifier.tie.threshold"), tiesConfiguration);
    }

    public TieClassifier(Set<String> set, FeatureTransformer featureTransformer, File file, String[] strArr, int i, double d, TiesConfiguration tiesConfiguration) throws ProcessingException {
        super(set, featureTransformer, tiesConfiguration);
        if (i < 1) {
            throw new IllegalArgumentException("TieClassifier requires at least 1 layer instead of " + i);
        }
        checkTieThreshold(d);
        this.tieThreshold = d;
        this.inner = new TrainableClassifier[i];
        for (int i2 = 0; i2 < this.inner.length; i2++) {
            this.inner[i2] = TrainableClassifier.createClassifier(set, file, null, strArr, tiesConfiguration);
        }
    }

    private void checkTieThreshold(double d) throws IllegalArgumentException {
        if (d < 0.0d || d > 1.0d) {
            throw new IllegalArgumentException("Tie threshold must be in the [0, 1] range: " + d);
        }
    }

    @Override // de.fu_berlin.ties.classify.TrainableClassifier, de.fu_berlin.ties.classify.Classifier
    public void destroy() throws ProcessingException {
        for (int i = 0; i < this.inner.length; i++) {
            this.inner[i].destroy();
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // de.fu_berlin.ties.classify.TrainableClassifier
    public PredictionDistribution doClassify(FeatureVector featureVector, Set set, ContextMap contextMap) throws ProcessingException {
        PredictionDistribution[] predictionDistributionArr = new PredictionDistribution[this.inner.length];
        ContextMap[] contextMapArr = new ContextMap[this.inner.length];
        PredictionDistribution predictionDistribution = null;
        boolean z = true;
        for (int i = 0; i < this.inner.length && z; i++) {
            ContextMap contextMap2 = new ContextMap();
            predictionDistribution = this.inner[i].doClassify(featureVector, set, contextMap2);
            contextMapArr[i] = contextMap2;
            predictionDistributionArr[i] = predictionDistribution;
            Iterator<Prediction> it = predictionDistribution.iterator();
            double prob = it.next().getProbability().getProb();
            double prob2 = it.next().getProbability().getProb();
            if (prob2 >= prob * this.tieThreshold) {
                z = true;
                Util.LOG.debug("Layer " + i + " of TieClassifier: will invoke next layer (if exists) since probability of 2nd best prediction (" + prob2 + ")  >= best prediction (" + prob + ") * tie threshold (" + this.tieThreshold + ")");
            } else {
                z = false;
            }
        }
        contextMap.put(KEY_INNER_CONTEXTS, contextMapArr);
        contextMap.put(KEY_INNER_DISTS, predictionDistributionArr);
        return predictionDistribution;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // de.fu_berlin.ties.classify.TrainableClassifier
    public void doTrain(FeatureVector featureVector, String str, ContextMap contextMap) throws UnsupportedOperationException {
        throw new UnsupportedOperationException("TieClassifier supports only error-driven training -- call trainOnError instead of train");
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // de.fu_berlin.ties.classify.TrainableClassifier
    public boolean doTrainOnError(PredictionDistribution predictionDistribution, FeatureVector featureVector, String str, Set set, ContextMap contextMap) throws ProcessingException {
        PredictionDistribution[] predictionDistributionArr = (PredictionDistribution[]) contextMap.get(KEY_INNER_DISTS);
        ContextMap[] contextMapArr = (ContextMap[]) contextMap.get(KEY_INNER_CONTEXTS);
        boolean z = true;
        for (int i = 0; i < predictionDistributionArr.length && contextMapArr[i] != null; i++) {
            z = this.inner[i].doTrainOnError(predictionDistributionArr[i], featureVector, str, set, contextMapArr[i]);
        }
        return z;
    }

    @Override // de.fu_berlin.ties.classify.TrainableClassifier
    public void reset() throws ProcessingException {
        for (int i = 0; i < this.inner.length; i++) {
            this.inner[i].reset();
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // de.fu_berlin.ties.classify.TrainableClassifier
    public boolean shouldTrain(String str, PredictionDistribution predictionDistribution, ContextMap contextMap) {
        throw new UnsupportedOperationException("TieClassifier: shouldTrain is not required and thus not supported");
    }

    @Override // de.fu_berlin.ties.classify.TrainableClassifier, de.fu_berlin.ties.io.XMLStorable
    public ObjectElement toElement() {
        ObjectElement element = super.toElement();
        element.addAttribute(ATTRIB_TIE_THRESHOLD, Double.toString(this.tieThreshold));
        DefaultElement defaultElement = new DefaultElement(MultiBinaryClassifier.ELEMENT_INNER);
        element.add(defaultElement);
        for (int i = 0; i < this.inner.length; i++) {
            defaultElement.add(this.inner[i].toElement());
        }
        return element;
    }

    @Override // de.fu_berlin.ties.classify.TrainableClassifier
    public String toString() {
        return new ToStringBuilder(this).appendSuper(super.toString()).append("inner classifiers", ArrayUtils.toString(this.inner)).append("tie threshold", this.tieThreshold).toString();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // de.fu_berlin.ties.classify.TrainableClassifier
    public boolean trainOnErrorHook(PredictionDistribution predictionDistribution, FeatureVector featureVector, String str, Set set, ContextMap contextMap) throws ProcessingException {
        PredictionDistribution[] predictionDistributionArr = (PredictionDistribution[]) contextMap.get(KEY_INNER_DISTS);
        ContextMap[] contextMapArr = (ContextMap[]) contextMap.get(KEY_INNER_CONTEXTS);
        boolean z = false;
        for (int i = 0; i < predictionDistributionArr.length && contextMapArr[i] != null; i++) {
            z = this.inner[i].trainOnErrorHook(predictionDistributionArr[i], featureVector, str, set, contextMapArr[i]) || z;
        }
        return z;
    }
}
