/*
 * Decompiled with CFR 0.152.
 */
package hex.api;

import hex.DataInfo;
import hex.Model;
import hex.glm.GLMModel;
import hex.gram.Gram;
import hex.schemas.DataInfoFrameV3;
import hex.schemas.GLMModelV3;
import hex.schemas.GLMRegularizationPathV3;
import hex.schemas.GramV3;
import hex.schemas.MakeGLMModelV3;
import java.util.Arrays;
import java.util.HashMap;
import water.DKV;
import water.Key;
import water.MRTask;
import water.api.Handler;
import water.api.schemas3.KeyV3;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.InteractionWrappedVec;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.util.ArrayUtils;

public class MakeGLMModelHandler
extends Handler {
    public GLMModelV3 make_model(int version, MakeGLMModelV3 args) {
        GLMModel model = (GLMModel)DKV.getGet(args.model.key());
        if (model == null) {
            throw new IllegalArgumentException("missing source model " + args.model);
        }
        boolean multiClass = ((GLMModel.GLMOutput)model._output)._multinomial || ((GLMModel.GLMOutput)model._output)._ordinal;
        CharSequence[] names = multiClass ? ((GLMModel.GLMOutput)model._output).multiClassCoeffNames() : ((GLMModel.GLMOutput)model._output).coefficientNames();
        HashMap<String, Double> coefs = model.coefficients();
        if (args.beta.length != names.length) {
            throw new IllegalArgumentException("model coefficient length " + names.length + " is different from coefficient provided by user " + args.beta.length + ".\n model coefficients needed are:\n" + String.join((CharSequence)"\n", names));
        }
        for (int i2 = 0; i2 < args.names.length; ++i2) {
            coefs.put(args.names[i2], args.beta[i2]);
        }
        double[] beta = (double[])model.beta().clone();
        for (int i3 = 0; i3 < beta.length; ++i3) {
            beta[i3] = (Double)coefs.get(names[i3]);
        }
        GLMModel m4 = new GLMModel(args.dest != null ? args.dest.key() : Key.make(), (GLMModel.GLMParameters)model._parms, null, model._ymu, Double.NaN, Double.NaN, -1L);
        m4.setInputParms(model._input_parms);
        DataInfo dinfo = model.dinfo();
        dinfo.setPredictorTransform(DataInfo.TransformType.NONE);
        m4._output = new GLMModel.GLMOutput(model.dinfo(), ((GLMModel.GLMOutput)model._output)._names, ((GLMModel.GLMOutput)model._output)._column_types, ((GLMModel.GLMOutput)model._output)._domains, ((GLMModel.GLMOutput)model._output).coefficientNames(), beta, ((GLMModel.GLMOutput)model._output)._binomial, ((GLMModel.GLMOutput)model._output)._multinomial, ((GLMModel.GLMOutput)model._output)._ordinal);
        DKV.put(m4._key, m4);
        GLMModelV3 res = new GLMModelV3();
        res.fillFromImpl(m4);
        return res;
    }

    public GLMRegularizationPathV3 extractRegularizationPath(int v2, GLMRegularizationPathV3 args) {
        GLMModel model = (GLMModel)DKV.getGet(args.model.key());
        if (model == null) {
            throw new IllegalArgumentException("missing source model " + args.model);
        }
        return (GLMRegularizationPathV3)new GLMRegularizationPathV3().fillFromImpl(model.getRegularizationPath());
    }

    public DataInfoFrameV3 getDataInfoFrame(int version, DataInfoFrameV3 args) {
        Frame fr = (Frame)DKV.getGet(args.frame.key());
        if (null == fr) {
            throw new IllegalArgumentException("no frame found");
        }
        args.result = new KeyV3.FrameKeyV3((Key<Frame>)MakeGLMModelHandler.oneHot((Frame)fr, (Model.InteractionSpec)Model.InteractionSpec.allPairwise((String[])args.interactions), (boolean)args.use_all, (boolean)args.standardize, (boolean)args.interactions_only, (boolean)true)._key);
        return args;
    }

    public static Frame oneHot(Frame fr, Model.InteractionSpec interactions, boolean useAll, boolean standardize, boolean interactionsOnly, final boolean skipMissing) {
        Frame res;
        final DataInfo dinfo = new DataInfo(fr, null, 1, useAll, standardize ? DataInfo.TransformType.STANDARDIZE : DataInfo.TransformType.NONE, DataInfo.TransformType.NONE, skipMissing, false, false, false, false, false, interactions);
        if (interactionsOnly) {
            if (null == dinfo._interactionVecs) {
                throw new IllegalArgumentException("no interactions");
            }
            int noutputs = 0;
            final int[] colIds = new int[dinfo._interactionVecs.length];
            final int[] offsetIds = new int[dinfo._interactionVecs.length];
            int idx = 0;
            String[] coefNames = dinfo.coefNames();
            for (int i2 : dinfo._interactionVecs) {
                int n2 = idx++;
                int n3 = ((InteractionWrappedVec)dinfo._adaptedFrame.vec(i2)).expandedLength();
                offsetIds[n2] = n3;
                noutputs += n3;
            }
            String[] names = new String[noutputs];
            idx = 0;
            int offset = 0;
            int namesIdx = 0;
            for (int i2 = 0; i2 < dinfo._adaptedFrame.numCols(); ++i2) {
                Vec v2 = dinfo._adaptedFrame.vec(i2);
                if (v2 instanceof InteractionWrappedVec) {
                    colIds[idx] = offset;
                    for (int nid = 0; nid < offsetIds[idx]; ++nid) {
                        names[namesIdx++] = coefNames[offset++];
                    }
                    if (++idx <= dinfo._interactionVecs.length) continue;
                    break;
                }
                if (v2.isCategorical()) {
                    offset += v2.domain().length - (useAll ? 0 : 1);
                    continue;
                }
                ++offset;
            }
            res = ((MRTask)new MRTask(){

                @Override
                public void map(Chunk[] cs, NewChunk[] ncs) {
                    DataInfo.Row r2 = dinfo.newDenseRow();
                    for (int i2 = 0; i2 < cs[0]._len; ++i2) {
                        r2 = dinfo.extractDenseRow(cs, i2, r2);
                        if (skipMissing && r2.isBad()) continue;
                        int newChkIdx = 0;
                        for (int idx = 0; idx < colIds.length; ++idx) {
                            int startOffset;
                            for (int start = startOffset = colIds[idx]; start < startOffset + offsetIds[idx]; ++start) {
                                ncs[newChkIdx++].addNum(r2.get(start));
                            }
                        }
                    }
                }
            }.doAll(noutputs, (byte)3, dinfo._adaptedFrame)).outputFrame(Key.make(), names, null);
        } else {
            byte[] types2 = new byte[dinfo.fullN()];
            Arrays.fill(types2, (byte)3);
            res = ((MRTask)new MRTask(){

                @Override
                public void map(Chunk[] cs, NewChunk[] ncs) {
                    DataInfo.Row r2 = dinfo.newDenseRow();
                    for (int i2 = 0; i2 < cs[0]._len; ++i2) {
                        r2 = dinfo.extractDenseRow(cs, i2, r2);
                        if (skipMissing && r2.isBad()) continue;
                        for (int n2 = 0; n2 < ncs.length; ++n2) {
                            ncs[n2].addNum(r2.get(n2));
                        }
                    }
                }
            }.doAll(types2, dinfo._adaptedFrame.vecs())).outputFrame(Key.make("OneHot" + Key.make().toString()), dinfo.coefNames(), null);
        }
        dinfo.dropInteractions();
        dinfo.remove();
        return res;
    }

    public GramV3 computeGram(int v2, GramV3 input) {
        Key<Frame> k2;
        if (DKV.get(input.X.key()) == null) {
            throw new IllegalArgumentException("Frame " + input.X.key() + " does not exist.");
        }
        Frame fr = (Frame)input.X.key().get();
        Frame frcpy = new Frame((String[])fr._names.clone(), (Vec[])fr.vecs().clone());
        String wname = null;
        Vec weight = null;
        if (input.W != null && !input.W.column_name.isEmpty()) {
            wname = input.W.column_name;
            if (fr.find(wname) == -1) {
                throw new IllegalArgumentException("Did not find weight vector " + wname);
            }
            weight = frcpy.remove(wname);
        }
        DataInfo dinfo = new DataInfo(frcpy, null, 0, input.use_all_factor_levels, input.standardize ? DataInfo.TransformType.STANDARDIZE : DataInfo.TransformType.NONE, DataInfo.TransformType.NONE, input.skip_missing, false, !input.skip_missing, false, false, false, true);
        DKV.put(dinfo);
        if (weight != null) {
            dinfo.setWeights(wname, weight);
        }
        Gram.GramTask gt = (Gram.GramTask)new Gram.GramTask(null, dinfo, false, true).doAll(dinfo._adaptedFrame);
        double[][] gram = gt._gram.getXX();
        dinfo.remove();
        String[] names = ArrayUtils.append(dinfo.coefNames(), "Intercept");
        Vec[] vecs = new Vec[gram.length];
        Key<Vec>[] keys = new Vec.VectorGroup().addVecs(vecs.length);
        for (int i2 = 0; i2 < vecs.length; ++i2) {
            vecs[i2] = Vec.makeVec(gram[i2], keys[i2]);
        }
        input.destination_frame = new KeyV3.FrameKeyV3();
        String keyname = input.X.key().toString();
        if (keyname.endsWith(".hex")) {
            keyname = keyname.substring(0, keyname.lastIndexOf("."));
        }
        keyname = keyname + "_gram";
        if (weight != null) {
            keyname = keyname + "_" + wname;
        }
        if (DKV.get(k2 = Key.make(keyname)) != null) {
            int cnt;
            for (cnt = 0; cnt < 1000 && DKV.get(k2 = Key.make(keyname + "_" + cnt)) != null; ++cnt) {
            }
            if (cnt == 1000) {
                throw new IllegalArgumentException("unable to make unique key");
            }
        }
        input.destination_frame.fillFromImpl(k2);
        DKV.put(new Frame(k2, names, vecs));
        return input;
    }
}

