/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.evaluator.support_vector_machine;

import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.support_vector_machine.Kernel;
import org.dmg.pmml.support_vector_machine.LinearKernel;
import org.dmg.pmml.support_vector_machine.PolynomialKernel;
import org.dmg.pmml.support_vector_machine.RadialBasisKernel;
import org.dmg.pmml.support_vector_machine.SigmoidKernel;
import org.jpmml.evaluator.UnsupportedElementException;
import org.jpmml.evaluator.Value;
import org.jpmml.evaluator.ValueFactory;

public class KernelUtil {
    private KernelUtil() {
    }

    public static <V extends Number> Value<V> evaluate(Kernel kernel, ValueFactory<V> valueFactory, Object input, Object vector) {
        if (kernel instanceof LinearKernel) {
            return KernelUtil.evaluateLinearKernel((LinearKernel)kernel, valueFactory, input, vector);
        }
        if (kernel instanceof PolynomialKernel) {
            return KernelUtil.evaluatePolynomialKernel((PolynomialKernel)kernel, valueFactory, input, vector);
        }
        if (kernel instanceof RadialBasisKernel) {
            return KernelUtil.evaluateRadialBasisKernel((RadialBasisKernel)kernel, valueFactory, input, vector);
        }
        if (kernel instanceof SigmoidKernel) {
            return KernelUtil.evaluateSigmoidKernel((SigmoidKernel)kernel, valueFactory, input, vector);
        }
        throw new UnsupportedElementException((PMMLObject)kernel);
    }

    public static <V extends Number> Value<V> evaluateLinearKernel(LinearKernel linearKernel, ValueFactory<V> valueFactory, Object input, Object vector) {
        Value<V> result = valueFactory.newValue(KernelUtil.dotProduct(input, vector));
        return result;
    }

    public static <V extends Number> Value<V> evaluatePolynomialKernel(PolynomialKernel polynomialKernel, ValueFactory<V> valueFactory, Object input, Object vector) {
        Double degree;
        Double coef0;
        Value<V> result = valueFactory.newValue(KernelUtil.dotProduct(input, vector));
        Double gamma = polynomialKernel.getGamma();
        if (gamma != 1.0) {
            result.multiply(gamma);
        }
        if ((coef0 = polynomialKernel.getCoef0()) != 1.0) {
            result.add(coef0);
        }
        if ((degree = polynomialKernel.getDegree()) != 1.0) {
            result.power(degree);
        }
        return result;
    }

    public static <V extends Number> Value<V> evaluateRadialBasisKernel(RadialBasisKernel radialBasisKernel, ValueFactory<V> valueFactory, Object input, Object vector) {
        Value<V> result = valueFactory.newValue(KernelUtil.negativeSquaredDistance(input, vector));
        Double gamma = radialBasisKernel.getGamma();
        if (gamma != 1.0) {
            result.multiply(gamma);
        }
        result.exp();
        return result;
    }

    public static <V extends Number> Value<V> evaluateSigmoidKernel(SigmoidKernel sigmoidKernel, ValueFactory<V> valueFactory, Object input, Object vector) {
        Double coef0;
        Value<V> result = valueFactory.newValue(KernelUtil.dotProduct(input, vector));
        Double gamma = sigmoidKernel.getGamma();
        if (gamma != 1.0) {
            result.multiply(gamma);
        }
        if ((coef0 = sigmoidKernel.getCoef0()) != 1.0) {
            result.add(coef0);
        }
        result.tanh();
        return result;
    }

    private static Number dotProduct(Object left, Object right) {
        if (left instanceof float[] && right instanceof float[]) {
            return Float.valueOf(KernelUtil.dotProduct((float[])left, (float[])right));
        }
        if (left instanceof double[] && right instanceof double[]) {
            return KernelUtil.dotProduct((double[])left, (double[])right);
        }
        throw new IllegalArgumentException();
    }

    private static float dotProduct(float[] left, float[] right) {
        if (left.length != right.length) {
            throw new IllegalArgumentException();
        }
        float sum = 0.0f;
        int max = left.length;
        for (int i = 0; i < max; ++i) {
            sum += left[i] * right[i];
        }
        return sum;
    }

    private static double dotProduct(double[] left, double[] right) {
        if (left.length != right.length) {
            throw new IllegalArgumentException();
        }
        double sum = 0.0;
        int max = left.length;
        for (int i = 0; i < max; ++i) {
            sum += left[i] * right[i];
        }
        return sum;
    }

    private static Number negativeSquaredDistance(Object left, Object right) {
        if (left instanceof float[] && right instanceof float[]) {
            return Float.valueOf(-KernelUtil.squaredDistance((float[])left, (float[])right));
        }
        if (left instanceof double[] && right instanceof double[]) {
            return -KernelUtil.squaredDistance((double[])left, (double[])right);
        }
        throw new IllegalArgumentException();
    }

    private static float squaredDistance(float[] left, float[] right) {
        if (left.length != right.length) {
            throw new IllegalArgumentException();
        }
        float sum = 0.0f;
        int max = left.length;
        for (int i = 0; i < max; ++i) {
            float diff = left[i] - right[i];
            sum += diff * diff;
        }
        return sum;
    }

    private static double squaredDistance(double[] left, double[] right) {
        if (left.length != right.length) {
            throw new IllegalArgumentException();
        }
        double sum = 0.0;
        int max = left.length;
        for (int i = 0; i < max; ++i) {
            double diff = left[i] - right[i];
            sum += diff * diff;
        }
        return sum;
    }
}

