package de.fu_berlin.ties.extract;

import de.fu_berlin.ties.ContextMap;
import de.fu_berlin.ties.DirectoryProcessor;
import de.fu_berlin.ties.ParsingException;
import de.fu_berlin.ties.ProcessingException;
import de.fu_berlin.ties.TiesConfiguration;
import de.fu_berlin.ties.classify.Prediction;
import de.fu_berlin.ties.eval.Accuracy;
import de.fu_berlin.ties.eval.FeatureCountView;
import de.fu_berlin.ties.eval.MultiFMetrics;
import de.fu_berlin.ties.eval.MultiFMetricsView;
import de.fu_berlin.ties.io.ExtensionFilter;
import de.fu_berlin.ties.io.FieldContainer;
import de.fu_berlin.ties.io.FieldMap;
import de.fu_berlin.ties.io.IOUtils;
import de.fu_berlin.ties.util.Util;
import de.fu_berlin.ties.xml.dom.DOMUtils;
import java.io.File;
import java.io.FileFilter;
import java.io.IOException;
import java.io.Writer;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.NoSuchElementException;
import java.util.Random;
import java.util.Set;
import org.apache.commons.configuration.Configuration;
import org.apache.commons.configuration.ConfigurationException;
import org.apache.commons.lang.builder.ToStringBuilder;
import org.dom4j.DocumentException;

/* loaded from: input_file:de/fu_berlin/ties/extract/TrainEval.class */
public class TrainEval extends DirectoryProcessor {
    public static final String OUTPUT_DIR = "eval";
    public static final String RUN_DIR = "run";
    public static final String CONFIG_FILE_EXT = "eval.files";
    public static final String CONFIG_TRAIN_SPLIT = "eval.train-split";
    public static final String CONFIG_RUN = "eval.run";
    public static final String CONFIG_UNIFORM = "eval.uniform";
    public static final String KEY_RUN = "Run";
    private final float trainSplit;
    private final int runs;
    private final boolean uniform;

    /* loaded from: input_file:de/fu_berlin/ties/extract/TrainEval$Results.class */
    public static class Results {
        private final EvaluatedExtractionContainer evaluated;
        private final FeatureCountView trainFeatureCV;
        private final FeatureCountView extractFeatureCV;

        public Results(EvaluatedExtractionContainer evaluatedExtractionContainer, FeatureCountView featureCountView, FeatureCountView featureCountView2) {
            this.evaluated = evaluatedExtractionContainer;
            this.trainFeatureCV = featureCountView;
            this.extractFeatureCV = featureCountView2;
        }

        public EvaluatedExtractionContainer getEvaluated() {
            return this.evaluated;
        }

        public FeatureCountView getExtractFeatureCV() {
            return this.extractFeatureCV;
        }

        public FeatureCountView getTrainFeatureCV() {
            return this.trainFeatureCV;
        }

        public String toString() {
            return new ToStringBuilder(this).append("training feature count", this.trainFeatureCV).append("extraction feature count", this.extractFeatureCV).toString();
        }
    }

    public TrainEval() throws IllegalArgumentException, ClassCastException, NoSuchElementException {
        this(TiesConfiguration.CONF);
    }

    public TrainEval(TiesConfiguration tiesConfiguration) throws IllegalArgumentException, ClassCastException, NoSuchElementException {
        this(new ExtensionFilter((Set) new HashSet(tiesConfiguration.getList(CONFIG_FILE_EXT)), false), tiesConfiguration.getFloat(CONFIG_TRAIN_SPLIT), tiesConfiguration.getInt(CONFIG_RUN), tiesConfiguration.getBoolean(CONFIG_UNIFORM), tiesConfiguration);
    }

    public TrainEval(FileFilter fileFilter, float f, int i, boolean z, TiesConfiguration tiesConfiguration) throws IllegalArgumentException {
        super(fileFilter, tiesConfiguration);
        this.uniform = z;
        if (z) {
            this.trainSplit = 0.5f;
            this.runs = 2;
        } else {
            if (f < 0.0d || f > 1.0d) {
                throw new IllegalArgumentException(new StringBuffer().append("Train split is not a percentage: ").append(f).toString());
            }
            if (i <= 0) {
                throw new IllegalArgumentException(new StringBuffer().append("Number of runs is not positive: ").append(i).toString());
            }
            this.trainSplit = f;
            this.runs = i;
        }
    }

    public float getEvalSplit() {
        return 1.0f - this.trainSplit;
    }

    public int getRuns() {
        return this.runs;
    }

    public float getTrainSplit() {
        return this.trainSplit;
    }

    protected Extractor initExtractor(Trainer trainer) {
        return new Extractor((String) null, trainer);
    }

    protected Trainer initTrainer(File file) throws ProcessingException {
        return new Trainer(null, file, getConfig());
    }

    public boolean isUniform() {
        return this.uniform;
    }

    @Override // de.fu_berlin.ties.DirectoryProcessor
    public void process(File[] fileArr, ContextMap contextMap) throws IOException, ProcessingException {
        if (fileArr.length == 0) {
            Util.LOG.info("No files to process");
            return;
        }
        File createOutFile = IOUtils.createOutFile(fileArr[0].getParentFile(), OUTPUT_DIR);
        if (!createOutFile.mkdir()) {
            throw new IOException(new StringBuffer().append("Creating ").append(createOutFile).append(" directory failed").toString());
        }
        try {
            getConfig().save(new File(createOutFile, "config.cfg"));
        } catch (ConfigurationException e) {
            Util.LOG.error("Could not store configuration: ", e);
        }
        FieldContainer createFieldContainer = FieldContainer.createFieldContainer();
        MultiFMetrics multiFMetrics = new MultiFMetrics(true);
        for (int i = 1; i <= this.runs; i++) {
            File file = new File(createOutFile, new StringBuffer().append(RUN_DIR).append(i).toString());
            if (!file.mkdir()) {
                throw new IOException(new StringBuffer().append("Creating ").append(file).append(" directory failed").toString());
            }
            Util.LOG.info(new StringBuffer().append("Starting evaluation run ").append(i).append("/").append(this.runs).toString());
            Results trainAndEval = trainAndEval(fileArr, contextMap, file, i);
            Integer num = new Integer(i);
            FieldMap storeFields = trainAndEval.getTrainFeatureCV().storeFields();
            storeFields.put(KEY_RUN, num);
            storeFields.put("Type", "training");
            createFieldContainer.add(storeFields);
            FieldMap storeFields2 = trainAndEval.getExtractFeatureCV().storeFields();
            storeFields2.put(KEY_RUN, num);
            storeFields2.put("Type", "extraction");
            createFieldContainer.add(storeFields2);
            multiFMetrics.update(trainAndEval.getEvaluated().viewMetrics());
        }
        storeValues(createFieldContainer, createOutFile, "feature-counts", null);
        FieldContainer createFieldContainer2 = FieldContainer.createFieldContainer();
        multiFMetrics.storeEntries(createFieldContainer2);
        storeValues(createFieldContainer2, createOutFile, MultiFMetrics.NAME_METRICS, null);
    }

    @Override // de.fu_berlin.ties.DirectoryProcessor
    public String toString() {
        return new ToStringBuilder(this).appendSuper(super.toString()).append("train split", this.trainSplit).append("runs", this.runs).append("uniform", this.uniform).toString();
    }

    private final void selectFiles(File[] fileArr, List list, List list2, int i) throws IllegalArgumentException {
        if (!list.isEmpty() || !list2.isEmpty()) {
            throw new IllegalArgumentException("Lists of train files and eval files must initially be empty");
        }
        if (this.uniform) {
            for (int i2 = 0; i2 < fileArr.length; i2++) {
                if ((i2 - i) % 2 == 0) {
                    list2.add(fileArr[i2]);
                } else {
                    list.add(fileArr[i2]);
                }
            }
        } else {
            int round = Math.round(fileArr.length * this.trainSplit);
            Random random = new Random();
            for (int i3 = 0; i3 < fileArr.length; i3++) {
                if (random.nextFloat() < this.trainSplit) {
                    list.add(fileArr[i3]);
                } else {
                    list2.add(fileArr[i3]);
                }
            }
            int size = list.size() - round;
            if (size > 0) {
                for (int i4 = 0; i4 < size; i4++) {
                    list2.add((File) list.remove(random.nextInt(list.size())));
                }
                Util.LOG.debug(new StringBuffer().append("Reached expected split by moving ").append(size).append(" files from training to evaluation corpus").toString());
            } else if (size < 0) {
                for (int i5 = 0; i5 > size; i5--) {
                    list.add((File) list2.remove(random.nextInt(list2.size())));
                }
                Util.LOG.debug(new StringBuffer().append("Reached expected split by moving ").append(-size).append(" files from evaluation to training corpus").toString());
            }
        }
        Util.LOG.debug(new StringBuffer().append("Using ").append(list.size()).append(" files for ").append("training, ").append(list2.size()).append(" files for evaluation").toString());
    }

    private File storeValues(FieldContainer fieldContainer, File file, String str, String str2) throws IOException {
        File file2 = new File(file, new StringBuffer().append(str).append('.').append(str2 != null ? str2 : FieldContainer.recommendedExtension()).toString());
        Writer openWriter = IOUtils.openWriter(file2, (Configuration) getConfig());
        fieldContainer.store(openWriter);
        openWriter.flush();
        openWriter.close();
        return file2;
    }

    public Results trainAndEval(File[] fileArr, ContextMap contextMap, File file, int i) throws IOException, ProcessingException {
        IOUtils.setDefaultDirectory(file);
        LinkedList linkedList = new LinkedList();
        LinkedList<File> linkedList2 = new LinkedList();
        selectFiles(fileArr, linkedList, linkedList2, i);
        long currentTimeMillis = System.currentTimeMillis();
        FieldContainer createFieldContainer = FieldContainer.createFieldContainer();
        Trainer initTrainer = initTrainer(file);
        Random random = new Random();
        while (!linkedList.isEmpty()) {
            File file2 = (File) linkedList.remove(random.nextInt(linkedList.size()));
            Util.LOG.debug(new StringBuffer().append("Starting to train ").append(file2).toString());
            try {
                Accuracy train = initTrainer.train(DOMUtils.readDocument(file2, (Configuration) getConfig()), AnswerBuilder.readCorrespondingAnswerKeys(initTrainer.getTargetStructure(), file2, getConfig()));
                if (train != null) {
                    FieldMap storeFields = train.storeFields();
                    storeFields.put(Prediction.KEY_SOURCE, IOUtils.getBaseName(file2));
                    createFieldContainer.add(storeFields);
                }
                Util.LOG.info(new StringBuffer().append("Trained ").append(file2).toString());
            } catch (DocumentException e) {
                throw new ParsingException(new StringBuffer().append("Error while parsing ").append(file2).append(": ").append(e.toString()).toString(), e);
            }
        }
        if (createFieldContainer.size() > 0) {
            storeValues(createFieldContainer, file, "train-accuracy", Extractor.EXT_EXTRACTIONS);
        }
        Util.LOG.info(new StringBuffer().append("Finished training (").append(Util.showDuration(currentTimeMillis)).append(")").toString());
        long currentTimeMillis2 = System.currentTimeMillis();
        Extractor initExtractor = initExtractor(initTrainer);
        EvaluatedExtractionContainer evaluatedExtractionContainer = new EvaluatedExtractionContainer(initExtractor.getTargetStructure(), getConfig());
        for (File file3 : linkedList2) {
            Util.LOG.debug(new StringBuffer().append("Starting to extract and evaluate ").append(file3).toString());
            try {
                evaluatedExtractionContainer.evaluateBatch(initExtractor.extract(DOMUtils.readDocument(file3, (Configuration) getConfig())), AnswerBuilder.readCorrespondingAnswerKeys(initTrainer.getTargetStructure(), file3, getConfig()), IOUtils.getBaseName(file3));
                MultiFMetricsView viewMetrics = evaluatedExtractionContainer.viewMetrics();
                Util.LOG.info(new StringBuffer().append("Extracted and evaluated ").append(file3).append(", interim results: ").append(viewMetrics.viewAll()).toString());
                for (String str : viewMetrics.types()) {
                    Util.LOG.debug(new StringBuffer().append("Interim results for ").append(str).append(": ").append(viewMetrics.view(str)).toString());
                }
            } catch (DocumentException e2) {
                throw new ParsingException((Throwable) e2);
            }
        }
        Util.LOG.info(new StringBuffer().append("Finished extraction and evaluation (").append(Util.showDuration(currentTimeMillis2)).append(")").toString());
        FieldContainer createFieldContainer2 = FieldContainer.createFieldContainer();
        evaluatedExtractionContainer.storeEntries(createFieldContainer2);
        Util.LOG.info(new StringBuffer().append("Stored results of training + evaluation run in ").append(storeValues(createFieldContainer2, file, "results", Extractor.EXT_EXTRACTIONS).getAbsolutePath()).toString());
        FieldContainer createFieldContainer3 = FieldContainer.createFieldContainer();
        evaluatedExtractionContainer.viewMetrics().storeEntries(createFieldContainer3);
        storeValues(createFieldContainer3, file, MultiFMetrics.NAME_METRICS, null);
        return new Results(evaluatedExtractionContainer, initTrainer.viewFeatureCount(), initExtractor.viewFeatureCount());
    }
}
