package de.fu_berlin.ties.filter;

import de.fu_berlin.ties.ProcessingException;
import de.fu_berlin.ties.TiesConfiguration;
import de.fu_berlin.ties.classify.Classifier;
import de.fu_berlin.ties.classify.PredictionDistribution;
import de.fu_berlin.ties.classify.Reranker;
import de.fu_berlin.ties.classify.TrainableClassifier;
import de.fu_berlin.ties.classify.feature.FeatureVector;
import de.fu_berlin.ties.util.Util;
import java.io.File;
import java.io.IOException;
import java.util.Collections;
import java.util.SortedSet;
import java.util.TreeSet;
import org.apache.commons.lang.builder.ToStringBuilder;
import org.dom4j.Document;
import org.dom4j.Element;
import org.dom4j.NodeFilter;

/* loaded from: input_file:de/fu_berlin/ties/filter/TrainableFilter.class */
public abstract class TrainableFilter implements ElementFilter {
    public static final SortedSet<String> BOOLEAN_CLASSES;
    private final TrainableClassifier classifier;
    private Element lastElement;
    private FeatureVector lastFeatures;
    private NodeFilter avoidFilter;
    private NodeFilter preferredFilter;
    private final Reranker reranker;

    public TrainableFilter(TiesConfiguration tiesConfiguration, NodeFilter nodeFilter, NodeFilter nodeFilter2, Reranker reranker) throws ProcessingException {
        this.classifier = TrainableClassifier.createClassifier(BOOLEAN_CLASSES, tiesConfiguration, "filter");
        this.preferredFilter = nodeFilter;
        this.avoidFilter = nodeFilter2;
        this.reranker = reranker;
    }

    public abstract FeatureVector buildFeatures(Element element);

    @Override // de.fu_berlin.ties.filter.ElementFilter
    public void init(Document document, File file) throws ProcessingException, IOException {
    }

    @Override // de.fu_berlin.ties.filter.ElementFilter
    public boolean avoids(Element element) {
        return this.avoidFilter.matches(element);
    }

    @Override // de.fu_berlin.ties.filter.ElementFilter
    public boolean matches(Element element) throws ProcessingException {
        FeatureVector buildFeatures = buildFeatures(element);
        this.lastElement = element;
        this.lastFeatures = buildFeatures;
        PredictionDistribution classify = this.classifier.classify(buildFeatures, BOOLEAN_CLASSES);
        return Util.asBoolean((this.reranker != null ? this.reranker.rerank(classify) : classify).best().getType());
    }

    @Override // de.fu_berlin.ties.filter.ElementFilter
    public boolean prefers(Element element) {
        return this.preferredFilter.matches(element);
    }

    public String toString() {
        return new ToStringBuilder(this).append(Classifier.CONFIG_CLASSIFIER, this.classifier).append("preferred filter", this.preferredFilter).append("reranker", this.reranker).toString();
    }

    public PredictionDistribution trainIfNecessary(Element element, boolean z) throws ProcessingException {
        return this.classifier.trainOnError(element.equals(this.lastElement) ? this.lastFeatures : buildFeatures(element), Boolean.toString(z), BOOLEAN_CLASSES);
    }

    static {
        TreeSet treeSet = new TreeSet();
        treeSet.add(Boolean.FALSE.toString());
        treeSet.add(Boolean.TRUE.toString());
        BOOLEAN_CLASSES = Collections.unmodifiableSortedSet(treeSet);
    }
}
