package de.fu_berlin.ties.combi;

import de.fu_berlin.ties.ContextMap;
import de.fu_berlin.ties.ProcessingException;
import de.fu_berlin.ties.TiesConfiguration;
import de.fu_berlin.ties.classify.Prediction;
import de.fu_berlin.ties.classify.PredictionComparator;
import de.fu_berlin.ties.classify.PredictionDistribution;
import de.fu_berlin.ties.classify.Probability;
import de.fu_berlin.ties.extract.amend.BeginEndReextractor;
import de.fu_berlin.ties.extract.amend.FinalReextractor;
import de.fu_berlin.ties.extract.reestimate.LengthFilter;
import de.fu_berlin.ties.extract.reestimate.Reestimator;
import de.fu_berlin.ties.text.TokenDetails;
import de.fu_berlin.ties.util.Util;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;

/* loaded from: input_file:de/fu_berlin/ties/combi/BeginEndStrategy.class */
public class BeginEndStrategy extends CombinationStrategy {
    private static final char PREFIX_TERMINATOR = '-';
    private static final String BEGIN_PREFIX = "B-";
    private static final String END_PREFIX = "E-";
    private static final String OTHER = "A";
    private final SortedSet<String> beginClasses;
    private final SortedSet<String> endClasses;
    private final SortedSet[] allClasses;
    private final TiesConfiguration config;
    private final boolean reextract;
    private BeginEndReextractor.PositivePredictionsMap beginMap;
    private BeginEndReextractor.PositivePredictionsMap endMap;

    public BeginEndStrategy(Set<String> set, TiesConfiguration tiesConfiguration) {
        super(set);
        this.beginMap = null;
        this.endMap = null;
        this.config = tiesConfiguration;
        this.reextract = tiesConfiguration.getBoolean("combination.begin-end.level2");
        if (this.reextract) {
            this.beginMap = new BeginEndReextractor.PositivePredictionsMap();
            this.endMap = new BeginEndReextractor.PositivePredictionsMap();
        }
        TreeSet treeSet = new TreeSet();
        TreeSet treeSet2 = new TreeSet();
        treeSet.add(OTHER);
        treeSet2.add(OTHER);
        for (String str : set) {
            treeSet.add(BEGIN_PREFIX + str);
            treeSet2.add(END_PREFIX + str);
        }
        this.beginClasses = Collections.unmodifiableSortedSet(treeSet);
        this.endClasses = Collections.unmodifiableSortedSet(treeSet2);
        this.allClasses = new SortedSet[]{this.beginClasses, this.endClasses};
    }

    @Override // de.fu_berlin.ties.combi.CombinationStrategy
    public Set[] activeClasses() {
        return this.allClasses;
    }

    @Override // de.fu_berlin.ties.combi.CombinationStrategy
    public Set[] allClasses() {
        return this.allClasses;
    }

    @Override // de.fu_berlin.ties.combi.CombinationStrategy
    public ContextMap contextForReextractor() {
        ContextMap contextMap;
        if (this.reextract) {
            contextMap = new ContextMap();
            contextMap.put(BeginEndReextractor.CONTEXT_BEGIN_MAP, this.beginMap);
            contextMap.put(BeginEndReextractor.CONTEXT_END_MAP, this.endMap);
            this.beginMap = new BeginEndReextractor.PositivePredictionsMap();
            this.endMap = new BeginEndReextractor.PositivePredictionsMap();
        } else {
            contextMap = null;
        }
        return contextMap;
    }

    @Override // de.fu_berlin.ties.combi.CombinationStrategy
    public FinalReextractor initReextractor(Reestimator reestimator) throws ProcessingException {
        if (!this.reextract) {
            return null;
        }
        Reestimator reestimator2 = reestimator;
        LengthFilter lengthFilter = null;
        while (reestimator2 != null && lengthFilter == null) {
            if (reestimator2 instanceof LengthFilter) {
                lengthFilter = (LengthFilter) reestimator2;
            } else {
                reestimator2 = reestimator2.getPrecedingReestimator();
            }
        }
        if (lengthFilter != null) {
            return new BeginEndReextractor(getValidClasses(), this.config, lengthFilter);
        }
        throw new IllegalArgumentException("No LengthFilter found in re-estimator chain -- cannot init BeginEndReextractor");
    }

    @Override // de.fu_berlin.ties.combi.CombinationStrategy
    protected boolean resetHook() {
        return (state().isEnd() || state().getType() == null) ? false : true;
    }

    private void storePositivePredictions(Iterator<Prediction> it, BeginEndReextractor.PositivePredictionsMap positivePredictionsMap, String str, int i) throws IllegalStateException {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        boolean z = false;
        it.next().getType();
        while (it.hasNext() && !z) {
            Prediction next = it.next();
            String type = next.getType();
            if (type.equals(OTHER)) {
                z = true;
            } else {
                linkedHashMap.put(type.substring(str.length()), next);
            }
        }
        if (linkedHashMap.isEmpty()) {
            throw new IllegalStateException("No positive predictions");
        }
        positivePredictionsMap.put(Integer.valueOf(i), linkedHashMap);
    }

    @Override // de.fu_berlin.ties.combi.CombinationStrategy
    public String[] translateCurrentState(CombinationState combinationState) throws IllegalArgumentException {
        return new String[]{combinationState.isBegin() ? BEGIN_PREFIX + combinationState.getType() : OTHER, combinationState.isEnd() ? END_PREFIX + combinationState.getType() : OTHER};
    }

    @Override // de.fu_berlin.ties.combi.CombinationStrategy
    public CombinationState translateResult(PredictionDistribution[] predictionDistributionArr, TokenDetails tokenDetails) throws IllegalArgumentException {
        String str;
        Probability probability;
        CombinationState combinationState;
        if (predictionDistributionArr.length != 2) {
            throw new IllegalArgumentException("Illegal number of classifiers: " + predictionDistributionArr.length + " instead of 2");
        }
        Prediction best = predictionDistributionArr[0].best();
        Prediction best2 = predictionDistributionArr[1].best();
        PredictionComparator predictionComparator = new PredictionComparator();
        boolean z = !best.getType().equals(OTHER);
        boolean z2 = !best2.getType().equals(OTHER);
        String type = state().getType();
        boolean isEnd = state().isEnd();
        if (z) {
            String substring = best.getType().substring(BEGIN_PREFIX.length());
            if (z2) {
                String substring2 = best2.getType().substring(END_PREFIX.length());
                if (substring.equals(substring2)) {
                    str = substring;
                    Prediction prediction = new Prediction(null, best.getProbability());
                    prediction.addProb(best2.getProbability(), true);
                    probability = prediction.getProbability();
                } else if (!substring2.equals(type) || isEnd) {
                    str = substring;
                    probability = best.getProbability();
                } else if (predictionComparator.compare(best, best2) <= 0) {
                    str = substring2;
                    z = false;
                    Util.LOG.debug("BeginEnd: end classification " + best2 + " wins over begin classification " + best + "due to higher/equal confidence");
                    probability = best2.getProbability();
                } else {
                    str = substring;
                    z2 = false;
                    Util.LOG.debug("BeginEnd: begin classification " + best + " wins over end classification " + best2 + "due to higher confidence");
                    probability = best.getProbability();
                }
            } else {
                str = substring;
                probability = best.getProbability();
            }
        } else if (z2) {
            String substring3 = best2.getType().substring(END_PREFIX.length());
            if (!substring3.equals(type) || isEnd) {
                z2 = false;
                str = isEnd ? null : type;
                probability = null;
            } else {
                str = substring3;
                probability = best2.getProbability();
            }
        } else {
            str = isEnd ? null : type;
            probability = null;
        }
        boolean z3 = (!z || isEnd || type == null || type.equals(str)) ? false : true;
        if (str == null) {
            combinationState = CombinationState.OUTSIDE;
        } else {
            if (!getValidClasses().contains(str)) {
                throw new IllegalArgumentException("Type " + str + " derived from predictions " + best + " and " + best2 + " is invalid");
            }
            combinationState = new CombinationState(str, z, z2, probability, z3);
        }
        return combinationState;
    }

    @Override // de.fu_berlin.ties.combi.CombinationStrategy
    public void updateState(CombinationState combinationState, PredictionDistribution[] predictionDistributionArr, TokenDetails tokenDetails) throws IllegalArgumentException {
        super.updateState(combinationState, predictionDistributionArr, tokenDetails);
        if (this.reextract) {
            if (predictionDistributionArr[0] != null) {
                if (!predictionDistributionArr[0].best().getType().equals(OTHER)) {
                    storePositivePredictions(predictionDistributionArr[0].iterator(), this.beginMap, BEGIN_PREFIX, tokenDetails.getIndex());
                }
            }
            if (predictionDistributionArr[1] != null) {
                if (!predictionDistributionArr[1].best().getType().equals(OTHER)) {
                    storePositivePredictions(predictionDistributionArr[1].iterator(), this.endMap, END_PREFIX, tokenDetails.getIndex());
                }
            }
        }
    }
}
