/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.meta;

import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;
import weka.classifiers.RandomizableSingleClassifierEnhancer;
import weka.classifiers.evaluation.EvaluationUtils;
import weka.classifiers.evaluation.ThresholdCurve;
import weka.classifiers.functions.Logistic;
import weka.core.Attribute;
import weka.core.AttributeStats;
import weka.core.Capabilities;
import weka.core.Drawable;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.Utils;

public class ThresholdSelector
extends RandomizableSingleClassifierEnhancer
implements OptionHandler,
Drawable {
    static final long serialVersionUID = -1795038053239867444L;
    public static final int RANGE_NONE = 0;
    public static final int RANGE_BOUNDS = 1;
    public static final Tag[] TAGS_RANGE = new Tag[]{new Tag(0, "No range correction"), new Tag(1, "Correct based on min/max observed")};
    public static final int EVAL_TRAINING_SET = 2;
    public static final int EVAL_TUNED_SPLIT = 1;
    public static final int EVAL_CROSS_VALIDATION = 0;
    public static final Tag[] TAGS_EVAL = new Tag[]{new Tag(2, "Entire training set"), new Tag(1, "Single tuned fold"), new Tag(0, "N-Fold cross validation")};
    public static final int OPTIMIZE_0 = 0;
    public static final int OPTIMIZE_1 = 1;
    public static final int OPTIMIZE_LFREQ = 2;
    public static final int OPTIMIZE_MFREQ = 3;
    public static final int OPTIMIZE_POS_NAME = 4;
    public static final Tag[] TAGS_OPTIMIZE = new Tag[]{new Tag(0, "First class value"), new Tag(1, "Second class value"), new Tag(2, "Least frequent class value"), new Tag(3, "Most frequent class value"), new Tag(4, "Class value named: \"yes\", \"pos(itive)\",\"1\"")};
    public static final int FMEASURE = 1;
    public static final int ACCURACY = 2;
    public static final int TRUE_POS = 3;
    public static final int TRUE_NEG = 4;
    public static final int TP_RATE = 5;
    public static final int PRECISION = 6;
    public static final int RECALL = 7;
    public static final Tag[] TAGS_MEASURE = new Tag[]{new Tag(1, "FMEASURE"), new Tag(2, "ACCURACY"), new Tag(3, "TRUE_POS"), new Tag(4, "TRUE_NEG"), new Tag(5, "TP_RATE"), new Tag(6, "PRECISION"), new Tag(7, "RECALL")};
    protected double m_HighThreshold = 1.0;
    protected double m_LowThreshold = 0.0;
    protected double m_BestThreshold = -1.7976931348623157E308;
    protected double m_BestValue = -1.7976931348623157E308;
    protected int m_NumXValFolds = 3;
    protected int m_DesignatedClass = 0;
    protected int m_ClassMode = 4;
    protected int m_EvalMode = 1;
    protected int m_RangeMode = 0;
    int m_nMeasure = 1;
    protected boolean m_manualThreshold = false;
    protected double m_manualThresholdValue = -1.0;
    protected static final double MIN_VALUE = 0.05;

    public ThresholdSelector() {
        this.m_Classifier = new Logistic();
    }

    protected String defaultClassifierString() {
        return "weka.classifiers.functions.Logistic";
    }

    protected FastVector getPredictions(Instances instances, int mode, int numFolds) throws Exception {
        EvaluationUtils eu = new EvaluationUtils();
        eu.setSeed(this.m_Seed);
        switch (mode) {
            case 1: {
                Instances trainData = null;
                Instances evalData = null;
                Instances data = new Instances(instances);
                Random random = new Random(this.m_Seed);
                data.randomize(random);
                data.stratify(numFolds);
                for (int subsetIndex = 0; subsetIndex < numFolds; ++subsetIndex) {
                    trainData = data.trainCV(numFolds, subsetIndex, random);
                    evalData = data.testCV(numFolds, subsetIndex);
                    if (this.checkForInstance(trainData) && this.checkForInstance(evalData)) break;
                }
                return eu.getTrainTestPredictions(this.m_Classifier, trainData, evalData);
            }
            case 2: {
                return eu.getTrainTestPredictions(this.m_Classifier, instances, instances);
            }
            case 0: {
                return eu.getCVPredictions(this.m_Classifier, instances, numFolds);
            }
        }
        throw new RuntimeException("Unrecognized evaluation mode");
    }

    public String measureTipText() {
        return "Sets the measure for determining the threshold.";
    }

    public void setMeasure(SelectedTag newMeasure) {
        if (newMeasure.getTags() == TAGS_MEASURE) {
            this.m_nMeasure = newMeasure.getSelectedTag().getID();
        }
    }

    public SelectedTag getMeasure() {
        return new SelectedTag(this.m_nMeasure, TAGS_MEASURE);
    }

    protected void findThreshold(FastVector predictions) {
        Instances curve = new ThresholdCurve().getCurve(predictions, this.m_DesignatedClass);
        double low = 1.0;
        double high = 0.0;
        if (curve.numInstances() > 0) {
            Instance maxInst = curve.instance(0);
            double maxValue = 0.0;
            int index1 = 0;
            int index2 = 0;
            switch (this.m_nMeasure) {
                case 1: {
                    index1 = curve.attribute("FMeasure").index();
                    maxValue = maxInst.value(index1);
                    break;
                }
                case 3: {
                    index1 = curve.attribute("True Positives").index();
                    maxValue = maxInst.value(index1);
                    break;
                }
                case 4: {
                    index1 = curve.attribute("True Negatives").index();
                    maxValue = maxInst.value(index1);
                    break;
                }
                case 5: {
                    index1 = curve.attribute("True Positive Rate").index();
                    maxValue = maxInst.value(index1);
                    break;
                }
                case 6: {
                    index1 = curve.attribute("Precision").index();
                    maxValue = maxInst.value(index1);
                    break;
                }
                case 7: {
                    index1 = curve.attribute("Recall").index();
                    maxValue = maxInst.value(index1);
                    break;
                }
                case 2: {
                    index1 = curve.attribute("True Positives").index();
                    index2 = curve.attribute("True Negatives").index();
                    maxValue = maxInst.value(index1) + maxInst.value(index2);
                }
            }
            int indexThreshold = curve.attribute("Threshold").index();
            for (int i = 1; i < curve.numInstances(); ++i) {
                Instance current = curve.instance(i);
                double currentValue = 0.0;
                currentValue = this.m_nMeasure == 2 ? current.value(index1) + current.value(index2) : current.value(index1);
                if (currentValue > maxValue) {
                    maxInst = current;
                    maxValue = currentValue;
                }
                if (this.m_RangeMode != 1) continue;
                double thresh = current.value(indexThreshold);
                if (thresh < low) {
                    low = thresh;
                }
                if (!(thresh > high)) continue;
                high = thresh;
            }
            if (maxValue > 0.05) {
                this.m_BestThreshold = maxInst.value(indexThreshold);
                this.m_BestValue = maxValue;
            }
            if (this.m_RangeMode == 1) {
                this.m_LowThreshold = low;
                this.m_HighThreshold = high;
            }
        }
    }

    public Enumeration listOptions() {
        Vector<Option> newVector = new Vector<Option>(5);
        newVector.addElement(new Option("\tThe class for which threshold is determined. Valid values are:\n\t1, 2 (for first and second classes, respectively), 3 (for whichever\n\tclass is least frequent), and 4 (for whichever class value is most\n\tfrequent), and 5 (for the first class named any of \"yes\",\"pos(itive)\"\n\t\"1\", or method 3 if no matches). (default 5).", "C", 1, "-C <integer>"));
        newVector.addElement(new Option("\tNumber of folds used for cross validation. If just a\n\thold-out set is used, this determines the size of the hold-out set\n\t(default 3).", "X", 1, "-X <number of folds>"));
        newVector.addElement(new Option("\tSets whether confidence range correction is applied. This\n\tcan be used to ensure the confidences range from 0 to 1.\n\tUse 0 for no range correction, 1 for correction based on\n\tthe min/max values seen during threshold selection\n\t(default 0).", "R", 1, "-R <integer>"));
        newVector.addElement(new Option("\tSets the evaluation mode. Use 0 for\n\tevaluation using cross-validation,\n\t1 for evaluation using hold-out set,\n\tand 2 for evaluation on the\n\ttraining data (default 1).", "E", 1, "-E <integer>"));
        newVector.addElement(new Option("\tMeasure used for evaluation (default is FMEASURE).\n", "M", 1, "-M [FMEASURE|ACCURACY|TRUE_POS|TRUE_NEG|TP_RATE|PRECISION|RECALL]"));
        newVector.addElement(new Option("\tSet a manual threshold to use. This option overrides\n\tautomatic selection and options pertaining to\n\tautomatic selection will be ignored.\n\t(default -1, i.e. do not use a manual threshold).", "manual", 1, "-manual <real>"));
        Enumeration enu = super.listOptions();
        while (enu.hasMoreElements()) {
            newVector.addElement((Option)enu.nextElement());
        }
        return newVector.elements();
    }

    public void setOptions(String[] options) throws Exception {
        String classString;
        double val;
        String manualS = Utils.getOption("manual", options);
        if (manualS.length() > 0 && (val = Double.parseDouble(manualS)) >= 0.0) {
            this.setManualThresholdValue(val);
        }
        if ((classString = Utils.getOption('C', options)).length() != 0) {
            this.setDesignatedClass(new SelectedTag(Integer.parseInt(classString) - 1, TAGS_OPTIMIZE));
        } else {
            this.setDesignatedClass(new SelectedTag(4, TAGS_OPTIMIZE));
        }
        String modeString = Utils.getOption('E', options);
        if (modeString.length() != 0) {
            this.setEvaluationMode(new SelectedTag(Integer.parseInt(modeString), TAGS_EVAL));
        } else {
            this.setEvaluationMode(new SelectedTag(1, TAGS_EVAL));
        }
        String rangeString = Utils.getOption('R', options);
        if (rangeString.length() != 0) {
            this.setRangeCorrection(new SelectedTag(Integer.parseInt(rangeString), TAGS_RANGE));
        } else {
            this.setRangeCorrection(new SelectedTag(0, TAGS_RANGE));
        }
        String measureString = Utils.getOption('M', options);
        if (measureString.length() != 0) {
            this.setMeasure(new SelectedTag(measureString, TAGS_MEASURE));
        } else {
            this.setMeasure(new SelectedTag(1, TAGS_MEASURE));
        }
        String foldsString = Utils.getOption('X', options);
        if (foldsString.length() != 0) {
            this.setNumXValFolds(Integer.parseInt(foldsString));
        } else {
            this.setNumXValFolds(3);
        }
        super.setOptions(options);
    }

    public String[] getOptions() {
        String[] superOptions = super.getOptions();
        String[] options = new String[superOptions.length + 12];
        int current = 0;
        if (this.m_manualThreshold) {
            options[current++] = "-manual";
            options[current++] = "" + this.getManualThresholdValue();
        }
        options[current++] = "-C";
        options[current++] = "" + (this.m_ClassMode + 1);
        options[current++] = "-X";
        options[current++] = "" + this.getNumXValFolds();
        options[current++] = "-E";
        options[current++] = "" + this.m_EvalMode;
        options[current++] = "-R";
        options[current++] = "" + this.m_RangeMode;
        options[current++] = "-M";
        options[current++] = "" + this.getMeasure().getSelectedTag().getReadable();
        System.arraycopy(superOptions, 0, options, current, superOptions.length);
        current += superOptions.length;
        while (current < options.length) {
            options[current++] = "";
        }
        return options;
    }

    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAllClasses();
        result.disableAllClassDependencies();
        result.enable(Capabilities.Capability.BINARY_CLASS);
        return result;
    }

    public void buildClassifier(Instances instances) throws Exception {
        this.getCapabilities().testWithFail(instances);
        instances = new Instances(instances);
        instances.deleteWithMissingClass();
        AttributeStats stats = instances.attributeStats(instances.classIndex());
        this.m_BestThreshold = this.m_manualThreshold ? this.m_manualThresholdValue : 0.5;
        this.m_BestValue = 0.05;
        this.m_HighThreshold = 1.0;
        this.m_LowThreshold = 0.0;
        if (stats.distinctCount != 2) {
            System.err.println("Couldn't find examples of both classes. No adjustment.");
            this.m_Classifier.buildClassifier(instances);
        } else {
            switch (this.m_ClassMode) {
                case 0: {
                    this.m_DesignatedClass = 0;
                    break;
                }
                case 1: {
                    this.m_DesignatedClass = 1;
                    break;
                }
                case 4: {
                    Attribute cAtt = instances.classAttribute();
                    boolean found = false;
                    for (int i = 0; i < cAtt.numValues() && !found; ++i) {
                        String name = cAtt.value(i).toLowerCase();
                        if (!name.startsWith("yes") && !name.equals("1") && !name.startsWith("pos")) continue;
                        found = true;
                        this.m_DesignatedClass = i;
                    }
                    if (found) break;
                }
                case 2: {
                    this.m_DesignatedClass = stats.nominalCounts[0] > stats.nominalCounts[1] ? 1 : 0;
                    break;
                }
                case 3: {
                    this.m_DesignatedClass = stats.nominalCounts[0] > stats.nominalCounts[1] ? 0 : 1;
                    break;
                }
                default: {
                    throw new Exception("Unrecognized class value selection mode");
                }
            }
            if (this.m_manualThreshold) {
                this.m_Classifier.buildClassifier(instances);
                return;
            }
            if (stats.nominalCounts[this.m_DesignatedClass] == 1) {
                System.err.println("Only 1 positive found: optimizing on training data");
                this.findThreshold(this.getPredictions(instances, 2, 0));
            } else {
                int numFolds = Math.min(this.m_NumXValFolds, stats.nominalCounts[this.m_DesignatedClass]);
                this.findThreshold(this.getPredictions(instances, this.m_EvalMode, numFolds));
                if (this.m_EvalMode != 2) {
                    this.m_Classifier.buildClassifier(instances);
                }
            }
        }
    }

    private boolean checkForInstance(Instances data) throws Exception {
        for (int i = 0; i < data.numInstances(); ++i) {
            if ((int)data.instance(i).classValue() != this.m_DesignatedClass) continue;
            return true;
        }
        return false;
    }

    public double[] distributionForInstance(Instance instance) throws Exception {
        double[] pred = this.m_Classifier.distributionForInstance(instance);
        double prob = pred[this.m_DesignatedClass];
        prob = prob > this.m_BestThreshold ? 0.5 + (prob - this.m_BestThreshold) / ((this.m_HighThreshold - this.m_BestThreshold) * 2.0) : (prob - this.m_LowThreshold) / ((this.m_BestThreshold - this.m_LowThreshold) * 2.0);
        if (prob < 0.0) {
            prob = 0.0;
        } else if (prob > 1.0) {
            prob = 1.0;
        }
        pred[this.m_DesignatedClass] = prob;
        if (pred.length == 2) {
            pred[(this.m_DesignatedClass + 1) % 2] = 1.0 - prob;
        }
        return pred;
    }

    public String globalInfo() {
        return "A metaclassifier that selecting a mid-point threshold on the probability output by a Classifier. The midpoint threshold is set so that a given performance measure is optimized. Currently this is the F-measure. Performance is measured either on the training data, a hold-out set or using cross-validation. In addition, the probabilities returned by the base learner can have their range expanded so that the output probabilities will reside between 0 and 1 (this is useful if the scheme normally produces probabilities in a very narrow range).";
    }

    public String designatedClassTipText() {
        return "Sets the class value for which the optimization is performed. The options are: pick the first class value; pick the second class value; pick whichever class is least frequent; pick whichever class value is most frequent; pick the first class named any of \"yes\",\"pos(itive)\", \"1\", or the least frequent if no matches).";
    }

    public SelectedTag getDesignatedClass() {
        return new SelectedTag(this.m_ClassMode, TAGS_OPTIMIZE);
    }

    public void setDesignatedClass(SelectedTag newMethod) {
        if (newMethod.getTags() == TAGS_OPTIMIZE) {
            this.m_ClassMode = newMethod.getSelectedTag().getID();
        }
    }

    public String evaluationModeTipText() {
        return "Sets the method used to determine the threshold/performance curve. The options are: perform optimization based on the entire training set (may result in overfitting); perform an n-fold cross-validation (may be time consuming); perform one fold of an n-fold cross-validation (faster but likely less accurate).";
    }

    public void setEvaluationMode(SelectedTag newMethod) {
        if (newMethod.getTags() == TAGS_EVAL) {
            this.m_EvalMode = newMethod.getSelectedTag().getID();
        }
    }

    public SelectedTag getEvaluationMode() {
        return new SelectedTag(this.m_EvalMode, TAGS_EVAL);
    }

    public String rangeCorrectionTipText() {
        return "Sets the type of prediction range correction performed. The options are: do not do any range correction; expand predicted probabilities so that the minimum probability observed during the optimization maps to 0, and the maximum maps to 1 (values outside this range are clipped to 0 and 1).";
    }

    public void setRangeCorrection(SelectedTag newMethod) {
        if (newMethod.getTags() == TAGS_RANGE) {
            this.m_RangeMode = newMethod.getSelectedTag().getID();
        }
    }

    public SelectedTag getRangeCorrection() {
        return new SelectedTag(this.m_RangeMode, TAGS_RANGE);
    }

    public String numXValFoldsTipText() {
        return "Sets the number of folds used during full cross-validation and tuned fold evaluation. This number will be automatically reduced if there are insufficient positive examples.";
    }

    public int getNumXValFolds() {
        return this.m_NumXValFolds;
    }

    public void setNumXValFolds(int newNumFolds) {
        if (newNumFolds < 2) {
            throw new IllegalArgumentException("Number of folds must be greater than 1");
        }
        this.m_NumXValFolds = newNumFolds;
    }

    public int graphType() {
        if (this.m_Classifier instanceof Drawable) {
            return ((Drawable)((Object)this.m_Classifier)).graphType();
        }
        return 0;
    }

    public String graph() throws Exception {
        if (this.m_Classifier instanceof Drawable) {
            return ((Drawable)((Object)this.m_Classifier)).graph();
        }
        throw new Exception("Classifier: " + this.getClassifierSpec() + " cannot be graphed");
    }

    public String manualThresholdValueTipText() {
        return "Sets a manual threshold value to use. If this is set (non-negative value between 0 and 1), then all options pertaining to automatic threshold selection are ignored. ";
    }

    public void setManualThresholdValue(double threshold) throws Exception {
        this.m_manualThresholdValue = threshold;
        if (threshold >= 0.0 && threshold <= 1.0) {
            this.m_manualThreshold = true;
        } else {
            this.m_manualThreshold = false;
            if (threshold >= 0.0) {
                throw new IllegalArgumentException("Threshold must be in the range 0..1.");
            }
        }
    }

    public double getManualThresholdValue() {
        return this.m_manualThresholdValue;
    }

    public String toString() {
        if (this.m_BestValue == -1.7976931348623157E308) {
            return "ThresholdSelector: No model built yet.";
        }
        String result = "Threshold Selector.\nClassifier: " + this.m_Classifier.getClass().getName() + "\n";
        result = result + "Index of designated class: " + this.m_DesignatedClass + "\n";
        if (this.m_manualThreshold) {
            result = result + "User supplied threshold: " + this.m_BestThreshold + "\n";
        } else {
            result = result + "Evaluation mode: ";
            switch (this.m_EvalMode) {
                case 0: {
                    result = result + this.m_NumXValFolds + "-fold cross-validation";
                    break;
                }
                case 1: {
                    result = result + "tuning on 1/" + this.m_NumXValFolds + " of the data";
                    break;
                }
                default: {
                    result = result + "tuning on the training data";
                }
            }
            result = result + "\n";
            result = result + "Threshold: " + this.m_BestThreshold + "\n";
            result = result + "Best value: " + this.m_BestValue + "\n";
            if (this.m_RangeMode == 1) {
                result = result + "Expanding range [" + this.m_LowThreshold + "," + this.m_HighThreshold + "] to [0, 1]\n";
            }
            result = result + "Measure: " + this.getMeasure().getSelectedTag().getReadable() + "\n";
        }
        result = result + this.m_Classifier.toString();
        return result;
    }

    public String getRevision() {
        return RevisionUtils.extract("$Revision: 1.43 $");
    }

    public static void main(String[] argv) {
        ThresholdSelector.runClassifier(new ThresholdSelector(), argv);
    }
}

