/*
 * Decompiled with CFR 0.152.
 */
package biz.k11i.xgboost.learner;

import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
import net.jafama.FastMath;

public class ObjFunction
implements Serializable {
    private static final Map<String, ObjFunction> FUNCTIONS = new HashMap<String, ObjFunction>();

    public static ObjFunction fromName(String name) {
        ObjFunction result = FUNCTIONS.get(name);
        if (result == null) {
            throw new IllegalArgumentException(name + " is not supported objective function.");
        }
        return result;
    }

    public static void register(String name, ObjFunction objFunction) {
        FUNCTIONS.put(name, objFunction);
    }

    public static void useFastMathExp(boolean useJafama) {
        if (useJafama) {
            ObjFunction.register("binary:logistic", new RegLossObjLogistic_Jafama());
            ObjFunction.register("multi:softprob", new SoftmaxMultiClassObjProb_Jafama());
        } else {
            ObjFunction.register("binary:logistic", new RegLossObjLogistic());
            ObjFunction.register("multi:softprob", new SoftmaxMultiClassObjProb());
        }
    }

    public float[] predTransform(float[] preds) {
        return preds;
    }

    public float predTransform(float pred) {
        return pred;
    }

    static {
        ObjFunction.register("rank:pairwise", new ObjFunction());
        ObjFunction.register("binary:logistic", new RegLossObjLogistic());
        ObjFunction.register("binary:logitraw", new ObjFunction());
        ObjFunction.register("multi:softmax", new SoftmaxMultiClassObjClassify());
        ObjFunction.register("multi:softprob", new SoftmaxMultiClassObjProb());
        ObjFunction.register("reg:linear", new ObjFunction());
    }

    static class SoftmaxMultiClassObjProb_Jafama
    extends SoftmaxMultiClassObjProb {
        SoftmaxMultiClassObjProb_Jafama() {
        }

        @Override
        float exp(float x) {
            return (float)FastMath.exp((double)x);
        }
    }

    static class SoftmaxMultiClassObjProb
    extends ObjFunction {
        SoftmaxMultiClassObjProb() {
        }

        @Override
        public float[] predTransform(float[] preds) {
            int i;
            float max = preds[0];
            for (int i2 = 1; i2 < preds.length; ++i2) {
                max = Math.max(preds[i2], max);
            }
            double sum = 0.0;
            for (i = 0; i < preds.length; ++i) {
                preds[i] = this.exp(preds[i] - max);
                sum += (double)preds[i];
            }
            i = 0;
            while (i < preds.length) {
                int n = i++;
                preds[n] = preds[n] / (float)sum;
            }
            return preds;
        }

        @Override
        public float predTransform(float pred) {
            throw new UnsupportedOperationException();
        }

        float exp(float x) {
            return (float)Math.exp(x);
        }
    }

    static class SoftmaxMultiClassObjClassify
    extends ObjFunction {
        SoftmaxMultiClassObjClassify() {
        }

        @Override
        public float[] predTransform(float[] preds) {
            int maxIndex = 0;
            float max = preds[0];
            for (int i = 1; i < preds.length; ++i) {
                if (!(max < preds[i])) continue;
                maxIndex = i;
                max = preds[i];
            }
            return new float[]{maxIndex};
        }

        @Override
        public float predTransform(float pred) {
            throw new UnsupportedOperationException();
        }
    }

    static class RegLossObjLogistic_Jafama
    extends RegLossObjLogistic {
        RegLossObjLogistic_Jafama() {
        }

        double sigmoid(double x) {
            return 1.0 / (1.0 + FastMath.exp((double)(-x)));
        }
    }

    static class RegLossObjLogistic
    extends ObjFunction {
        RegLossObjLogistic() {
        }

        @Override
        public float[] predTransform(float[] preds) {
            for (int i = 0; i < preds.length; ++i) {
                preds[i] = this.sigmoid(preds[i]);
            }
            return preds;
        }

        @Override
        public float predTransform(float pred) {
            return this.sigmoid(pred);
        }

        float sigmoid(float x) {
            return 1.0f / (1.0f + (float)Math.exp(-x));
        }
    }
}

