/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.weka.ranking.dyad.learner.zeroshot.util;

import ai.libs.jaicore.ml.ranking.dyad.learner.util.DyadMinMaxScaler;
import org.nd4j.linalg.api.ndarray.INDArray;
import weka.core.Utils;

public class ZeroShotUtil {
    private ZeroShotUtil() {
    }

    public static String[] mapJ48InputsToWekaOptions(double c, double m) throws Exception {
        long roundedM = Math.round(m);
        return Utils.splitOptions((String)("-C " + c + " -M " + roundedM));
    }

    public static String[] mapSMORBFInputsToWekaOptions(double cExp, double rbfGammaExp) throws Exception {
        double c = Math.pow(10.0, cExp);
        double g = Math.pow(10.0, rbfGammaExp);
        String cComplexityConstOption = "-C " + c;
        String rbfGammaOption = " -K \"weka.classifiers.functions.supportVector.RBFKernel -C 250007 -G " + g + "\"";
        String options = cComplexityConstOption + rbfGammaOption;
        return Utils.splitOptions((String)options);
    }

    public static String[] mapMLPInputsToWekaOptions(double lExp, double mExp, double n) throws Exception {
        double l = Math.pow(10.0, lExp);
        double m = Math.pow(10.0, mExp);
        long roundedN = Math.round(n);
        return Utils.splitOptions((String)("-L " + l + " -M " + m + " -N " + roundedN));
    }

    public static String[] mapRFInputsToWekaOptions(double i, double kFraction, double m, double depth, double kNumAttributes) throws Exception {
        int iRounded = (int)Math.round(i);
        int k = (int)Math.ceil(kNumAttributes * kFraction);
        int mRounded = (int)Math.round(m);
        int depthRounded = (int)Math.round(depth);
        return Utils.splitOptions((String)(" -I " + iRounded + " -K " + k + " -M " + mRounded + " -depth " + depthRounded));
    }

    public static INDArray unscaleParameters(INDArray parameters, DyadMinMaxScaler scaler, int numHyperPars) {
        int[] hyperParIndices = new int[numHyperPars];
        for (int i = 0; i < numHyperPars; ++i) {
            hyperParIndices[i] = (int)parameters.length() - numHyperPars + i;
        }
        INDArray unscaled = parameters.getColumns(hyperParIndices);
        int i = 0;
        while ((long)i < unscaled.length()) {
            unscaled.putScalar((long)i, unscaled.getDouble((long)i) * (scaler.getStatsY()[i].getMax() - scaler.getStatsY()[i].getMin()) + scaler.getStatsY()[i].getMin());
            ++i;
        }
        return unscaled;
    }
}

