package de.fu_berlin.ties.classify.winnow;

import de.fu_berlin.ties.ProcessingException;
import de.fu_berlin.ties.TiesConfiguration;
import de.fu_berlin.ties.classify.PredictionDistribution;
import de.fu_berlin.ties.classify.TrainableClassifier;
import de.fu_berlin.ties.classify.feature.Feature;
import de.fu_berlin.ties.classify.feature.FeatureSet;
import de.fu_berlin.ties.classify.feature.FeatureTransformer;
import de.fu_berlin.ties.classify.feature.FeatureVector;
import de.fu_berlin.ties.util.Util;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;
import java.util.TreeSet;
import org.apache.commons.collections.LRUMap;
import org.apache.commons.lang.builder.ToStringBuilder;

/* loaded from: input_file:de/fu_berlin/ties/classify/winnow/Winnow.class */
public class Winnow extends TrainableClassifier {
    private final boolean balanced;
    private final float promotion;
    private final float demotion;
    private final float thresholdThickness;
    private final LRUMap store;

    public Winnow(Set set) throws IllegalArgumentException {
        this(set, TiesConfiguration.CONF);
    }

    public Winnow(Set set, TiesConfiguration tiesConfiguration) throws IllegalArgumentException {
        this(set, FeatureTransformer.createTransformer(tiesConfiguration), tiesConfiguration);
    }

    public Winnow(Set set, FeatureTransformer featureTransformer, TiesConfiguration tiesConfiguration) throws IllegalArgumentException {
        this(set, featureTransformer, tiesConfiguration.getBoolean("classifier.winnow.balanced"), tiesConfiguration.getFloat("classifier.winnow.promotion"), tiesConfiguration.getFloat("classifier.winnow.demotion"), tiesConfiguration.getFloat("classifier.winnow.threshold.thickness"), tiesConfiguration.getInt("classifier.winnow.features"));
    }

    public Winnow(Set set, FeatureTransformer featureTransformer, boolean z, float f, float f2, float f3, int i) throws IllegalArgumentException {
        super(new TreeSet(set), featureTransformer);
        if (f <= 1.0d) {
            throw new IllegalArgumentException(new StringBuffer().append("Promotion factor must be > 1: ").append(f).toString());
        }
        if (f2 >= 1.0d || f2 <= 0.0d) {
            throw new IllegalArgumentException(new StringBuffer().append("Demotion factor must be in ]0, 1[ range:").append(f2).toString());
        }
        if (f3 >= 1.0d || f3 < 0.0d) {
            throw new IllegalArgumentException(new StringBuffer().append("Threshold thickness must be in [0, 1[ range: ").append(f3).toString());
        }
        this.balanced = z;
        this.promotion = f;
        this.demotion = f2;
        this.thresholdThickness = f3;
        this.store = new LRUMap(i);
    }

    protected void adjustWeights(Feature feature, short[] sArr) {
        Integer num = new Integer(feature.hashCode());
        float[] fArr = (float[]) this.store.get(num);
        int size = getAllClasses().size();
        if (sArr.length != size) {
            throw new IllegalArgumentException(new StringBuffer().append("Array of directions has ").append(sArr.length).append(" members instead of one for each of the ").append(size).append(" classes").toString());
        }
        if (fArr == null) {
            fArr = initWeightArray();
            this.store.put(num, fArr);
        }
        for (int i = 0; i < size; i++) {
            if (sArr[i] < 0) {
                float[] fArr2 = fArr;
                int i2 = i;
                fArr2[i2] = fArr2[i2] * this.demotion;
                if (this.balanced) {
                    float[] fArr3 = fArr;
                    int i3 = i + size;
                    fArr3[i3] = fArr3[i3] * this.promotion;
                }
            } else if (sArr[i] > 0) {
                float[] fArr4 = fArr;
                int i4 = i;
                fArr4[i4] = fArr4[i4] * this.promotion;
                if (this.balanced) {
                    float[] fArr5 = fArr;
                    int i5 = i + size;
                    fArr5[i5] = fArr5[i5] * this.demotion;
                }
            }
        }
    }

    protected void chooseClassesToAdjust(WinnowDistribution winnowDistribution, String str, Set set, Set set2) {
        Iterator it = winnowDistribution.iterator();
        float minorThreshold = minorThreshold(winnowDistribution.getThreshold(), winnowDistribution.getRawThreshold());
        float majorThreshold = majorThreshold(winnowDistribution.getThreshold(), winnowDistribution.getRawThreshold());
        while (it.hasNext()) {
            WinnowPrediction winnowPrediction = (WinnowPrediction) it.next();
            if (str.equals(winnowPrediction.getType())) {
                if (winnowPrediction.getRawScore() <= majorThreshold) {
                    set.add(winnowPrediction.getType());
                }
            } else if (winnowPrediction.getRawScore() > minorThreshold) {
                set2.add(winnowPrediction.getType());
            }
        }
    }

    protected double confidence(float f, float f2) {
        return f / f2;
    }

    protected float defaultWeight() {
        if (this.balanced) {
            return 0.0f;
        }
        return initWeight();
    }

    @Override // de.fu_berlin.ties.classify.TrainableClassifier
    protected PredictionDistribution doClassify(FeatureVector featureVector, Set set) {
        FeatureSet featureSet = featureSet(featureVector);
        float[] initScores = initScores();
        Iterator it = featureSet.iterator();
        synchronized (this) {
            while (it.hasNext()) {
                updateScores((Feature) it.next(), initScores);
            }
        }
        float rawThreshold = rawThreshold(featureSet);
        float threshold = threshold(rawThreshold);
        float[] fArr = new float[initScores.length];
        float f = 0.0f;
        for (int i = 0; i < initScores.length; i++) {
            fArr[i] = sigmoid(initScores[i], threshold, rawThreshold);
            f += fArr[i];
        }
        WinnowDistribution winnowDistribution = new WinnowDistribution(threshold, rawThreshold);
        Iterator it2 = set.iterator();
        int i2 = 0;
        while (it2.hasNext()) {
            winnowDistribution.add(new WinnowPrediction((String) it2.next(), confidence(fArr[i2], f), initScores[i2], fArr[i2]));
            i2++;
        }
        return winnowDistribution;
    }

    @Override // de.fu_berlin.ties.classify.TrainableClassifier
    protected void doTrain(FeatureVector featureVector, String str) throws UnsupportedOperationException {
        throw new UnsupportedOperationException("Winnow supports only error-driven training -- call trainOnError instead of train");
    }

    protected FeatureSet featureSet(FeatureVector featureVector) {
        FeatureSet featureSet;
        if (featureVector instanceof FeatureSet) {
            featureSet = (FeatureSet) featureVector;
        } else {
            featureSet = new FeatureSet();
            featureSet.addAll(featureVector);
        }
        return featureSet;
    }

    public float getDemotion() {
        return this.demotion;
    }

    public float getPromotion() {
        return this.promotion;
    }

    public boolean isBalanced() {
        return this.balanced;
    }

    protected float[] initScores() {
        return new float[getAllClasses().size()];
    }

    public float getThresholdThickness() {
        return this.thresholdThickness;
    }

    protected float initWeight() {
        return 1.0f;
    }

    protected float[] initWeightArray() {
        float[] fArr = this.balanced ? new float[getAllClasses().size() * 2] : new float[getAllClasses().size()];
        float initWeight = initWeight();
        for (int i = 0; i < fArr.length; i++) {
            fArr[i] = initWeight;
        }
        return fArr;
    }

    protected float majorThreshold(float f, float f2) {
        return f + (getThresholdThickness() * f2);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public float minorThreshold(float f, float f2) {
        return f - (getThresholdThickness() * f2);
    }

    protected float rawThreshold(FeatureSet featureSet) {
        return featureSet.size();
    }

    protected float sigmoid(float f, float f2, float f3) throws IllegalArgumentException {
        if (this.balanced) {
            return (float) (1.0d / (1.0d + Math.exp((f2 - f) / f3)));
        }
        if (f <= 0.0f) {
            throw new IllegalArgumentException(new StringBuffer().append("Activation value must be positive in normal Winnow: ").append(f).toString());
        }
        return 1.0f / (1.0f + (f2 / f));
    }

    protected float threshold(float f) {
        return f * defaultWeight();
    }

    @Override // de.fu_berlin.ties.classify.TrainableClassifier
    protected boolean trainOnErrorHook(PredictionDistribution predictionDistribution, FeatureVector featureVector, String str, Set set) throws ProcessingException {
        FeatureSet featureSet = featureSet(featureVector);
        HashSet hashSet = new HashSet();
        HashSet hashSet2 = new HashSet();
        chooseClassesToAdjust((WinnowDistribution) predictionDistribution, str, hashSet, hashSet2);
        if (hashSet.isEmpty() && hashSet2.isEmpty()) {
            return true;
        }
        Util.LOG.debug(new StringBuffer().append("Promoting classes: ").append(hashSet).append("; demoting classes: ").append(hashSet2).toString());
        short[] sArr = new short[getAllClasses().size()];
        int i = 0;
        for (String str2 : getAllClasses()) {
            if (hashSet2.contains(str2)) {
                sArr[i] = -1;
            } else if (hashSet.contains(str2)) {
                sArr[i] = 1;
            } else {
                sArr[i] = 0;
            }
            i++;
        }
        Iterator it = featureSet.iterator();
        synchronized (this) {
            while (it.hasNext()) {
                adjustWeights((Feature) it.next(), sArr);
            }
        }
        return true;
    }

    @Override // de.fu_berlin.ties.classify.TrainableClassifier
    public String toString() {
        return new ToStringBuilder(this).appendSuper(super.toString()).append("balanced", this.balanced).append("promotion", this.promotion).append("demotion", this.demotion).append("threshold thickness", this.thresholdThickness).append("no. of features", this.store.size()).toString();
    }

    protected void updateScores(Feature feature, float[] fArr) {
        float[] fArr2 = (float[]) this.store.get(new Integer(feature.hashCode()));
        int size = getAllClasses().size();
        if (fArr.length != size) {
            throw new IllegalArgumentException(new StringBuffer().append("Array of scores has ").append(fArr.length).append(" members instead of one for each of the ").append(size).append(" classes").toString());
        }
        if (fArr2 == null) {
            float defaultWeight = defaultWeight();
            if (defaultWeight != 0.0f) {
                for (int i = 0; i < size; i++) {
                    int i2 = i;
                    fArr[i2] = fArr[i2] + defaultWeight;
                }
                return;
            }
            return;
        }
        for (int i3 = 0; i3 < size; i3++) {
            int i4 = i3;
            fArr[i4] = fArr[i4] + fArr2[i3];
        }
        if (this.balanced) {
            for (int i5 = 0; i5 < size; i5++) {
                int i6 = i5;
                fArr[i6] = fArr[i6] - fArr2[i5 + size];
            }
        }
    }
}
