/*
 * Decompiled with CFR 0.152.
 */
package hex.glm;

import hex.DataInfo;
import hex.glm.ComputationState;
import hex.glm.DispersionTask;
import hex.glm.GLMModel;
import hex.glm.GLMTask;
import hex.glm.TweedieEstimator;
import hex.glm.TweedieMLDispersionOnly;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.commons.math3.special.Gamma;
import water.Job;
import water.Key;
import water.MRTask;
import water.Scope;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.Log;

public class DispersionUtils {
    public static double estimateGammaMLSE(GLMTask.ComputeGammaMLSETsk mlCT, double seOld, double[] beta, GLMModel.GLMParameters parms, ComputationState state, Job job, GLMModel model) {
        double constantValue = mlCT._wsum + mlCT._sumlnyiOui - mlCT._sumyiOverui;
        DataInfo dinfo = state.activeData();
        Frame adaptedF = dinfo._adaptedFrame;
        long currTime = System.currentTimeMillis();
        long modelBuiltTime = currTime - ((GLMModel.GLMOutput)model._output)._start_time;
        long timeLeft = parms._max_runtime_secs > 0.0 ? (long)(parms._max_runtime_secs * 1000.0 - (double)modelBuiltTime) : Long.MAX_VALUE;
        for (int index = 0; index < parms._max_iterations_dispersion; ++index) {
            GLMTask.ComputeDiTriGammaTsk ditrigammatsk = (GLMTask.ComputeDiTriGammaTsk)new GLMTask.ComputeDiTriGammaTsk(null, dinfo, job._key, beta, parms, seOld).doAll(adaptedF);
            double numerator = mlCT._wsum * Math.log(seOld) - ditrigammatsk._sumDigamma + constantValue;
            double denominator = mlCT._wsum / seOld - ditrigammatsk._sumTrigamma;
            double change = numerator / denominator;
            if (denominator == 0.0 || !Double.isFinite(change)) {
                return seOld;
            }
            if (Math.abs(change) < parms._dispersion_epsilon) {
                return seOld - change;
            }
            double se = seOld - change;
            seOld = se < 0.0 ? (seOld *= 0.5) : se;
            if (index % 100 != 0 || !job.stop_requested() && System.currentTimeMillis() - currTime <= timeLeft) continue;
            Log.warn((Object[])new Object[]{"gamma dispersion parameter estimation was interrupted by user or due to time out.  Estimation process has not converged. Increase your max_runtime_secs if you have set maximum runtime for your model building process."});
            return seOld;
        }
        Log.warn((Object[])new Object[]{"gamma dispersion parameter estimation fails to converge within " + parms._max_iterations_dispersion + " iterations.  Increase max_iterations_dispersion or decrease dispersion_epsilon."});
        return seOld;
    }

    private static double getTweedieLogLikelihood(GLMModel.GLMParameters parms, DataInfo dinfo, double phi, Vec mu) {
        double llh = new TweedieEstimator((double)parms._tweedie_variance_power, (double)phi, (boolean)false, (boolean)false, (boolean)false, (boolean)false).compute((Vec)mu, (Vec)dinfo._adaptedFrame.vec((String)parms._response_column), (Vec)(parms._weights_column == null ? dinfo._adaptedFrame.makeCompatible((Frame)new Frame((Vec[])new Vec[]{Vec.makeOne((long)dinfo._adaptedFrame.numRows())}))[0] : dinfo._adaptedFrame.vec((String)parms._weights_column)))._loglikelihood;
        Log.debug((Object[])new Object[]{"Tweedie LogLikelihood(p=" + parms._tweedie_variance_power + ", phi=" + phi + ") = " + llh});
        return llh;
    }

    private static double goldenRatioDispersionSearch(GLMModel.GLMParameters parms, DataInfo dinfo, Vec mu, List<Double> logLikelihoods, List<Double> phis, Job job) {
        int counter;
        List sortedPhis = phis.stream().sorted().collect(Collectors.toList());
        ArrayList<Double> sortedLLHs = new ArrayList<Double>();
        for (int i = 0; i < sortedPhis.size(); ++i) {
            double phi = (Double)sortedPhis.get(i);
            int index = phis.indexOf(phi);
            sortedLLHs.add(logLikelihoods.get(index));
        }
        boolean increasing = true;
        double lowerBound = 1.0E-16;
        double upperBound = (Double)sortedPhis.get(0);
        for (int i = 1; i < sortedPhis.size(); ++i) {
            upperBound = (Double)sortedPhis.get(i);
            if (!((Double)sortedLLHs.get(i - 1) > (Double)sortedLLHs.get(i))) continue;
            increasing = false;
            if (i > 2) {
                lowerBound = (Double)sortedPhis.get(i - 2);
                break;
            }
            sortedPhis.add(0, lowerBound);
            sortedLLHs.add(0, DispersionUtils.getTweedieLogLikelihood(parms, dinfo, lowerBound, mu));
            break;
        }
        int iterationsLeft = parms._max_iterations_dispersion - 10 * counter;
        for (counter = sortedPhis.size(); increasing && iterationsLeft > counter && !job.stop_requested(); ++counter) {
            sortedPhis.add(upperBound *= 2.0);
            double newLLH = DispersionUtils.getTweedieLogLikelihood(parms, dinfo, upperBound, mu);
            Log.debug((Object[])new Object[]{"Tweedie looking for the region containing the max. likelihood; upper bound = " + upperBound + "; llh = " + newLLH});
            sortedLLHs.add(newLLH);
            if (!((Double)sortedLLHs.get(counter - 2) > (Double)sortedLLHs.get(counter - 1))) continue;
            if (counter > 3) {
                lowerBound = (Double)sortedPhis.get(counter - 3);
            }
            Log.debug((Object[])new Object[]{"Tweedie found the region containing the max. likelihood; phi lower bound = " + lowerBound + "; phi upper bound = " + upperBound});
            break;
        }
        double d = (upperBound - lowerBound) * 0.618;
        double lowPhi = lowerBound;
        double hiPhi = upperBound;
        double midLoPhi = (Double)sortedPhis.get(counter - 2);
        double midLoLLH = (Double)sortedLLHs.get(counter - 2);
        if (midLoPhi > upperBound) {
            midLoPhi = hiPhi - d;
            midLoLLH = DispersionUtils.getTweedieLogLikelihood(parms, dinfo, midLoPhi, mu);
        }
        double midHiPhi = lowPhi + d;
        double midHiLLH = DispersionUtils.getTweedieLogLikelihood(parms, dinfo, midHiPhi, mu);
        while (counter < iterationsLeft) {
            Log.info((Object[])new Object[]{"Tweedie golden-section search[iter=" + counter + ", phis=(" + lowPhi + ", " + midLoPhi + ", " + midHiPhi + ", " + hiPhi + "), likelihoods=(..., " + midLoLLH + ", " + midHiLLH + ", ...)]"});
            if (job.stop_requested()) {
                return (hiPhi + lowPhi) / 2.0;
            }
            if (midHiLLH > midLoLLH) {
                lowPhi = midLoPhi;
            } else {
                hiPhi = midHiPhi;
            }
            d = (hiPhi - lowPhi) * 0.618;
            if (hiPhi - lowPhi < parms._dispersion_epsilon) {
                return (hiPhi + lowPhi) / 2.0;
            }
            midLoPhi = hiPhi - d;
            midHiPhi = lowPhi + d;
            midLoLLH = DispersionUtils.getTweedieLogLikelihood(parms, dinfo, midLoPhi, mu);
            midHiLLH = DispersionUtils.getTweedieLogLikelihood(parms, dinfo, midHiPhi, mu);
            ++counter;
        }
        return (hiPhi + lowPhi) / 2.0;
    }

    public static double estimateTweedieDispersionOnly(GLMModel.GLMParameters parms, GLMModel model, Job job, double[] beta, DataInfo dinfo) {
        long currTime = System.currentTimeMillis();
        long modelBuiltTime = currTime - ((GLMModel.GLMOutput)model._output)._start_time;
        long timeLeft = parms._max_runtime_secs > 0.0 ? (long)(parms._max_runtime_secs * 1000.0 - (double)modelBuiltTime) : Long.MAX_VALUE;
        TweedieMLDispersionOnly tDispersion = new TweedieMLDispersionOnly(parms.train(), parms, model, beta, dinfo);
        DispersionTask.GenPrediction gPred = (DispersionTask.GenPrediction)new DispersionTask.GenPrediction(beta, model, dinfo).doAll(1, (byte)3, dinfo._adaptedFrame);
        Vec mu = Scope.track((Frame[])new Frame[]{gPred.outputFrame(Key.make(), new String[]{"prediction"}, null)}).vec(0);
        double dispersionCurr = tDispersion._dispersionParameter;
        ArrayList<Double> loglikelihoodList = new ArrayList<Double>();
        ArrayList<Double> llChangeList = new ArrayList<Double>();
        ArrayList<Double> dispersionList = new ArrayList<Double>();
        double bestLogLikelihoodFromSanityCheck = DispersionUtils.getTweedieLogLikelihood(parms, dinfo, dispersionCurr, mu);
        ArrayList<Double> logLikelihoodSanityChecks = new ArrayList<Double>();
        ArrayList<Double> dispersionsSanityChecks = new ArrayList<Double>();
        logLikelihoodSanityChecks.add(bestLogLikelihoodFromSanityCheck);
        dispersionsSanityChecks.add(dispersionCurr);
        for (int index = 0; index < parms._max_iterations_dispersion; ++index) {
            double dispersionNew;
            tDispersion.updateDispersionP(dispersionCurr);
            DispersionTask.ComputeMaxSumSeriesTsk computeTask = new DispersionTask.ComputeMaxSumSeriesTsk(tDispersion, parms, true);
            computeTask.doAll(tDispersion._infoFrame);
            double logLLCurr = computeTask._logLL / (double)computeTask._nobsLL;
            loglikelihoodList.add(logLLCurr);
            dispersionList.add(dispersionCurr);
            if (loglikelihoodList.size() > 1) {
                boolean converged;
                llChangeList.add((Double)loglikelihoodList.get(index) - (Double)loglikelihoodList.get(index - 1));
                boolean bl = converged = Math.abs((Double)llChangeList.get(llChangeList.size() - 1)) < parms._dispersion_epsilon;
                if (index % 10 == 0 || converged) {
                    double newLogLikelihood = DispersionUtils.getTweedieLogLikelihood(parms, dinfo, dispersionCurr, mu);
                    logLikelihoodSanityChecks.add(newLogLikelihood);
                    dispersionsSanityChecks.add(dispersionCurr);
                    if (newLogLikelihood < bestLogLikelihoodFromSanityCheck) {
                        Log.info((Object[])new Object[]{"Tweedie sanity check FAIL. Trying Golden-section search instead of Newton's method."});
                        tDispersion.cleanUp();
                        double dispersion = DispersionUtils.goldenRatioDispersionSearch(parms, dinfo, mu, logLikelihoodSanityChecks, dispersionsSanityChecks, job);
                        Log.info((Object[])new Object[]{"Tweedie dispersion estimate = " + dispersion});
                        return dispersion;
                    }
                    bestLogLikelihoodFromSanityCheck = Math.max(bestLogLikelihoodFromSanityCheck, newLogLikelihood);
                    Log.debug((Object[])new Object[]{"Tweedie sanity check OK"});
                }
                if (converged) {
                    tDispersion.cleanUp();
                    Log.info((Object[])new Object[]{"last dispersion " + dispersionCurr});
                    return (Double)dispersionList.get(loglikelihoodList.indexOf(Collections.max(loglikelihoodList)));
                }
            }
            if (loglikelihoodList.size() > 10 && loglikelihoodList.stream().skip(loglikelihoodList.size() - 3).noneMatch(x -> x != null && Double.isFinite(x))) {
                Log.warn((Object[])new Object[]{"tweedie dispersion parameter estimation got stuck in numerically unstable region."});
                tDispersion.cleanUp();
                return Double.NaN;
            }
            double update = computeTask._dLogLL / computeTask._d2LogLL;
            if (Math.abs(update) < 0.001) {
                update = DispersionUtils.dispersionLS(computeTask, tDispersion, parms);
                if (!Double.isFinite(update)) {
                    Log.info((Object[])new Object[]{"last dispersion " + dispersionCurr});
                    return (Double)dispersionList.get(loglikelihoodList.indexOf(Collections.max(loglikelihoodList)));
                }
                dispersionNew = dispersionCurr - update;
            } else {
                dispersionNew = dispersionCurr - update;
                if (dispersionNew < 0.0) {
                    dispersionNew = dispersionCurr * 0.5;
                }
                tDispersion.updateDispersionP(dispersionNew);
                DispersionTask.ComputeMaxSumSeriesTsk computeTaskNew = new DispersionTask.ComputeMaxSumSeriesTsk(tDispersion, parms, false);
                computeTaskNew.doAll(tDispersion._infoFrame);
                double logLLNext = computeTaskNew._logLL / (double)computeTaskNew._nobsLL;
                if (logLLNext <= logLLCurr) {
                    dispersionNew = dispersionCurr + parms._dispersion_learning_rate * update;
                }
            }
            dispersionCurr = dispersionNew < 0.0 ? (dispersionCurr *= 0.5) : dispersionNew;
            if (index % 100 != 0 || !job.stop_requested() && System.currentTimeMillis() - currTime <= timeLeft) continue;
            Log.warn((Object[])new Object[]{"tweedie dispersion parameter estimation was interrupted by user or due to time out.  Estimation process has not converged. Increase your max_runtime_secs if you have set maximum runtime for your model building process."});
            tDispersion.cleanUp();
            Log.info((Object[])new Object[]{"last dispersion " + dispersionCurr});
            return (Double)dispersionList.get(loglikelihoodList.indexOf(Collections.max(loglikelihoodList)));
        }
        tDispersion.cleanUp();
        if (dispersionList.size() > 0) {
            Log.info((Object[])new Object[]{"last dispersion " + dispersionCurr});
            return (Double)dispersionList.get(loglikelihoodList.indexOf(Collections.max(loglikelihoodList)));
        }
        return dispersionCurr;
    }

    public static double estimateNegBinomialDispersionMomentMethod(GLMModel model, double[] beta, DataInfo dinfo, Vec weights, Vec response, Vec mu) {
        class MomentMethodThetaEstimation
        extends MRTask<MomentMethodThetaEstimation> {
            double _muSqSum;
            double _sSqSum;
            double _muSum;
            double _wSum;

            MomentMethodThetaEstimation() {
            }

            public void map(Chunk[] cs) {
                for (int i = 0; i < cs[0]._len; ++i) {
                    double w = cs[2].atd(i);
                    this._muSqSum += w * Math.pow(cs[0].atd(i), 2.0);
                    this._sSqSum += w * Math.pow(cs[1].atd(i) - cs[0].atd(i), 2.0);
                    this._muSum += w * cs[0].atd(i);
                    this._wSum += w;
                }
            }

            public void reduce(MomentMethodThetaEstimation mrt) {
                this._muSqSum += mrt._muSqSum;
                this._sSqSum += mrt._sSqSum;
                this._muSum += mrt._muSum;
                this._wSum += mrt._wSum;
            }
        }
        MomentMethodThetaEstimation mm = (MomentMethodThetaEstimation)new MomentMethodThetaEstimation().doAll(new Vec[]{mu, response, weights});
        return mm._muSqSum / (mm._sSqSum - mm._muSum / mm._wSum);
    }

    public static double estimateNegBinomialDispersionFisherScoring(GLMModel.GLMParameters parms, GLMModel model, double[] beta, DataInfo dinfo) {
        int i;
        Vec weights = dinfo._weights ? dinfo.getWeightsVec() : dinfo._adaptedFrame.makeCompatible(new Frame(new Vec[]{Vec.makeOne((long)dinfo._adaptedFrame.numRows())}))[0];
        double nRows = weights == null ? (double)dinfo._adaptedFrame.numRows() : weights.mean() * (double)weights.length();
        DispersionTask.GenPrediction gPred = (DispersionTask.GenPrediction)new DispersionTask.GenPrediction(beta, model, dinfo).doAll(1, (byte)3, dinfo._adaptedFrame);
        Vec mu = gPred.outputFrame(Key.make(), new String[]{"prediction"}, null).vec(0);
        Vec response = dinfo._adaptedFrame.vec(dinfo.responseChunkId(0));
        double invTheta = nRows / ((CalculateInitialTheta)new CalculateInitialTheta().doAll((Vec[])new Vec[]{mu, response, weights}))._theta0;
        double delta = 1.0;
        for (i = 0; i < parms._max_iterations_dispersion && !(Math.abs(delta) < parms._dispersion_epsilon); ++i) {
            invTheta = Math.abs(invTheta);
            CalculateNegativeBinomialScoreAndInfo si = (CalculateNegativeBinomialScoreAndInfo)new CalculateNegativeBinomialScoreAndInfo(invTheta).doAll(new Vec[]{mu, response, weights});
            delta = si._score / si._info;
            invTheta += delta;
        }
        if (invTheta < 0.0) {
            Log.warn((Object[])new Object[]{"Dispersion estimate truncated at zero."});
        }
        if (i == parms._max_iterations_dispersion) {
            Log.warn((Object[])new Object[]{"Iteration limit reached."});
        }
        return 1.0 / invTheta;
    }

    public static double dispersionLS(DispersionTask.ComputeMaxSumSeriesTsk computeTsk, TweedieMLDispersionOnly tDispersion, GLMModel.GLMParameters parms) {
        double currObj = Double.NEGATIVE_INFINITY;
        double dispersionCurr = tDispersion._dispersionParameter;
        double update = computeTsk._dLogLL / computeTsk._d2LogLL;
        for (int index = 0; index < parms._max_iterations_dispersion; ++index) {
            double newObj;
            if (Double.isFinite(update)) {
                double dispersionNew = dispersionCurr - update;
                tDispersion.updateDispersionP(dispersionNew);
                DispersionTask.ComputeMaxSumSeriesTsk computeTskNew = (DispersionTask.ComputeMaxSumSeriesTsk)new DispersionTask.ComputeMaxSumSeriesTsk(tDispersion, parms, false).doAll(tDispersion._infoFrame);
                newObj = computeTskNew._logLL / (double)computeTskNew._nobsLL;
                if (!(newObj > currObj)) {
                    return update;
                }
            } else {
                return Double.NaN;
            }
            currObj = newObj;
            update = 2.0 * update;
        }
        return update;
    }

    public static double[] makeZeros(double[] sourceCoeffs, double[] targetCoeffs) {
        int size = targetCoeffs.length;
        for (int valInd = 0; valInd < size; ++valInd) {
            targetCoeffs[valInd] = targetCoeffs[valInd] - sourceCoeffs[valInd];
        }
        return targetCoeffs;
    }

    static class CalculateInitialTheta
    extends MRTask<CalculateInitialTheta> {
        double _theta0;

        CalculateInitialTheta() {
        }

        public void map(Chunk[] cs) {
            for (int i = 0; i < cs[0]._len; ++i) {
                this._theta0 += cs[2].atd(i) * Math.pow(cs[1].atd(i) / cs[0].atd(i) - 1.0, 2.0);
            }
        }

        public void reduce(CalculateInitialTheta mrt) {
            this._theta0 += mrt._theta0;
        }
    }

    static class CalculateNegativeBinomialScoreAndInfo
    extends MRTask<CalculateNegativeBinomialScoreAndInfo> {
        double _score;
        double _info;
        double _theta;

        CalculateNegativeBinomialScoreAndInfo(double theta) {
            this._theta = theta;
        }

        public void map(Chunk[] cs) {
            for (int i = 0; i < cs[0]._len; ++i) {
                double w = cs[2].atd(i);
                this._score += w * (Gamma.digamma((double)(this._theta + cs[1].atd(i))) - Gamma.digamma((double)this._theta) + Math.log(this._theta) + 1.0 - Math.log(this._theta + cs[0].atd(i)) - (cs[1].atd(i) + this._theta) / (cs[0].atd(i) + this._theta));
                this._info += w * (-Gamma.trigamma((double)(this._theta + cs[1].atd(i))) + Gamma.trigamma((double)this._theta) - 1.0 / this._theta + 2.0 / (cs[0].atd(i) + this._theta) - (cs[1].atd(i) + this._theta) / Math.pow(cs[0].atd(i) + this._theta, 2.0));
            }
        }

        public void reduce(CalculateNegativeBinomialScoreAndInfo mrt) {
            this._score += mrt._score;
            this._info += mrt._info;
        }
    }

    static class NegativeBinomialGradientAndHessian
    extends MRTask<NegativeBinomialGradientAndHessian> {
        double _grad;
        double _hess;
        double _theta;
        double _invTheta;
        double _invThetaSq;
        double _llh;

        NegativeBinomialGradientAndHessian(double theta) {
            assert (theta > 0.0);
            this._theta = theta;
            this._invTheta = 1.0 / theta;
            this._invThetaSq = this._invTheta * this._invTheta;
        }

        public void map(Chunk[] cs) {
            for (int i = 0; i < cs[0]._len; ++i) {
                double mu = cs[0].atd(i);
                double y = cs[1].atd(i);
                double w = cs[2].atd(i);
                this._grad += w * (-mu * (y + this._invTheta) / (mu * this._theta + 1.0) + (y + (Math.log(mu * this._theta + 1.0) - Gamma.digamma((double)(y + this._invTheta)) + Gamma.digamma((double)this._invTheta)) * this._invTheta) * this._invTheta);
                this._hess += w * (mu * mu * (y + this._invTheta) / Math.pow(mu * this._theta + 1.0, 2.0) + (-y + 2.0 * mu / (mu * this._theta + 1.0) + (-2.0 * Math.log(mu * this._theta + 1.0) + 2.0 * Gamma.digamma((double)(y + this._invTheta)) - 2.0 * Gamma.digamma((double)this._invTheta) + (Gamma.trigamma((double)(y + this._invTheta)) - Gamma.trigamma((double)this._invTheta)) * this._invTheta) * this._invTheta) * this._invThetaSq);
                this._llh += Gamma.logGamma((double)(y + this._invTheta)) - Gamma.logGamma((double)this._invTheta) - Gamma.logGamma((double)(y + 1.0)) + y * Math.log(this._theta * mu) - (y + this._invTheta) * Math.log(1.0 + this._theta * mu);
            }
        }

        public void reduce(NegativeBinomialGradientAndHessian mrt) {
            this._grad += mrt._grad;
            this._hess += mrt._hess;
            this._llh += mrt._llh;
        }
    }
}

