package de.fu_berlin.ties.classify;

import de.fu_berlin.ties.ContextMap;
import de.fu_berlin.ties.ProcessingException;
import de.fu_berlin.ties.TiesConfiguration;
import de.fu_berlin.ties.classify.feature.FeatureTransformer;
import de.fu_berlin.ties.classify.feature.FeatureVector;
import de.fu_berlin.ties.classify.winnow.UltraconservativeWinnow;
import de.fu_berlin.ties.classify.winnow.Winnow;
import de.fu_berlin.ties.extract.Extractor;
import de.fu_berlin.ties.io.ObjectElement;
import de.fu_berlin.ties.io.XMLStorable;
import de.fu_berlin.ties.text.TextUtils;
import de.fu_berlin.ties.util.CollUtils;
import de.fu_berlin.ties.util.Util;
import de.fu_berlin.ties.xml.dom.DOMUtils;
import java.io.File;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;
import org.apache.commons.lang.ArrayUtils;
import org.apache.commons.lang.builder.ToStringBuilder;
import org.dom4j.Element;
import org.dom4j.QName;

/* loaded from: input_file:de/fu_berlin/ties/classify/TrainableClassifier.class */
public abstract class TrainableClassifier implements Classifier, XMLStorable {
    public static final QName ELEMENT_MAIN = DOMUtils.defaultName(Classifier.CONFIG_CLASSIFIER);
    static final QName ATTRIB_CLASSES = DOMUtils.defaultName("classes");
    static final QName ATTRIB_TRAIN_ALL = DOMUtils.defaultName("train-all");
    public static final String META_CLASSIFIER = "meta";
    public static final String MULTI_CLASSIFIER = "multi";
    public static final String OAR_CLASSIFIER = "oar";
    public static final String TIE_CLASSIFIER = "tie";
    private static final Set<String> WRAPPING_CLASSIFIERS;
    private final Set<String> allClasses;
    private final TiesConfiguration config;
    private final boolean trainingAll;
    private final FeatureTransformer transformer;
    private PredictionDistribution cachedPredictions;
    private FeatureVector cachedOrgFeatures;
    private FeatureVector cachedActualFeatures;
    private ContextMap cachedContext;

    public static TrainableClassifier createClassifier(Set<String> set) throws IllegalArgumentException, ProcessingException {
        return createClassifier(set, TiesConfiguration.CONF);
    }

    public static TrainableClassifier createClassifier(Set<String> set, TiesConfiguration tiesConfiguration) throws IllegalArgumentException, ProcessingException {
        return createClassifier(set, tiesConfiguration, null);
    }

    public static TrainableClassifier createClassifier(Set<String> set, TiesConfiguration tiesConfiguration, String str) throws IllegalArgumentException, ProcessingException {
        return createClassifier(set, null, tiesConfiguration, str);
    }

    public static TrainableClassifier createClassifier(Set<String> set, File file, TiesConfiguration tiesConfiguration, String str) throws IllegalArgumentException, ProcessingException {
        return createClassifier(set, file, FeatureTransformer.createTransformer(tiesConfiguration), tiesConfiguration.getStringArray(tiesConfiguration.adaptKey(Classifier.CONFIG_CLASSIFIER, str)), tiesConfiguration);
    }

    public static TrainableClassifier createClassifier(Set<String> set, File file, FeatureTransformer featureTransformer, String[] strArr, TiesConfiguration tiesConfiguration) throws IllegalArgumentException, ProcessingException {
        TrainableClassifier trainableClassifier;
        if (strArr == null || strArr.length < 1) {
            throw new IllegalArgumentException("Cannot create classifier -- specification is null or empty");
        }
        String lowerCase = strArr[0].toLowerCase();
        if (Extractor.EXT_EXTRACTIONS.equals(lowerCase)) {
            trainableClassifier = new ExternalClassifier(set, featureTransformer, file, tiesConfiguration);
        } else if ("winnow".equals(lowerCase)) {
            trainableClassifier = new Winnow(set, featureTransformer, tiesConfiguration);
        } else if ("ucwinnow".equals(lowerCase)) {
            trainableClassifier = new UltraconservativeWinnow(set, featureTransformer, tiesConfiguration);
        } else if ("moon".equals(lowerCase)) {
            trainableClassifier = new MoonClassifier(set, featureTransformer, tiesConfiguration);
        } else if (WRAPPING_CLASSIFIERS.contains(lowerCase)) {
            String[] strArr2 = new String[strArr.length - 1];
            for (int i = 0; i < strArr2.length; i++) {
                strArr2[i] = strArr[i + 1];
            }
            if (TIE_CLASSIFIER.equals(lowerCase)) {
                trainableClassifier = new TieClassifier(set, featureTransformer, file, strArr2, tiesConfiguration);
            } else if (META_CLASSIFIER.equals(lowerCase)) {
                trainableClassifier = new MetaClassifier(set, featureTransformer, file, strArr2, tiesConfiguration);
            } else if (set.size() <= 2) {
                trainableClassifier = createClassifier(set, file, featureTransformer, strArr2, tiesConfiguration);
            } else if (MULTI_CLASSIFIER.equals(lowerCase)) {
                trainableClassifier = new MultiBinaryClassifier(set, featureTransformer, file, strArr2, tiesConfiguration);
            } else {
                if (!OAR_CLASSIFIER.equals(lowerCase)) {
                    throw new RuntimeException("Implementation error: unknown wrapping classifier" + strArr[0]);
                }
                trainableClassifier = new OneAgainstTheRestClassifier(set, featureTransformer, file, strArr2, tiesConfiguration);
            }
        } else {
            try {
                trainableClassifier = (TrainableClassifier) Util.createObject(Class.forName(strArr[0]), new Object[]{set, featureTransformer, tiesConfiguration}, new Class[]{Set.class, FeatureTransformer.class, TiesConfiguration.class});
            } catch (ClassNotFoundException e) {
                throw new ProcessingException("Cannot create classifier from specification " + ArrayUtils.toString(strArr) + ": " + e.toString());
            } catch (InstantiationException e2) {
                throw new ProcessingException("Cannot create classifier from specification " + ArrayUtils.toString(strArr), e2);
            }
        }
        return trainableClassifier;
    }

    public TrainableClassifier(Element element) throws InstantiationException {
        this(CollUtils.asStringSet(element.attributeValue(ATTRIB_CLASSES)), (FeatureTransformer) ObjectElement.createNextObject(element.elementIterator(FeatureTransformer.ELEMENT_MAIN)), Util.asBoolean(element.attributeValue(ATTRIB_TRAIN_ALL)), TiesConfiguration.CONF);
    }

    public TrainableClassifier(Set<String> set, FeatureTransformer featureTransformer, TiesConfiguration tiesConfiguration) {
        this(set, featureTransformer, tiesConfiguration.getBoolean("classifier.train.all"), tiesConfiguration);
    }

    public TrainableClassifier(Set<String> set, FeatureTransformer featureTransformer, boolean z, TiesConfiguration tiesConfiguration) {
        this.cachedPredictions = null;
        this.cachedOrgFeatures = null;
        this.cachedActualFeatures = null;
        this.cachedContext = null;
        Iterator<String> it = set.iterator();
        while (it.hasNext()) {
            TextUtils.ensurePrintableName(it.next());
        }
        this.allClasses = Collections.unmodifiableSet(set);
        this.config = tiesConfiguration;
        this.transformer = featureTransformer;
        this.trainingAll = z;
    }

    private void checkCandidateClass(Set set) throws IllegalArgumentException {
        Iterator it = set.iterator();
        while (it.hasNext()) {
            String str = (String) it.next();
            if (!this.allClasses.contains(str)) {
                throw new IllegalArgumentException("Candidate class " + str + " is not in the set of valid classes: " + this.allClasses);
            }
        }
    }

    private void checkTargetClass(String str) throws IllegalArgumentException {
        if (!this.allClasses.contains(str)) {
            throw new IllegalArgumentException("Target class " + str + " is not in the set of valid classes: " + this.allClasses);
        }
    }

    @Override // de.fu_berlin.ties.classify.Classifier
    public final PredictionDistribution classify(FeatureVector featureVector, Set set) throws IllegalArgumentException, ProcessingException {
        checkCandidateClass(set);
        FeatureVector transform = this.transformer != null ? this.transformer.transform(featureVector) : featureVector;
        ContextMap contextMap = new ContextMap();
        PredictionDistribution doClassify = doClassify(transform, set, contextMap);
        this.cachedPredictions = doClassify;
        this.cachedOrgFeatures = featureVector;
        this.cachedActualFeatures = transform;
        this.cachedContext = contextMap;
        return doClassify;
    }

    @Override // de.fu_berlin.ties.classify.Classifier
    public void destroy() throws ProcessingException {
        reset();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public abstract PredictionDistribution doClassify(FeatureVector featureVector, Set set, ContextMap contextMap) throws ProcessingException;

    /* JADX INFO: Access modifiers changed from: protected */
    public abstract void doTrain(FeatureVector featureVector, String str, ContextMap contextMap) throws ProcessingException;

    /* JADX INFO: Access modifiers changed from: protected */
    public boolean doTrainOnError(PredictionDistribution predictionDistribution, FeatureVector featureVector, String str, Set set, ContextMap contextMap) throws ProcessingException {
        boolean trainOnErrorHook = trainOnErrorHook(predictionDistribution, featureVector, str, set, contextMap);
        boolean shouldTrain = shouldTrain(str, predictionDistribution, contextMap);
        if (shouldTrain && !trainOnErrorHook) {
            doTrain(featureVector, str, contextMap);
        }
        return shouldTrain;
    }

    public Set<String> getAllClasses() {
        return this.allClasses;
    }

    public TiesConfiguration getConfig() {
        return this.config;
    }

    public abstract void reset() throws ProcessingException;

    /* JADX INFO: Access modifiers changed from: protected */
    public boolean shouldTrain(String str, PredictionDistribution predictionDistribution, ContextMap contextMap) {
        Prediction best = predictionDistribution.best();
        double prob = best.getProbability().getProb();
        return !best.getType().equals(str) || Double.isNaN(prob) || prob <= 0.0d;
    }

    public ObjectElement toElement() {
        ObjectElement objectElement = new ObjectElement(ELEMENT_MAIN, getClass());
        objectElement.addAttribute(ATTRIB_CLASSES, CollUtils.flatten(this.allClasses.iterator()));
        objectElement.addAttribute(ATTRIB_TRAIN_ALL, Boolean.toString(this.trainingAll));
        if (this.transformer != null) {
            objectElement.add(this.transformer.toElement());
        }
        return objectElement;
    }

    public String toString() {
        ToStringBuilder toStringBuilder = new ToStringBuilder(this);
        toStringBuilder.append("all classes", this.allClasses);
        if (this.trainingAll) {
            toStringBuilder.append("training all classes", this.trainingAll);
        }
        if (this.transformer != null) {
            toStringBuilder.append("transformer", this.transformer);
        }
        return toStringBuilder.toString();
    }

    public final void train(FeatureVector featureVector, String str) throws IllegalArgumentException, ProcessingException {
        checkTargetClass(str);
        doTrain(this.transformer != null ? this.transformer.transform(featureVector) : featureVector, str, new ContextMap());
    }

    public final PredictionDistribution trainOnError(FeatureVector featureVector, String str, Set set) throws ProcessingException {
        FeatureVector transform;
        ContextMap contextMap;
        PredictionDistribution doClassify;
        checkTargetClass(str);
        checkCandidateClass(set);
        Set set2 = this.trainingAll ? this.allClasses : set;
        if (featureVector == this.cachedOrgFeatures) {
            transform = this.cachedActualFeatures;
            contextMap = this.cachedContext;
            doClassify = this.cachedPredictions;
        } else {
            transform = this.transformer != null ? this.transformer.transform(featureVector) : featureVector;
            contextMap = new ContextMap();
            doClassify = doClassify(transform, set2, contextMap);
        }
        if (!doTrainOnError(doClassify, transform, str, set2, contextMap)) {
            return null;
        }
        if (this.trainingAll) {
            Iterator<Prediction> it = doClassify.iterator();
            while (it.hasNext()) {
                if (!set.contains(it.next().getType())) {
                    it.remove();
                }
            }
        }
        if (doClassify.size() <= 0 || doClassify.best().getType().equals(str)) {
            return null;
        }
        return doClassify;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public boolean trainOnErrorHook(PredictionDistribution predictionDistribution, FeatureVector featureVector, String str, Set set, ContextMap contextMap) throws ProcessingException {
        return false;
    }

    static {
        HashSet hashSet = new HashSet();
        hashSet.add(META_CLASSIFIER);
        hashSet.add(MULTI_CLASSIFIER);
        hashSet.add(OAR_CLASSIFIER);
        hashSet.add(TIE_CLASSIFIER);
        WRAPPING_CLASSIFIERS = Collections.unmodifiableSet(hashSet);
    }
}
