package de.fu_berlin.ties.extract;

import de.fu_berlin.ties.ContextMap;
import de.fu_berlin.ties.ProcessingException;
import de.fu_berlin.ties.TiesConfiguration;
import de.fu_berlin.ties.classify.Classifier;
import de.fu_berlin.ties.classify.Prediction;
import de.fu_berlin.ties.classify.PredictionDistribution;
import de.fu_berlin.ties.classify.Probability;
import de.fu_berlin.ties.classify.Reranker;
import de.fu_berlin.ties.combi.CombinationState;
import de.fu_berlin.ties.combi.CombinationStrategy;
import de.fu_berlin.ties.context.Recognition;
import de.fu_berlin.ties.context.Representation;
import de.fu_berlin.ties.eval.FMetricsView;
import de.fu_berlin.ties.filter.EmbeddingElements;
import de.fu_berlin.ties.filter.FilteringTokenWalker;
import de.fu_berlin.ties.filter.TrainableFilter;
import de.fu_berlin.ties.io.FieldContainer;
import de.fu_berlin.ties.text.TokenDetails;
import de.fu_berlin.ties.text.TokenizerFactory;
import de.fu_berlin.ties.util.MathUtils;
import de.fu_berlin.ties.util.Util;
import java.io.File;
import java.io.IOException;
import java.io.Writer;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import org.apache.commons.lang.ArrayUtils;
import org.apache.commons.lang.builder.ToStringBuilder;
import org.dom4j.Document;
import org.dom4j.Element;

/* loaded from: input_file:de/fu_berlin/ties/extract/Extractor.class */
public class Extractor extends ExtractorBase {
    public static final String EXT_EXTRACTIONS = "ext";
    private Document lastDocument;
    private ExtractionContainer predictedExtractions;
    private final List<TokenDetails> punctuationDetails;
    private final Reranker reranker;

    private static Reranker createReranker(TiesConfiguration tiesConfiguration) {
        return new Reranker(tiesConfiguration.subset("extract"));
    }

    public Extractor() throws IllegalArgumentException, ProcessingException {
        this(EXT_EXTRACTIONS);
    }

    public Extractor(String str) throws IllegalArgumentException, ProcessingException {
        this(str, TiesConfiguration.CONF);
    }

    public Extractor(String str, TiesConfiguration tiesConfiguration) throws IllegalArgumentException, ProcessingException {
        super(str, tiesConfiguration);
        this.punctuationDetails = new ArrayList();
        this.reranker = createReranker(tiesConfiguration);
    }

    public Extractor(String str, File file, TiesConfiguration tiesConfiguration) throws IllegalArgumentException, ProcessingException {
        super(str, file, tiesConfiguration);
        this.punctuationDetails = new ArrayList();
        this.reranker = createReranker(tiesConfiguration);
    }

    public Extractor(String str, Trainer trainer) {
        this(str, trainer.getTargetStructure(), trainer.getClassifiers(), trainer.getRepresentation(), trainer.getStrategy(), trainer.getFactory(), trainer.getSentenceFilter(), createReranker(trainer.getConfig()), trainer.viewRelevantPunctuation(), trainer.getConfig());
    }

    public Extractor(String str, TargetStructure targetStructure, Classifier[] classifierArr, Representation representation, CombinationStrategy combinationStrategy, TokenizerFactory tokenizerFactory, TrainableFilter trainableFilter, Reranker reranker, Set<String> set, TiesConfiguration tiesConfiguration) {
        super(str, targetStructure, classifierArr, representation, combinationStrategy, tokenizerFactory, trainableFilter, set, tiesConfiguration);
        this.punctuationDetails = new ArrayList();
        this.reranker = reranker;
    }

    protected void addPunctuationDetails(TokenDetails tokenDetails) {
        this.punctuationDetails.add(tokenDetails);
    }

    protected void appendPunctuation(Extraction extraction) {
        if (this.punctuationDetails.isEmpty()) {
            return;
        }
        Iterator<TokenDetails> it = this.punctuationDetails.iterator();
        while (it.hasNext()) {
            extraction.addToken(it.next(), null, true);
        }
        clearPunctuation();
    }

    protected void clearPunctuation() {
        this.punctuationDetails.clear();
    }

    @Override // de.fu_berlin.ties.extract.ExtractorBase
    protected FilteringTokenWalker createFilteringTokenWalker(TrainableFilter trainableFilter) {
        return new FilteringTokenWalker(this, getFactory(), trainableFilter, this);
    }

    private void discardLastExtraction() {
        Extraction removeLast = getPredictedExtractions().removeLast();
        Recognition removeLast2 = getPriorRecognitions().removeLast();
        if (removeLast2 != null && !removeLast.equals(removeLast2)) {
            throw new IllegalStateException("Extractions discarded from container " + removeLast + " and from prior recognitions " + removeLast2 + " differ");
        }
        Util.LOG.debug("Discarded last extraction " + removeLast2);
    }

    public ExtractionContainer extract(Document document) throws IOException, ProcessingException {
        initFields();
        this.lastDocument = document;
        this.predictedExtractions = new ExtractionContainer(getTargetStructure());
        getWalker().walk(document, null);
        resetStrategy();
        return getPredictedExtractions();
    }

    public FMetricsView evaluateSentenceFiltering(ExtractionContainer extractionContainer) {
        return evaluateSentenceFiltering(new EmbeddingElements(this.lastDocument, extractionContainer, getFactory()));
    }

    protected ExtractionContainer getPredictedExtractions() {
        return this.predictedExtractions;
    }

    @Override // de.fu_berlin.ties.DocumentReader
    public void process(Document document, Writer writer, ContextMap contextMap) throws IOException, ProcessingException {
        ExtractionContainer extract = extract(document);
        FieldContainer createFieldContainer = FieldContainer.createFieldContainer();
        extract.storeEntries(createFieldContainer);
        createFieldContainer.store(writer);
    }

    @Override // de.fu_berlin.ties.xml.dom.TokenProcessor
    public void processToken(Element element, String str, TokenDetails tokenDetails, String str2, ContextMap contextMap) throws ProcessingException {
        if (!isRelevant(tokenDetails.getToken())) {
            Extraction last = getPredictedExtractions().last();
            if (last == null || last.isSealed()) {
                Util.LOG.debug("Skipping over irrelevant punctuation token " + tokenDetails.getToken());
                return;
            } else {
                addPunctuationDetails(tokenDetails);
                Util.LOG.debug("Keeping irrelevant punctuation token " + tokenDetails.getToken() + " -- might become part of the current " + last.getType() + " extraction");
                return;
            }
        }
        updateState(element, str, tokenDetails.getToken(), str2);
        Classifier[] classifiers = getClassifiers();
        PredictionDistribution[] predictionDistributionArr = new PredictionDistribution[classifiers.length];
        PredictionDistribution[] predictionDistributionArr2 = new PredictionDistribution[classifiers.length];
        Prediction[] predictionArr = new Prediction[classifiers.length];
        String[] strArr = new String[classifiers.length];
        double[] dArr = new double[classifiers.length];
        double[] dArr2 = new double[classifiers.length];
        for (int i = 0; i < predictionDistributionArr.length; i++) {
            predictionDistributionArr[i] = classifiers[i].classify(getFeatures(), getActiveClasses()[i]);
            predictionDistributionArr2[i] = this.reranker.rerank(predictionDistributionArr[i]);
            predictionArr[i] = predictionDistributionArr2[i].best();
            strArr[i] = predictionArr[i].getType();
            Probability probability = predictionArr[i].getProbability();
            dArr[i] = probability.getProb();
            dArr2[i] = probability.getPR();
        }
        CombinationState translateResult = getStrategy().translateResult(predictionDistributionArr2);
        double mean = MathUtils.mean(dArr);
        double mean2 = MathUtils.mean(dArr2);
        Util.LOG.debug("Predicted types: '" + ArrayUtils.toString(strArr) + "'; translated state: " + translateResult + ", mean prob.: " + mean + ", mean pR: " + mean2);
        if (translateResult.isDiscardPreceding()) {
            discardLastExtraction();
        }
        if (translateResult.getType() == null) {
            Extraction last2 = getPredictedExtractions().last();
            if (last2 != null && !last2.isSealed()) {
                last2.setSealed(true);
            }
            clearPunctuation();
        } else if (translateResult.isBegin()) {
            Extraction extraction = new Extraction(translateResult.getType(), tokenDetails, new Probability(mean, mean2));
            getPredictedExtractions().add(extraction);
            getPriorRecognitions().add(extraction);
            clearPunctuation();
        } else {
            Extraction last3 = getPredictedExtractions().last();
            appendPunctuation(last3);
            if (!last3.getType().equals(translateResult.getType())) {
                throw new IllegalStateException("Type mismatch: " + translateResult + " cannot continue extraction " + last3);
            }
            last3.addToken(tokenDetails, new Probability(mean, mean2), true);
        }
        getStrategy().updateState(translateResult);
    }

    @Override // de.fu_berlin.ties.extract.ExtractorBase
    protected void resetStrategy() {
        if (getStrategy().reset()) {
            discardLastExtraction();
        }
    }

    @Override // de.fu_berlin.ties.extract.ExtractorBase, de.fu_berlin.ties.TextProcessor
    public String toString() {
        return new ToStringBuilder(this).appendSuper(super.toString()).append("reranker", this.reranker).append("predicted extractions", this.predictedExtractions).toString();
    }
}
