/*
 * Decompiled with CFR 0.152.
 */
package hex.Infogram;

import hex.Infogram.InfogramModel;
import hex.Model;
import hex.ModelBuilder;
import hex.ModelBuilderHelper;
import hex.SplitFrame;
import hex.schemas.DRFV3;
import hex.schemas.DeepLearningV3;
import hex.schemas.GBMV3;
import hex.schemas.GLMV3;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.DoubleStream;
import water.DKV;
import water.Iced;
import water.Key;
import water.Keyed;
import water.Scope;
import water.api.SchemaServer;
import water.api.schemas3.ModelParametersSchemaV3;
import water.fvec.Frame;
import water.fvec.Vec;
import water.parser.BufferedString;
import water.util.TwoDimTable;

public class InfogramUtils {
    public static String[] extractPredictors(InfogramModel.InfogramParameters parms, Frame train, String foldColumnName) {
        String[] nonPredictors;
        ArrayList<String> colNames = new ArrayList<String>(Arrays.asList(train.names()));
        for (String nonPred : nonPredictors = parms.getNonPredictors()) {
            colNames.remove(nonPred);
        }
        if (parms._protected_columns != null) {
            for (String protectPred : parms._protected_columns) {
                colNames.remove(protectPred);
            }
        }
        if (foldColumnName != null) {
            colNames.remove(foldColumnName);
        }
        return colNames.toArray(new String[colNames.size()]);
    }

    public static String[] extractTopKPredictors(InfogramModel.InfogramParameters parms, Frame trainFrame, String[] eligiblePredictors) {
        if (parms._top_n_features >= eligiblePredictors.length) {
            return eligiblePredictors;
        }
        Frame topTrain = InfogramUtils.extractTrainingFrame(parms, eligiblePredictors, 1.0, trainFrame);
        Scope.track((Frame[])new Frame[]{topTrain});
        parms._infogram_algorithm_parameters._train = topTrain._key;
        Model.Parameters[] modelParams = InfogramUtils.buildModelParameters(new Frame[]{topTrain}, parms._infogram_algorithm_parameters, 1, parms._algorithm);
        ModelBuilder[] builders = ModelBuilderHelper.trainModelsParallel((ModelBuilder[])InfogramUtils.buildModelBuilders(modelParams), (int)1);
        Model builtModel = builders[0].get();
        Scope.track_generic((Keyed)builtModel);
        TwoDimTable varImp = builtModel._output.getVariableImportances();
        String[] ntopPredictors = new String[parms._top_n_features];
        String[] rowHeaders = varImp.getRowHeaders();
        System.arraycopy(rowHeaders, 0, ntopPredictors, 0, parms._top_n_features);
        return ntopPredictors;
    }

    public static int findstart(Key<Frame>[] generatedFrameKeys) {
        int arrLen = generatedFrameKeys.length;
        for (int index = 0; index < arrLen; ++index) {
            if (generatedFrameKeys[index] != null) continue;
            return index;
        }
        return -1;
    }

    public static Frame extractTrainingFrame(InfogramModel.InfogramParameters parms, String[] sensitivePredictors, double dataFraction, Frame trainFrame) {
        if (dataFraction < 1.0) {
            SplitFrame sf = new SplitFrame(trainFrame, new double[]{parms._data_fraction, 1.0 - parms._data_fraction}, new Key[]{Key.make((String)("ig_train_" + trainFrame._key)), Key.make((String)("ig_discard" + trainFrame._key))});
            sf.exec().get();
            String[] ksplits = sf._destination_frames;
            trainFrame = (Frame)DKV.get((Key)ksplits[0]).get();
            DKV.remove((Key)ksplits[1]);
        }
        Frame extractedFrame = new Frame(Key.make());
        if (sensitivePredictors != null) {
            for (String colName : sensitivePredictors) {
                extractedFrame.add(colName, trainFrame.vec(colName));
            }
        }
        String[] nonPredictors = parms.getNonPredictors();
        List<String> colNames = Arrays.asList(trainFrame.names());
        boolean cvWeightsPresent = parms._weights_column != null && colNames.contains("__internal_cv_weights__") && (parms._weights_column.equals("__internal_cv_weights__") || parms._weights_column.equals("infogram_internal_cv_weights_"));
        for (String nonPredName : nonPredictors) {
            if (("__internal_cv_weights__".equals(nonPredName) || "infogram_internal_cv_weights_".equals(nonPredName)) && colNames.contains("__internal_cv_weights__")) {
                String cvWeightName = "infogram_internal_cv_weights_";
                extractedFrame.add(cvWeightName, trainFrame.vec("__internal_cv_weights__"));
                parms._weights_column = cvWeightName;
                continue;
            }
            if (nonPredName.equals(parms._fold_column) && colNames.contains(parms._fold_column) && !cvWeightsPresent) {
                extractedFrame.add(nonPredName, trainFrame.vec(nonPredName));
                continue;
            }
            if (nonPredName.equals(parms._fold_column) || !colNames.contains(nonPredName)) continue;
            extractedFrame.add(nonPredName, trainFrame.vec(nonPredName));
        }
        if (parms._fold_column == null || !colNames.contains(parms._fold_column) || cvWeightsPresent) {
            parms._fold_column = null;
        }
        DKV.put((Keyed)extractedFrame);
        return extractedFrame;
    }

    public static String[] generateModelDescription(String[] topKPredictors, String[] sensitive_attributes) {
        int numModel = topKPredictors.length + 1;
        String[] modelNames = new String[numModel];
        int numPredInd = topKPredictors.length - 1;
        if (sensitive_attributes == null) {
            for (int index = 0; index < numPredInd; ++index) {
                modelNames[index] = "Model built missing predictor " + topKPredictors[index];
            }
            modelNames[numPredInd] = "Full model built with all predictors";
        } else {
            for (int index = 0; index < numPredInd; ++index) {
                modelNames[index] = "Model built with sensitive_features and predictor " + topKPredictors[index];
            }
            modelNames[numPredInd] = "Model built with sensitive_features only";
        }
        return modelNames;
    }

    public static Model.Parameters[] buildModelParameters(Frame[] trainingFrames, Model.Parameters infoParams, int numModels, InfogramModel.InfogramParameters.Algorithm algoName) {
        GLMV3.GLMParametersV3 paramsSchema;
        switch (algoName) {
            case glm: {
                paramsSchema = new GLMV3.GLMParametersV3();
                break;
            }
            case AUTO: 
            case gbm: {
                paramsSchema = new GBMV3.GBMParametersV3();
                break;
            }
            case drf: {
                paramsSchema = new DRFV3.DRFParametersV3();
                break;
            }
            case deeplearning: {
                paramsSchema = new DeepLearningV3.DeepLearningParametersV3();
                break;
            }
            case xgboost: {
                Model.Parameters params = ModelBuilder.makeParameters((String)"XGBoost");
                paramsSchema = (ModelParametersSchemaV3)SchemaServer.schema((Iced)params);
                break;
            }
            default: {
                throw new UnsupportedOperationException("Unknown algo: " + (Object)((Object)algoName));
            }
        }
        Model.Parameters[] modelParams = new Model.Parameters[numModels];
        for (int index = 0; index < numModels; ++index) {
            modelParams[index] = (Model.Parameters)paramsSchema.fillFromImpl(infoParams).createAndFillImpl();
            modelParams[index]._ignored_columns = null;
            modelParams[index]._train = trainingFrames[index]._key;
        }
        return modelParams;
    }

    public static ModelBuilder[] buildModelBuilders(Model.Parameters[] modelParams) {
        int numModel = modelParams.length;
        ModelBuilder[] modelBuilders = new ModelBuilder[numModel];
        for (int index = 0; index < numModel; ++index) {
            modelBuilders[index] = ModelBuilder.make((Model.Parameters)modelParams[index]);
        }
        return modelBuilders;
    }

    public static Frame generateCMIRelevance(String[] allPredictorNames, double[] admissible, double[] admissibleIndex, double[] relevance, double[] cmi, double[] cmiRaw, boolean buildCore) {
        String[] stringArray;
        Vec.VectorGroup vg = Vec.VectorGroup.VG_LEN1;
        Vec vName = Vec.makeVec((String[])allPredictorNames, (Key)vg.addVec());
        Vec vAdm = Vec.makeVec((double[])admissible, (Key)vg.addVec());
        Vec vAdmIndex = Vec.makeVec((double[])admissibleIndex, (Key)vg.addVec());
        Vec vRel = Vec.makeVec((double[])relevance, (Key)vg.addVec());
        Vec vCMI = Vec.makeVec((double[])cmi, (Key)vg.addVec());
        Vec vCMIRaw = Vec.makeVec((double[])cmiRaw, (Key)vg.addVec());
        if (buildCore) {
            String[] stringArray2 = new String[6];
            stringArray2[0] = "column";
            stringArray2[1] = "admissible";
            stringArray2[2] = "admissible_index";
            stringArray2[3] = "total_information";
            stringArray2[4] = "net_information";
            stringArray = stringArray2;
            stringArray2[5] = "cmi_raw";
        } else {
            String[] stringArray3 = new String[6];
            stringArray3[0] = "column";
            stringArray3[1] = "admissible";
            stringArray3[2] = "admissible_index";
            stringArray3[3] = "relevance_index";
            stringArray3[4] = "safety_index";
            stringArray = stringArray3;
            stringArray3[5] = "cmi_raw";
        }
        String[] columnNames = stringArray;
        Frame cmiRelFrame = new Frame(Key.make(), columnNames, new Vec[]{vName, vAdm, vAdmIndex, vRel, vCMI, vCMIRaw});
        DKV.put((Keyed)cmiRelFrame);
        return cmiRelFrame;
    }

    public static void removeFromDKV(Key<Frame>[] generatedFrameKeys) {
        for (Key<Frame> oneFrameKey : generatedFrameKeys) {
            if (null == oneFrameKey) break;
            DKV.remove(oneFrameKey);
        }
    }

    public static double[] calculateFinalCMI(double[] cmiRaw, boolean buildCore) {
        int lastInd = cmiRaw.length - 1;
        double maxCMI = 0.0;
        for (int index = 0; index < lastInd; ++index) {
            cmiRaw[index] = buildCore ? Math.max(0.0, cmiRaw[lastInd] - cmiRaw[index]) : Math.max(0.0, cmiRaw[index] - cmiRaw[lastInd]);
            if (!(cmiRaw[index] > maxCMI)) continue;
            maxCMI = cmiRaw[index];
        }
        double scale = maxCMI == 0.0 ? 0.0 : 1.0 / maxCMI;
        double[] cmi = new double[lastInd];
        double[] cmiLong = DoubleStream.of(cmiRaw).map(d -> d * scale).toArray();
        System.arraycopy(cmiLong, 0, cmi, 0, lastInd);
        return cmi;
    }

    public static Frame subtractAdd2Frame(Frame base, Frame featureFrame, String[] removeFeatures, String[] addFeatures) {
        Frame newFrame = new Frame(base);
        if (removeFeatures != null) {
            for (String removeEle : removeFeatures) {
                newFrame.remove(removeEle);
            }
        }
        for (String addEle : addFeatures) {
            newFrame.add(addEle, featureFrame.vec(addEle));
        }
        DKV.put((Keyed)newFrame);
        return newFrame;
    }

    public static void extractInfogramInfo(InfogramModel infoModel, double[][] cmiRaw, List<List<String>> columns, int foldIndex) {
        Frame validFrame = (Frame)DKV.getGet(((InfogramModel.InfogramModelOutput)infoModel._output)._admissible_score_key_valid);
        cmiRaw[foldIndex] = InfogramUtils.vec2array(validFrame.vec(5));
        String[] oneColumn = InfogramUtils.strVec2array(validFrame.vec(0));
        ArrayList<String> oneFrameColumn = new ArrayList<String>(Arrays.asList(oneColumn));
        columns.add(oneFrameColumn);
        validFrame.remove();
    }

    static double[] vec2array(Vec v) {
        assert (v.length() < Integer.MAX_VALUE);
        int len = (int)v.length();
        double[] array = new double[len];
        for (int i = 0; i < len; ++i) {
            array[i] = v.at((long)i);
        }
        return array;
    }

    static String[] strVec2array(Vec v) {
        assert (v.length() < Integer.MAX_VALUE);
        int len = (int)v.length();
        BufferedString bs = new BufferedString();
        String[] array = new String[len];
        for (int i = 0; i < len; ++i) {
            BufferedString s = v.atStr(bs, (long)i);
            if (s == null) continue;
            array[i] = s.toString();
        }
        return array;
    }
}

