/*
 * Decompiled with CFR 0.152.
 */
package hex.glm;

import hex.glm.GLM;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import water.MemoryManager;
import water.fvec.Frame;
import water.util.ArrayUtils;
import water.util.TwoDimTable;

public class GLMUtils {
    public static int[][] extractAdaptedFrameIndices(Frame adaptFrame, String[][] gamColnames, int numOffset) {
        String[] frameNames = adaptFrame.names();
        ArrayList<String> allColNames = new ArrayList<String>();
        for (String name : frameNames) {
            allColNames.add(name);
        }
        int[][] gamColIndices = new int[gamColnames.length][];
        int numFrame = gamColnames.length;
        for (int frameNum = 0; frameNum < numFrame; ++frameNum) {
            int numCols = gamColnames[frameNum].length;
            gamColIndices[frameNum] = MemoryManager.malloc4((int)numCols);
            for (int index = 0; index < numCols; ++index) {
                gamColIndices[frameNum][index] = numOffset + allColNames.indexOf(gamColnames[frameNum][index]);
            }
        }
        return gamColIndices;
    }

    public static GLM.GLMGradientInfo copyGInfo(GLM.GLMGradientInfo ginfo) {
        double[] gradient = (double[])ginfo._gradient.clone();
        GLM.GLMGradientInfo tempGinfo = new GLM.GLMGradientInfo(ginfo._likelihood, ginfo._objVal, gradient);
        return tempGinfo;
    }

    public static TwoDimTable combineScoringHistory(TwoDimTable glmSc1, TwoDimTable earlyStopSc2, List<Integer> scoreIterationList) {
        String[] esColTypes = earlyStopSc2.getColTypes();
        String[] esColFormats = earlyStopSc2.getColFormats();
        ArrayList<String> finalColHeaders = new ArrayList<String>(Arrays.asList(glmSc1.getColHeaders()));
        int indexOfIter = finalColHeaders.indexOf("iteration");
        if (indexOfIter < 0) {
            indexOfIter = finalColHeaders.indexOf("iterations");
        }
        ArrayList<String> finalColTypes = new ArrayList<String>(Arrays.asList(glmSc1.getColTypes()));
        ArrayList<String> finalColFormats = new ArrayList<String>(Arrays.asList(glmSc1.getColFormats()));
        ArrayList<Integer> earlyStopColIndices = new ArrayList<Integer>();
        int colCounter = 0;
        for (String colName : earlyStopSc2.getColHeaders()) {
            if (!finalColHeaders.contains(colName.toLowerCase())) {
                finalColHeaders.add(colName);
                finalColTypes.add(esColTypes[colCounter]);
                finalColFormats.add(esColFormats[colCounter]);
                earlyStopColIndices.add(colCounter);
            }
            ++colCounter;
        }
        int tableSize = finalColHeaders.size();
        TwoDimTable res = new TwoDimTable("Scoring History", "", glmSc1.getRowHeaders(), finalColHeaders.toArray(new String[tableSize]), finalColTypes.toArray(new String[tableSize]), finalColFormats.toArray(new String[tableSize]), "");
        res = GLMUtils.combineTableContents(glmSc1, earlyStopSc2, res, earlyStopColIndices, scoreIterationList, indexOfIter);
        return res;
    }

    public static TwoDimTable combineTableContents(TwoDimTable glmSc1, TwoDimTable earlyStopSc2, TwoDimTable combined, List<Integer> earlyStopColIndices, List<Integer> scoreIterationList, int indexOfIter) {
        int rowSize = glmSc1.getRowDim();
        int rowSize2 = earlyStopSc2.getRowDim();
        int glmColSize = glmSc1.getColDim();
        int earlyStopColSize = earlyStopColIndices.size();
        int sc2RowIndex = 0;
        for (int rowIndex = 0; rowIndex < rowSize; ++rowIndex) {
            int glmSc1Iteration;
            for (int colIndex = 0; colIndex < glmColSize; ++colIndex) {
                combined.set(rowIndex, colIndex, glmSc1.get(rowIndex, colIndex));
            }
            if (sc2RowIndex >= rowSize2 || !scoreIterationList.contains(glmSc1Iteration = ((Integer)glmSc1.get(rowIndex, indexOfIter)).intValue())) continue;
            int sc2Index = scoreIterationList.indexOf(glmSc1Iteration);
            int earlyStopIteration = scoreIterationList.get(sc2Index);
            scoreIterationList.remove(sc2Index);
            if (glmSc1Iteration != earlyStopIteration) continue;
            for (int colIndex = 0; colIndex < earlyStopColSize; ++colIndex) {
                int trueColIndex = colIndex + glmColSize;
                combined.set(rowIndex, trueColIndex, earlyStopSc2.get(sc2RowIndex, earlyStopColIndices.get(colIndex).intValue()));
            }
            ++sc2RowIndex;
        }
        return combined;
    }

    public static void updateGradGam(double[] gradient, double[][][] penalty_mat, int[][] gamBetaIndices, double[] beta, int[] activeCols) {
        int numGamCol = gamBetaIndices.length;
        for (int gamColInd = 0; gamColInd < numGamCol; ++gamColInd) {
            int penaltyMatSize = penalty_mat[gamColInd].length;
            for (int betaInd = 0; betaInd < penaltyMatSize; ++betaInd) {
                int currentBetaIndex = gamBetaIndices[gamColInd][betaInd];
                if (activeCols != null) {
                    currentBetaIndex = ArrayUtils.find((int[])activeCols, (int)currentBetaIndex);
                }
                double tempGrad = 2.0 * beta[currentBetaIndex] * penalty_mat[gamColInd][betaInd][betaInd];
                for (int rowInd = 0; rowInd < penaltyMatSize; ++rowInd) {
                    if (rowInd == betaInd) continue;
                    int currBetaInd = gamBetaIndices[gamColInd][rowInd];
                    if (activeCols != null) {
                        currBetaInd = ArrayUtils.find((int[])activeCols, (int)currBetaInd);
                    }
                    tempGrad += beta[currBetaInd] * penalty_mat[gamColInd][betaInd][rowInd];
                }
                int n = currentBetaIndex;
                gradient[n] = gradient[n] + tempGrad;
            }
        }
    }

    public static void updateGradGamMultinomial(double[][] gradient, double[][][] penaltyMat, int[][] gamBetaIndices, double[][] beta) {
        int numClass = beta[0].length;
        int numGamCol = gamBetaIndices.length;
        for (int classInd = 0; classInd < numClass; ++classInd) {
            for (int gamInd = 0; gamInd < numGamCol; ++gamInd) {
                int numKnots = gamBetaIndices[gamInd].length;
                for (int rowInd = 0; rowInd < numKnots; ++rowInd) {
                    double temp = 0.0;
                    int betaIndR = gamBetaIndices[gamInd][rowInd];
                    for (int colInd = 0; colInd < numKnots; ++colInd) {
                        int betaIndC = gamBetaIndices[gamInd][colInd];
                        temp += betaIndC == betaIndR ? 2.0 * penaltyMat[gamInd][rowInd][colInd] * beta[betaIndC][classInd] : penaltyMat[gamInd][rowInd][colInd] * beta[betaIndC][classInd];
                    }
                    double[] dArray = gradient[betaIndR];
                    int n = classInd;
                    dArray[n] = dArray[n] + temp;
                }
            }
        }
    }

    public static double calSmoothNess(double[] beta, double[][][] penaltyMatrix, int[][] gamColIndices) {
        int numGamCols = gamColIndices.length;
        double smoothval = 0.0;
        for (int gamCol = 0; gamCol < numGamCols; ++gamCol) {
            smoothval += ArrayUtils.innerProductPartial((double[])beta, (int[])gamColIndices[gamCol], (double[])ArrayUtils.multArrVecPartial((double[][])penaltyMatrix[gamCol], (double[])beta, (int[])gamColIndices[gamCol]));
        }
        return smoothval;
    }

    public static double calSmoothNess(double[][] beta, double[][][] penaltyMatrix, int[][] gamColIndices) {
        int numClass = beta.length;
        double smoothval = 0.0;
        for (int classInd = 0; classInd < numClass; ++classInd) {
            smoothval += GLMUtils.calSmoothNess(beta[classInd], penaltyMatrix, gamColIndices);
        }
        return smoothval;
    }
}

