package de.fu_berlin.ties.extract;

import de.fu_berlin.ties.Closeable;
import de.fu_berlin.ties.ContextMap;
import de.fu_berlin.ties.ParsingException;
import de.fu_berlin.ties.ProcessingException;
import de.fu_berlin.ties.TextProcessor;
import de.fu_berlin.ties.TiesConfiguration;
import de.fu_berlin.ties.classify.Prediction;
import de.fu_berlin.ties.classify.Tuner;
import de.fu_berlin.ties.eval.Accuracy;
import de.fu_berlin.ties.eval.AccuracyView;
import de.fu_berlin.ties.eval.FMetrics;
import de.fu_berlin.ties.eval.FMetricsView;
import de.fu_berlin.ties.eval.MultiFMetrics;
import de.fu_berlin.ties.eval.MultiFMetricsView;
import de.fu_berlin.ties.io.FieldContainer;
import de.fu_berlin.ties.io.FieldMap;
import de.fu_berlin.ties.io.IOUtils;
import de.fu_berlin.ties.util.Util;
import de.fu_berlin.ties.xml.dom.DOMUtils;
import java.io.File;
import java.io.IOException;
import java.io.Reader;
import java.io.Writer;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.NoSuchElementException;
import java.util.SortedMap;
import java.util.TreeMap;
import org.apache.commons.configuration.Configuration;
import org.apache.commons.lang.builder.ToStringBuilder;
import org.apache.commons.math.stat.descriptive.moment.StandardDeviation;
import org.dom4j.Document;
import org.dom4j.DocumentException;

/* loaded from: input_file:de/fu_berlin/ties/extract/TrainEval.class */
public class TrainEval extends TextProcessor implements Closeable {
    public static final String CONFIG_FEEDBACK = "eval.feedback";
    public static final String CONFIG_SENTENCE_TUNE = "sent.tune";
    public static final String KEY_ITERATION = "Iteration";
    public static final String KEY_RUN = "Run";
    public static final String KEY_TYPE = "Type";
    public static final String TYPE_TRAIN = "Train";
    public static final String TYPE_EVAL = "Eval";
    public static final String EXT_METRICS = "metrics";
    private static final int MEASURE_FEEDBACK = 10;
    private final boolean feedback;
    private final int sentenceIterations;
    private int lastUsedTuneIteration;
    private final MultiFMetrics[] averages;
    private final List<MultiFMetrics> feedbackAverages;
    private FieldContainer sentenceMetricsStore;
    private int runNo;
    private final Tuner tuner;
    private final boolean evaluate;
    private final String predExtension;
    private final boolean storePredsInOutDir;

    /* loaded from: input_file:de/fu_berlin/ties/extract/TrainEval$Results.class */
    public static class Results {
        private final SortedMap<Integer, EvaluatedExtractionContainer> evaluated = new TreeMap();

        /* JADX INFO: Access modifiers changed from: private */
        public void addEvaluated(int i, EvaluatedExtractionContainer evaluatedExtractionContainer) throws IllegalArgumentException {
            EvaluatedExtractionContainer put = this.evaluated.put(Integer.valueOf(i), evaluatedExtractionContainer);
            if (put != null) {
                throw new IllegalArgumentException("Cannot store two extraction containers for same iteration (" + i + "): " + put + "; " + evaluatedExtractionContainer);
            }
        }

        public EvaluatedExtractionContainer getEvaluated(Integer num) {
            return this.evaluated.get(num);
        }

        public EvaluatedExtractionContainer getEvaluated(int i) {
            return getEvaluated(Integer.valueOf(i));
        }

        public Iterator<Integer> iterations() {
            return this.evaluated.keySet().iterator();
        }

        public String toString() {
            return new ToStringBuilder(this).append("size", this.evaluated.size()).toString();
        }
    }

    public TrainEval() throws IllegalArgumentException, ClassCastException, NoSuchElementException {
        this("metrics");
    }

    public TrainEval(String str) throws IllegalArgumentException, ClassCastException, NoSuchElementException {
        this(str, TiesConfiguration.CONF);
    }

    public TrainEval(String str, TiesConfiguration tiesConfiguration) throws IllegalArgumentException, ClassCastException, NoSuchElementException {
        this(str, new Tuner(tiesConfiguration, ExtractorBase.CONFIG_SUFFIX_IE), tiesConfiguration.getInt(CONFIG_SENTENCE_TUNE), tiesConfiguration.getBoolean(CONFIG_FEEDBACK), tiesConfiguration.getBoolean("extract.evaluate"), tiesConfiguration.getString("extract.pred.ext"), tiesConfiguration.getBoolean("extract.pred.use-outdir"), tiesConfiguration);
    }

    public TrainEval(String str, Tuner tuner, int i, boolean z, boolean z2, String str2, boolean z3, TiesConfiguration tiesConfiguration) {
        super(str, tiesConfiguration);
        this.lastUsedTuneIteration = -1;
        this.runNo = 0;
        this.tuner = tuner;
        this.sentenceIterations = i;
        this.feedback = z;
        this.averages = new MultiFMetrics[this.tuner.getTuneIterations()];
        this.evaluate = z2;
        this.predExtension = str2;
        this.storePredsInOutDir = z3;
        if (z && this.tuner.getTuneIterations() > 1 && this.tuner.isTuneEach()) {
            throw new IllegalArgumentException("It's not allowed give feedback when evaluating after each of several iteration because that would mean to evaluate on the training set");
        }
        if (this.feedback) {
            this.feedbackAverages = new ArrayList();
        } else {
            this.feedbackAverages = null;
        }
        new MultiFMetrics(true).storeEntries(new FieldContainer());
        new StandardDeviation().clear();
    }

    private void checkSentenceTraining(Trainer trainer, FMetrics fMetrics, int i) {
        if (this.sentenceIterations > 0 && this.sentenceIterations == i) {
            trainer.disableSentenceTraining();
            Util.LOG.info("Disabled sentence training after " + i + " iterations");
        }
        if (fMetrics != null) {
            updateSentenceMetricsStore(fMetrics, true, i);
        }
    }

    @Override // de.fu_berlin.ties.Closeable
    public void close(int i) throws IOException, ProcessingException {
        if (i <= 0) {
            File determineOutputDirectory = IOUtils.determineOutputDirectory(getConfig());
            FieldContainer createFieldContainer = FieldContainer.createFieldContainer(getConfig());
            boolean z = this.averages.length > 1;
            int min = Math.min(this.averages.length, this.lastUsedTuneIteration);
            for (int i2 = 1; i2 <= min; i2++) {
                if (this.averages[i2 - 1] != null && (this.tuner.isTuneEach() || this.tuner.getTuneEvaluations().contains(Integer.valueOf(i2)) || i2 == min)) {
                    if (z) {
                        createFieldContainer.backgroundMap().put(KEY_ITERATION, Integer.valueOf(i2));
                    }
                    this.averages[i2 - 1].storeEntries(createFieldContainer);
                }
            }
            createFieldContainer.storeInFile(determineOutputDirectory, "All", "metrics", getConfig());
            if (this.feedbackAverages != null) {
                FieldContainer createFieldContainer2 = FieldContainer.createFieldContainer(getConfig());
                for (int i3 = 0; i3 < this.feedbackAverages.size(); i3++) {
                    this.feedbackAverages.get(i3).storeEntries(createFieldContainer2);
                }
                createFieldContainer2.storeInFile(determineOutputDirectory, "Feedback", "metrics", getConfig());
            }
        }
    }

    @Override // de.fu_berlin.ties.TextProcessor
    protected void doProcess(Reader reader, Writer writer, ContextMap contextMap) throws IOException, ProcessingException {
        this.runNo++;
        String[] readURIList = IOUtils.readURIList(reader);
        if (readURIList.length == 0) {
            Util.LOG.info("No files to process");
            return;
        }
        MultiFMetricsView multiFMetricsView = null;
        Results trainAndEval = trainAndEval(readURIList, (File) contextMap.get(TextProcessor.KEY_DIRECTORY), IOUtils.determineOutputDirectory(getConfig()), IOUtils.getBaseName(new File((String) contextMap.get(TextProcessor.KEY_LOCAL_NAME))), writer);
        if (trainAndEval == null) {
            return;
        }
        Iterator<Integer> iterations = trainAndEval.iterations();
        int i = 0;
        while (iterations.hasNext()) {
            Integer next = iterations.next();
            i = next.intValue();
            multiFMetricsView = trainAndEval.getEvaluated(next).viewMetrics();
            if (this.averages[i - 1] == null) {
                this.averages[i - 1] = new MultiFMetrics(true);
            }
            this.averages[i - 1].update(multiFMetricsView);
        }
        while (true) {
            i++;
            if (i > this.averages.length) {
                return;
            }
            if (this.averages[i - 1] == null) {
                this.averages[i - 1] = new MultiFMetrics(true);
            }
            this.averages[i - 1].update(multiFMetricsView);
        }
    }

    protected Extractor initExtractor(Trainer trainer) {
        return new Extractor((String) null, trainer);
    }

    protected Trainer initTrainer(File file) throws ProcessingException {
        Trainer trainer = new Trainer(null, file, getConfig());
        trainer.reset();
        return trainer;
    }

    @Override // de.fu_berlin.ties.TextProcessor
    public String toString() {
        return new ToStringBuilder(this).appendSuper(super.toString()).append("tuner", this.tuner).append("sentence iterations", this.sentenceIterations).append("feedback", this.feedback).toString();
    }

    private void serializeExtractions(File file, String str, int i, EvaluatedExtractionContainer evaluatedExtractionContainer, FMetrics fMetrics) throws IOException {
        FieldContainer createFieldContainer = FieldContainer.createFieldContainer(getConfig());
        evaluatedExtractionContainer.storeEntries(createFieldContainer);
        Util.LOG.info("Stored results of training + evaluation run in " + createFieldContainer.storeInFile(file, str, Extractor.EXT_EXTRACTIONS, getConfig()));
        if (fMetrics != null) {
            updateSentenceMetricsStore(fMetrics, false, i);
        }
    }

    private void serializeMetrics(File file, String str, Writer writer, Results results) throws IOException {
        FieldContainer createFieldContainer = FieldContainer.createFieldContainer(getConfig());
        boolean z = this.tuner.getTuneIterations() > 1;
        Iterator<Integer> iterations = results.iterations();
        while (iterations.hasNext()) {
            Integer next = iterations.next();
            EvaluatedExtractionContainer evaluated = results.getEvaluated(next);
            if (z) {
                createFieldContainer.backgroundMap().put(KEY_ITERATION, next);
            }
            evaluated.viewMetrics().storeEntries(createFieldContainer);
        }
        createFieldContainer.store(writer);
        writer.flush();
        if (this.sentenceMetricsStore != null) {
            this.sentenceMetricsStore.storeInFile(file, str, ExtractorBase.CONFIG_SENTENCE, getConfig());
        }
    }

    private void serializeTrainingMetrics(File file, String str, FieldContainer[] fieldContainerArr, AccuracyView[] accuracyViewArr) throws IOException {
        if (fieldContainerArr != null) {
            for (int i = 0; i < accuracyViewArr.length; i++) {
                fieldContainerArr[i].storeInFile(file, accuracyViewArr.length > 1 ? str + ((char) (97 + i)) : str, "train", getConfig());
            }
        }
    }

    public Results trainAndEval(String[] strArr, File file, File file2, String str, Writer writer) throws IOException, ProcessingException {
        IOUtils.setDefaultDirectory(file2);
        LinkedList linkedList = new LinkedList();
        LinkedList linkedList2 = new LinkedList();
        this.tuner.selectFiles(strArr, linkedList, linkedList2);
        long currentTimeMillis = System.currentTimeMillis();
        FieldContainer[] fieldContainerArr = null;
        AccuracyView[] accuracyViewArr = null;
        boolean z = this.tuner.getTuneIterations() > 1;
        boolean z2 = true;
        this.tuner.reset();
        Trainer initTrainer = initTrainer(file2);
        Extractor extractor = null;
        FMetrics fMetrics = null;
        if (initTrainer.isSentenceFiltering()) {
            this.sentenceMetricsStore = FieldContainer.createFieldContainer(getConfig());
        } else {
            this.sentenceMetricsStore = null;
        }
        Results results = this.evaluate ? new Results() : null;
        for (int i = 1; i <= this.tuner.getTuneIterations() && z2; i++) {
            if (this.tuner.getTuneIterations() > 1) {
                Util.LOG.debug("Starting TUNE iteration " + i + "/" + this.tuner.getTuneIterations() + " (will stop if no improvement for " + this.tuner.getTuneStop() + " iterations)");
            }
            Iterator it = linkedList.iterator();
            initTrainer.resetGlobalAccuracy();
            if (initTrainer.isSentenceFiltering()) {
                fMetrics = new FMetrics();
            }
            while (it.hasNext()) {
                File resolveFilename = IOUtils.resolveFilename(file, (String) it.next());
                Util.LOG.debug("Starting to train " + resolveFilename);
                try {
                    Accuracy[] train = initTrainer.train(DOMUtils.readDocument(resolveFilename, (Configuration) getConfig()), resolveFilename, AnswerBuilder.readCorrespondingAnswerKeys(initTrainer.getTargetStructure(), resolveFilename, getConfig()));
                    if (train != null) {
                        accuracyViewArr = initTrainer.viewGlobalAccuracy();
                        if (fieldContainerArr == null) {
                            fieldContainerArr = new FieldContainer[accuracyViewArr.length];
                            for (int i2 = 0; i2 < accuracyViewArr.length; i2++) {
                                fieldContainerArr[i2] = FieldContainer.createFieldContainer(getConfig());
                            }
                        }
                        for (int i3 = 0; i3 < accuracyViewArr.length; i3++) {
                            FieldMap storeFields = accuracyViewArr[i3].storeFields();
                            storeFields.putAll(train[i3].storeFields());
                            storeFields.put(Prediction.KEY_SOURCE, IOUtils.getBaseName(resolveFilename));
                            if (z) {
                                storeFields.put(KEY_ITERATION, new Integer(i));
                            }
                            fieldContainerArr[i3].add(storeFields);
                        }
                    }
                    Util.LOG.info("Trained " + resolveFilename);
                    if (initTrainer.isSentenceFiltering()) {
                        FMetricsView evaluateSentenceFiltering = initTrainer.evaluateSentenceFiltering();
                        Util.LOG.debug("Evaluated sentence filtering for trained document: " + evaluateSentenceFiltering);
                        fMetrics.update(evaluateSentenceFiltering);
                    }
                } catch (DocumentException e) {
                    throw new ParsingException("Error while parsing " + resolveFilename + ": " + e.toString(), e);
                }
            }
            checkSentenceTraining(initTrainer, fMetrics, i);
            if (accuracyViewArr != null) {
                double[] dArr = new double[accuracyViewArr.length];
                for (int i4 = 0; i4 < dArr.length; i4++) {
                    dArr[i4] = accuracyViewArr[i4].getAccuracy();
                }
                z2 = this.tuner.continueTraining(dArr, i);
            }
            if (this.tuner.shouldEvaluate(z2, i)) {
                Util.LOG.info("Finished training using " + initTrainer.toString() + "; " + Util.showDuration(currentTimeMillis));
                long currentTimeMillis2 = System.currentTimeMillis();
                int i5 = 0;
                if (extractor == null) {
                    extractor = initExtractor(initTrainer);
                }
                Iterator it2 = linkedList2.iterator();
                EvaluatedExtractionContainer evaluatedExtractionContainer = this.evaluate ? new EvaluatedExtractionContainer(extractor.getTargetStructure(), getConfig()) : null;
                if (extractor.isSentenceFiltering()) {
                    fMetrics = new FMetrics();
                }
                while (it2.hasNext()) {
                    File resolveFilename2 = IOUtils.resolveFilename(file, (String) it2.next());
                    i5++;
                    Util.LOG.debug("Starting to extract and evaluate file #" + i5 + ": " + resolveFilename2);
                    try {
                        Document readDocument = DOMUtils.readDocument(resolveFilename2, (Configuration) getConfig());
                        ExtractionContainer extract = extractor.extract(readDocument, resolveFilename2);
                        if (this.evaluate) {
                            ExtractionContainer readCorrespondingAnswerKeys = AnswerBuilder.readCorrespondingAnswerKeys(initTrainer.getTargetStructure(), resolveFilename2, getConfig());
                            if (this.feedback) {
                                initTrainer.train(readDocument, resolveFilename2, readCorrespondingAnswerKeys);
                            }
                            if (extractor.isSentenceFiltering()) {
                                FMetricsView evaluateSentenceFiltering2 = extractor.evaluateSentenceFiltering(readCorrespondingAnswerKeys);
                                Util.LOG.debug("Evaluated sentence filtering for current document: " + evaluateSentenceFiltering2);
                                fMetrics.update(evaluateSentenceFiltering2);
                            }
                            evaluatedExtractionContainer.evaluateBatch(extract, readCorrespondingAnswerKeys, IOUtils.getBaseName(resolveFilename2));
                            Util.LOG.info("Extracted and evaluated " + resolveFilename2 + ", interim results: " + evaluatedExtractionContainer.viewMetrics().viewAll());
                            if (this.feedback && (i5 % MEASURE_FEEDBACK == 0 || !it2.hasNext())) {
                                int intValue = new Double(Math.ceil(i5 / 10.0d)).intValue();
                                if (this.feedbackAverages.size() < intValue) {
                                    this.feedbackAverages.add(new MultiFMetrics(true));
                                }
                                this.feedbackAverages.get(intValue - 1).update(evaluatedExtractionContainer.viewMetrics());
                            }
                        } else {
                            File file3 = new File(this.storePredsInOutDir ? file2 : resolveFilename2.getParentFile(), IOUtils.getBaseName(resolveFilename2) + '.' + this.predExtension);
                            FieldContainer createFieldContainer = FieldContainer.createFieldContainer(getConfig());
                            extract.storeEntries(createFieldContainer);
                            Writer openWriter = IOUtils.openWriter(file3, (Configuration) getConfig());
                            createFieldContainer.store(openWriter);
                            openWriter.flush();
                            IOUtils.tryToClose(openWriter);
                            Util.LOG.info("Stored predicted extractions in " + file3);
                        }
                    } catch (DocumentException e2) {
                        throw new ParsingException((Throwable) e2);
                    }
                }
                if (this.evaluate) {
                    serializeExtractions(file2, str, i, evaluatedExtractionContainer, fMetrics);
                    results.addEvaluated(i, evaluatedExtractionContainer);
                }
                Util.LOG.info("Finished extraction and evaluation using " + extractor.toString() + "; " + Util.showDuration(currentTimeMillis2));
            }
            this.lastUsedTuneIteration = Math.max(this.lastUsedTuneIteration, i);
        }
        serializeTrainingMetrics(file2, str, fieldContainerArr, accuracyViewArr);
        if (this.evaluate) {
            serializeMetrics(file2, str, writer, results);
        }
        extractor.destroy();
        return results;
    }

    private void updateSentenceMetricsStore(FMetricsView fMetricsView, boolean z, int i) {
        FieldMap storeFields = fMetricsView.storeFields();
        storeFields.put("Type", z ? TYPE_TRAIN : TYPE_EVAL);
        storeFields.put(KEY_ITERATION, new Integer(i));
        this.sentenceMetricsStore.add(storeFields);
    }
}
