package de.fu_berlin.ties.extract;

import de.fu_berlin.ties.ContextMap;
import de.fu_berlin.ties.ProcessingException;
import de.fu_berlin.ties.TextProcessor;
import de.fu_berlin.ties.TiesConfiguration;
import de.fu_berlin.ties.classify.Prediction;
import de.fu_berlin.ties.classify.PredictionDistribution;
import de.fu_berlin.ties.classify.Probability;
import de.fu_berlin.ties.classify.TrainableClassifier;
import de.fu_berlin.ties.combi.CombinationState;
import de.fu_berlin.ties.combi.CombinationStrategy;
import de.fu_berlin.ties.context.Representation;
import de.fu_berlin.ties.eval.Accuracy;
import de.fu_berlin.ties.eval.AccuracyView;
import de.fu_berlin.ties.eval.FMetricsView;
import de.fu_berlin.ties.filter.EmbeddingElements;
import de.fu_berlin.ties.filter.FilteringTokenWalker;
import de.fu_berlin.ties.filter.Oracle;
import de.fu_berlin.ties.filter.TrainableFilter;
import de.fu_berlin.ties.filter.TrainableFilteringTokenWalker;
import de.fu_berlin.ties.text.TokenDetails;
import de.fu_berlin.ties.text.TokenizerFactory;
import de.fu_berlin.ties.util.CollectionUtils;
import de.fu_berlin.ties.util.Util;
import de.fu_berlin.ties.xml.dom.DOMUtils;
import java.io.File;
import java.io.IOException;
import java.io.Writer;
import java.util.Set;
import org.apache.commons.lang.ArrayUtils;
import org.apache.commons.lang.StringUtils;
import org.apache.commons.lang.builder.ToStringBuilder;
import org.dom4j.Document;
import org.dom4j.Element;

/* loaded from: input_file:de/fu_berlin/ties/extract/Trainer.class */
public class Trainer extends ExtractorBase implements Oracle {
    public static final String CONFIG_TOE = "train.only-errors";
    public static final String CONFIG_TEST_ONLY = "train.test-only";
    public static final String PREFIX_GLOBAL_ACC = "Overall ";
    public static final String PREFIX_LOCAL_ACC = "Document ";
    private boolean sentenceTrainingEnabled;
    private final TrainableClassifier[] trainableClassifiers;
    private final boolean trainingOnlyErrors;
    private final boolean testingOnly;
    private Accuracy[] globalAccuracies;
    private Accuracy[] localAccuracies;
    private ExtractionLocator locator;
    private EmbeddingElements embeddingElements;
    private Extraction partialExtraction;

    public Trainer() throws IllegalArgumentException, ProcessingException {
        this("tmp");
    }

    public Trainer(String str) throws IllegalArgumentException, ProcessingException {
        this(str, TiesConfiguration.CONF);
    }

    public Trainer(String str, TiesConfiguration tiesConfiguration) throws IllegalArgumentException, ProcessingException {
        this(str, null, tiesConfiguration);
    }

    public Trainer(String str, File file, TiesConfiguration tiesConfiguration) throws IllegalArgumentException, ProcessingException {
        super(str, file, tiesConfiguration);
        this.sentenceTrainingEnabled = true;
        this.globalAccuracies = null;
        this.localAccuracies = null;
        this.trainingOnlyErrors = tiesConfiguration.getBoolean(CONFIG_TOE);
        this.testingOnly = tiesConfiguration.getBoolean(CONFIG_TEST_ONLY);
        this.trainableClassifiers = new TrainableClassifier[getClassifiers().length];
        for (int i = 0; i < this.trainableClassifiers.length; i++) {
            this.trainableClassifiers[i] = (TrainableClassifier) getClassifiers()[i];
        }
        resetGlobalAccuracy();
    }

    public Trainer(String str, TargetStructure targetStructure, TrainableClassifier[] trainableClassifierArr, Representation representation, CombinationStrategy combinationStrategy, TokenizerFactory tokenizerFactory, TrainableFilter trainableFilter) {
        this(str, targetStructure, trainableClassifierArr, representation, combinationStrategy, tokenizerFactory, trainableFilter, CollectionUtils.arrayAsSet(TiesConfiguration.CONF.getStringArray(ExtractorBase.CONFIG_RELEVANT_PUNCTUATION)), TiesConfiguration.CONF.getBoolean(CONFIG_TOE), TiesConfiguration.CONF.getBoolean(CONFIG_TEST_ONLY), TiesConfiguration.CONF);
    }

    public Trainer(String str, TargetStructure targetStructure, TrainableClassifier[] trainableClassifierArr, Representation representation, CombinationStrategy combinationStrategy, TokenizerFactory tokenizerFactory, TrainableFilter trainableFilter, Set<String> set, boolean z, boolean z2, TiesConfiguration tiesConfiguration) {
        super(str, targetStructure, trainableClassifierArr, representation, combinationStrategy, tokenizerFactory, trainableFilter, set, tiesConfiguration);
        this.sentenceTrainingEnabled = true;
        this.globalAccuracies = null;
        this.localAccuracies = null;
        this.trainingOnlyErrors = z;
        this.testingOnly = z2;
        this.trainableClassifiers = trainableClassifierArr;
        resetGlobalAccuracy();
    }

    @Override // de.fu_berlin.ties.extract.ExtractorBase
    protected FilteringTokenWalker createFilteringTokenWalker(TrainableFilter trainableFilter) {
        return new TrainableFilteringTokenWalker(this, getFactory(), trainableFilter, this, this, this.sentenceTrainingEnabled);
    }

    public void disableSentenceTraining() {
        this.sentenceTrainingEnabled = false;
    }

    public void enableSentenceTraining() {
        this.sentenceTrainingEnabled = true;
    }

    public FMetricsView evaluateSentenceFiltering() {
        return evaluateSentenceFiltering(this.embeddingElements);
    }

    private TrainableClassifier[] getTrainableClassifiers() {
        return this.trainableClassifiers;
    }

    private Accuracy[] initAccuracies(String str) {
        Accuracy[] accuracyArr = new Accuracy[getClassifiers().length];
        for (int i = 0; i < accuracyArr.length; i++) {
            accuracyArr[i] = new Accuracy(str);
        }
        return accuracyArr;
    }

    public boolean isTestingOnly() {
        return this.testingOnly;
    }

    public boolean isTrainingOnlyErrors() {
        return this.trainingOnlyErrors;
    }

    @Override // de.fu_berlin.ties.DocumentReader
    public void process(Document document, Writer writer, ContextMap contextMap) throws IOException, ProcessingException {
        train(document, AnswerBuilder.readCorrespondingAnswerKeys(getTargetStructure(), new File((File) contextMap.get(TextProcessor.KEY_DIRECTORY), (String) contextMap.get(TextProcessor.KEY_LOCAL_NAME)), getConfig()));
    }

    @Override // de.fu_berlin.ties.xml.dom.TokenProcessor
    public void processToken(Element element, String str, TokenDetails tokenDetails, String str2, ContextMap contextMap) throws ProcessingException {
        String str3;
        boolean z;
        updateState(element, str, tokenDetails.getToken(), str2);
        boolean startOfExtraction = this.locator.startOfExtraction(tokenDetails.getToken(), tokenDetails.getRep());
        if (startOfExtraction) {
            Util.LOG.debug("Starting extraction (" + tokenDetails.getToken() + " token)");
        }
        if (this.locator.inExtraction()) {
            if (!this.locator.updateExtraction(tokenDetails.getToken(), tokenDetails.getRep())) {
                str3 = null;
            } else if (startOfExtraction) {
                str3 = this.locator.getCurrentExtraction().getType();
                this.partialExtraction = new Extraction(str3, new TokenDetails(tokenDetails.getToken(), this.locator.getCurrentExtraction().getFirstTokenRep(), this.locator.getCurrentExtraction().getIndex(), false));
                getPriorRecognitions().add(this.partialExtraction);
            } else {
                str3 = this.locator.getCurrentExtraction().getType();
                this.partialExtraction.addToken(tokenDetails, true);
            }
            z = this.locator.endOfExtraction();
        } else {
            str3 = null;
            z = false;
        }
        CombinationState combinationState = str3 == null ? CombinationState.OUTSIDE : new CombinationState(str3, startOfExtraction, z);
        if (str3 == null && this.partialExtraction != null && !this.partialExtraction.isSealed()) {
            this.partialExtraction.setSealed(true);
        }
        boolean isRelevant = isRelevant(tokenDetails.getToken());
        if (!isRelevant && this.locator.inExtraction() && (startOfExtraction || z)) {
            markRelevant(tokenDetails.getToken());
            isRelevant = true;
            Util.LOG.debug("Marked punctuation token " + tokenDetails.getToken() + " as relevant since it is the " + (startOfExtraction ? "first" : "last") + " token of a " + str3 + " extraction");
        }
        if (isRelevant) {
            String[] translateCurrentState = getStrategy().translateCurrentState(combinationState);
            Util.LOG.debug("Current state: " + combinationState + "; translated state: '" + ArrayUtils.toString(translateCurrentState) + "'");
            TrainableClassifier[] trainableClassifiers = getTrainableClassifiers();
            for (int i = 0; i < translateCurrentState.length; i++) {
                if (this.testingOnly) {
                    PredictionDistribution[] predictionDistributionArr = new PredictionDistribution[translateCurrentState.length];
                    for (int i2 = 0; i2 < predictionDistributionArr.length; i2++) {
                        predictionDistributionArr[i2] = new PredictionDistribution(new Prediction(translateCurrentState[i2], new Probability(1.0d)));
                    }
                    CombinationState translateResult = getStrategy().translateResult(predictionDistributionArr);
                    if (!StringUtils.equals(combinationState.getType(), translateResult.getType()) || (combinationState.getType() != null && combinationState.isBegin() != translateResult.isBegin())) {
                        Util.LOG.error("Error in combination strategy: incorrect re-translation " + translateResult + " of current state " + combinationState);
                    }
                } else if (this.trainingOnlyErrors) {
                    PredictionDistribution trainOnError = trainableClassifiers[i].trainOnError(getFeatures(), translateCurrentState[i], getActiveClasses()[i]);
                    if (trainOnError == null) {
                        this.localAccuracies[i].incTrueCount();
                        this.globalAccuracies[i].incTrueCount();
                    } else {
                        this.localAccuracies[i].incFalseCount();
                        this.globalAccuracies[i].incFalseCount();
                        Util.LOG.debug(DOMUtils.showToken(element, tokenDetails.getToken()) + " classifier " + i + " misclassified " + trainOnError.best() + " instead of " + translateCurrentState[i]);
                    }
                } else {
                    trainableClassifiers[i].train(getFeatures(), translateCurrentState[i]);
                    Util.LOG.debug("Trained classifier");
                }
            }
            getStrategy().updateState(combinationState);
        } else {
            Util.LOG.debug("Skipping over irrelevant punctuation token " + tokenDetails.getToken());
        }
        if (this.locator.endOfExtraction()) {
            this.locator.switchToNextExtraction();
        }
    }

    public void reset() throws ProcessingException {
        for (TrainableClassifier trainableClassifier : getTrainableClassifiers()) {
            trainableClassifier.reset();
        }
    }

    public void resetGlobalAccuracy() {
        if (this.trainingOnlyErrors) {
            this.globalAccuracies = initAccuracies(PREFIX_GLOBAL_ACC);
        }
    }

    @Override // de.fu_berlin.ties.extract.ExtractorBase
    protected void resetStrategy() {
        if (getStrategy().reset()) {
            Util.LOG.warn("Combination strategy " + getStrategy() + " ordered to discard the last extraction -- this is not supposed to happen when training");
        }
    }

    @Override // de.fu_berlin.ties.filter.Oracle
    public boolean shouldMatch(Element element) {
        return this.embeddingElements.containsExtraction(element);
    }

    @Override // de.fu_berlin.ties.extract.ExtractorBase, de.fu_berlin.ties.TextProcessor
    public String toString() {
        ToStringBuilder append = new ToStringBuilder(this).appendSuper(super.toString()).append("training only errors", this.trainingOnlyErrors).append("testing only", this.testingOnly).append("locator", this.locator);
        if (getSentenceFilter() != null) {
            append.append("sentence training enabled", this.sentenceTrainingEnabled);
        }
        return append.toString();
    }

    public Accuracy[] train(Document document, ExtractionContainer extractionContainer) throws IOException, ProcessingException {
        initFields();
        this.locator = new ExtractionLocator(document, extractionContainer, getFactory().createTokenizer(""));
        if (isSentenceFiltering()) {
            this.embeddingElements = new EmbeddingElements(document, extractionContainer, getFactory());
        }
        if (this.trainingOnlyErrors) {
            this.localAccuracies = initAccuracies(PREFIX_LOCAL_ACC);
        }
        getWalker().walk(document, null);
        this.locator.reachedEndOfDocument();
        if (!this.trainingOnlyErrors) {
            return null;
        }
        Util.LOG.debug("Finished training in TOE mode: " + ArrayUtils.toString(this.localAccuracies) + ", " + ArrayUtils.toString(this.globalAccuracies));
        return this.localAccuracies;
    }

    public AccuracyView[] viewGlobalAccuracy() {
        return this.globalAccuracies;
    }
}
