package de.fu_berlin.ties.classify;

import de.fu_berlin.ties.ContextMap;
import de.fu_berlin.ties.ProcessingException;
import de.fu_berlin.ties.TiesConfiguration;
import de.fu_berlin.ties.classify.feature.FeatureTransformer;
import de.fu_berlin.ties.classify.feature.FeatureVector;
import de.fu_berlin.ties.io.ObjectElement;
import de.fu_berlin.ties.xml.dom.DOMUtils;
import java.io.File;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
import org.apache.commons.lang.builder.ToStringBuilder;
import org.dom4j.Element;
import org.dom4j.QName;
import org.dom4j.tree.DefaultElement;

/* loaded from: input_file:de/fu_berlin/ties/classify/MultiBinaryClassifier.class */
public class MultiBinaryClassifier extends TrainableClassifier {
    static final QName ELEMENT_INNER = DOMUtils.defaultName("inner");
    static final QName ATTRIB_FOR = DOMUtils.defaultName("for");
    private static final QName ATTRIB_BACKGROUND = DOMUtils.defaultName("background-class");
    private static final String PREFIX_CONTEXT = "context-";
    private static final String PREFIX_DIST = "dist-";
    private final String backgroundClass;
    private final Map<String, TrainableClassifier> innerClassifiers;

    public MultiBinaryClassifier(Element element) throws InstantiationException {
        super(element);
        this.innerClassifiers = new HashMap();
        this.backgroundClass = element.attributeValue(ATTRIB_BACKGROUND);
        Iterator elementIterator = element.elementIterator(ELEMENT_INNER);
        while (elementIterator.hasNext()) {
            Element element2 = (Element) elementIterator.next();
            this.innerClassifiers.put(element2.attributeValue(ATTRIB_FOR), (TrainableClassifier) ObjectElement.createObject(element2.element(TrainableClassifier.ELEMENT_MAIN)));
        }
        if (getAllClasses().size() - 1 != this.innerClassifiers.size()) {
            throw new InstantiationException("Serialization error: Found " + this.innerClassifiers.size() + " inner classifiers but there are " + (getAllClasses().size() - 1) + " foreground classes");
        }
        for (String str : getAllClasses()) {
            if (!this.innerClassifiers.containsKey(str) && !this.backgroundClass.equals(str)) {
                throw new InstantiationException("Serialization error: No inner classifier exists for the foreground class " + str);
            }
        }
    }

    public MultiBinaryClassifier(Set<String> set, FeatureTransformer featureTransformer, File file, String[] strArr, TiesConfiguration tiesConfiguration) throws IllegalArgumentException, ProcessingException {
        super(set, featureTransformer, tiesConfiguration);
        this.innerClassifiers = new HashMap();
        if (set.size() < 3) {
            throw new IllegalArgumentException("MultiBinaryClassifier requires at least 3 classes instead of " + set.size());
        }
        Iterator<String> it = set.iterator();
        this.backgroundClass = it.next();
        while (it.hasNext()) {
            String next = it.next();
            this.innerClassifiers.put(next, TrainableClassifier.createClassifier(createBinarySet(next), file, null, strArr, tiesConfiguration));
        }
    }

    protected Set<String> createBinarySet(String str) {
        TreeSet treeSet = new TreeSet();
        treeSet.add(this.backgroundClass);
        treeSet.add(str);
        return treeSet;
    }

    @Override // de.fu_berlin.ties.classify.TrainableClassifier, de.fu_berlin.ties.classify.Classifier
    public void destroy() throws ProcessingException {
        Iterator<TrainableClassifier> it = this.innerClassifiers.values().iterator();
        while (it.hasNext()) {
            it.next().destroy();
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // de.fu_berlin.ties.classify.TrainableClassifier
    public PredictionDistribution doClassify(FeatureVector featureVector, Set set, ContextMap contextMap) throws ProcessingException {
        ArrayList arrayList = new ArrayList(set.size());
        Iterator it = set.iterator();
        Prediction prediction = null;
        double d = Double.MAX_VALUE;
        while (it.hasNext()) {
            String str = (String) it.next();
            if (!this.backgroundClass.equals(str)) {
                ContextMap contextMap2 = new ContextMap();
                TrainableClassifier trainableClassifier = this.innerClassifiers.get(str);
                PredictionDistribution doClassify = trainableClassifier.doClassify(featureVector, trainableClassifier.getAllClasses(), contextMap2);
                contextMap.put(PREFIX_CONTEXT + str, contextMap2);
                contextMap.put(PREFIX_DIST + str, doClassify);
                Iterator<Prediction> it2 = doClassify.iterator();
                while (it2.hasNext()) {
                    Prediction next = it2.next();
                    if (this.backgroundClass.equals(next.getType())) {
                        double abs = Math.abs(next.getProbability().getProb() - 0.5d);
                        if (abs < d) {
                            prediction = next;
                            d = abs;
                        }
                    } else {
                        arrayList.add(next);
                    }
                }
            }
        }
        if (prediction != null && set.contains(this.backgroundClass)) {
            arrayList.add(prediction);
        }
        PredictionDistribution predictionDistribution = new PredictionDistribution();
        Iterator it3 = arrayList.iterator();
        while (it3.hasNext()) {
            predictionDistribution.add((Prediction) it3.next());
        }
        return predictionDistribution;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // de.fu_berlin.ties.classify.TrainableClassifier
    public void doTrain(FeatureVector featureVector, String str, ContextMap contextMap) throws ProcessingException {
        for (String str2 : this.innerClassifiers.keySet()) {
            TrainableClassifier trainableClassifier = this.innerClassifiers.get(str2);
            String str3 = str2.equals(str) ? str : this.backgroundClass;
            String str4 = PREFIX_CONTEXT + str2;
            trainableClassifier.doTrain(featureVector, str3, contextMap.containsKey(str4) ? (ContextMap) contextMap.get(str4) : new ContextMap());
        }
    }

    public String getBackgroundClass() {
        return this.backgroundClass;
    }

    @Override // de.fu_berlin.ties.classify.TrainableClassifier, de.fu_berlin.ties.io.XMLStorable
    public ObjectElement toElement() {
        ObjectElement element = super.toElement();
        element.addAttribute(ATTRIB_BACKGROUND, this.backgroundClass);
        for (Map.Entry<String, TrainableClassifier> entry : this.innerClassifiers.entrySet()) {
            DefaultElement defaultElement = new DefaultElement(ELEMENT_INNER);
            element.add(defaultElement);
            defaultElement.addAttribute(ATTRIB_FOR, entry.getKey());
            defaultElement.add(entry.getValue().toElement());
        }
        return element;
    }

    @Override // de.fu_berlin.ties.classify.TrainableClassifier
    public String toString() {
        return new ToStringBuilder(this).appendSuper(super.toString()).append("background class", this.backgroundClass).append("inner classifiers", this.innerClassifiers.values()).toString();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // de.fu_berlin.ties.classify.TrainableClassifier
    public boolean trainOnErrorHook(PredictionDistribution predictionDistribution, FeatureVector featureVector, String str, Set set, ContextMap contextMap) throws ProcessingException {
        boolean z = false;
        Iterator<Prediction> it = predictionDistribution.iterator();
        while (it.hasNext()) {
            String type = it.next().getType();
            if (!this.backgroundClass.equals(type)) {
                TrainableClassifier trainableClassifier = this.innerClassifiers.get(type);
                z = trainableClassifier.trainOnErrorHook((PredictionDistribution) contextMap.get(new StringBuilder().append(PREFIX_DIST).append(type).toString()), featureVector, type.equals(str) ? str : this.backgroundClass, trainableClassifier.getAllClasses(), (ContextMap) contextMap.get(new StringBuilder().append(PREFIX_CONTEXT).append(type).toString())) || z;
            }
        }
        return z;
    }

    @Override // de.fu_berlin.ties.classify.TrainableClassifier
    public void reset() throws ProcessingException {
        Iterator<TrainableClassifier> it = this.innerClassifiers.values().iterator();
        while (it.hasNext()) {
            it.next().reset();
        }
    }
}
