/*
 * Decompiled with CFR 0.152.
 */
package org.broadinstitute.hellbender.tools.copynumber.models;

import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.math3.exception.MaxCountExceededException;
import org.apache.commons.math3.special.Beta;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.broadinstitute.hellbender.tools.copynumber.formats.CopyNumberFormatsUtils;
import org.broadinstitute.hellbender.tools.copynumber.formats.records.AllelicCount;
import org.broadinstitute.hellbender.tools.copynumber.models.AlleleFractionGlobalParameters;
import org.broadinstitute.hellbender.tools.copynumber.models.AlleleFractionLikelihoods;
import org.broadinstitute.hellbender.tools.copynumber.models.AlleleFractionParameter;
import org.broadinstitute.hellbender.tools.copynumber.models.AlleleFractionSegmentedData;
import org.broadinstitute.hellbender.tools.copynumber.models.AlleleFractionState;
import org.broadinstitute.hellbender.utils.OptimizationUtils;
import org.broadinstitute.hellbender.utils.Utils;

final class AlleleFractionInitializer {
    private static final Logger logger = LogManager.getLogger(AlleleFractionInitializer.class);
    private static final double INITIAL_OUTLIER_PROBABILITY = 0.01;
    private static final double INITIAL_MEAN_BIAS = 1.0;
    private static final double INITIAL_BIAS_VARIANCE = 0.05;
    private static final AlleleFractionGlobalParameters INITIAL_GLOBAL_PARAMETERS = new AlleleFractionGlobalParameters(1.0, 0.05, 0.01);
    private static final double LOG_LIKELIHOOD_CONVERGENCE_THRESHOLD = 0.5;
    private static final int MAX_ITERATIONS = 50;
    static final double MAX_REASONABLE_OUTLIER_PROBABILITY = 0.15;
    static final double MAX_REASONABLE_MEAN_BIAS = 5.0;
    static final double MAX_REASONABLE_BIAS_VARIANCE = 0.5;
    private static final double EPSILON_FOR_NEAR_MAX_WARNING = 0.01;
    private static final double MAX_MINOR_ALLELE_FRACTION = 0.5;
    private final AlleleFractionSegmentedData data;
    private AlleleFractionGlobalParameters globalParameters;
    private AlleleFractionState.MinorFractions minorFractions;

    AlleleFractionInitializer(AlleleFractionSegmentedData data) {
        double previousIterationLogLikelihood;
        this.data = Utils.nonNull(data);
        this.globalParameters = INITIAL_GLOBAL_PARAMETERS;
        this.minorFractions = this.calculateInitialMinorFractions(data);
        double nextIterationLogLikelihood = Double.NEGATIVE_INFINITY;
        logger.info(String.format("Initializing allele-fraction model, iterating until log likelihood converges to within %s...", CopyNumberFormatsUtils.formatDouble(0.5)));
        int iteration = 1;
        do {
            previousIterationLogLikelihood = nextIterationLogLikelihood;
            this.globalParameters = new AlleleFractionGlobalParameters(this.estimateMeanBias(), this.estimateBiasVariance(), this.estimateOutlierProbability());
            this.minorFractions = this.estimateMinorFractions();
            nextIterationLogLikelihood = AlleleFractionLikelihoods.logLikelihood(this.globalParameters, this.minorFractions, data);
            logger.info(String.format("Iteration %d, model log likelihood = %s...", iteration, CopyNumberFormatsUtils.formatDouble(nextIterationLogLikelihood)));
            logger.info((Object)this.globalParameters);
        } while (++iteration < 50 && nextIterationLogLikelihood - previousIterationLogLikelihood > 0.5);
        AlleleFractionInitializer.warnIfNearMax(AlleleFractionParameter.MEAN_BIAS.name, this.globalParameters.getMeanBias(), 5.0, 0.01);
        AlleleFractionInitializer.warnIfNearMax(AlleleFractionParameter.BIAS_VARIANCE.name, this.globalParameters.getBiasVariance(), 0.5, 0.01);
        AlleleFractionInitializer.warnIfNearMax(AlleleFractionParameter.OUTLIER_PROBABILITY.name, this.globalParameters.getOutlierProbability(), 0.15, 0.01);
    }

    private static void warnIfNearMax(String parameterName, double value, double maxValue, double epsilon) {
        if (maxValue - value < epsilon) {
            logger.warn(String.format("The maximum-likelihood estimate for the global parameter %s (%s) was near its boundary (%s), the model is likely not a good fit to the data!  Consider changing parameters for filtering homozygous sites.", parameterName, CopyNumberFormatsUtils.formatDouble(value), CopyNumberFormatsUtils.formatDouble(maxValue)));
        }
    }

    AlleleFractionState getInitializedState() {
        return new AlleleFractionState(this.globalParameters.getMeanBias(), this.globalParameters.getBiasVariance(), this.globalParameters.getOutlierProbability(), this.minorFractions);
    }

    private AlleleFractionState.MinorFractions calculateInitialMinorFractions(AlleleFractionSegmentedData data) {
        int numSegments = data.getNumSegments();
        AlleleFractionState.MinorFractions result = new AlleleFractionState.MinorFractions(numSegments);
        for (int segmentIndex = 0; segmentIndex < numSegments; ++segmentIndex) {
            double responsibilityWeightedMinorAlleleReadCount = 0.0;
            double responsibilityWeightedTotalReadCount = 0.0;
            for (AllelicCount allelicCount : data.getIndexedAllelicCountsInSegment(segmentIndex)) {
                double altMinorResponsibility;
                int a = allelicCount.getAltReadCount();
                int r = allelicCount.getRefReadCount();
                try {
                    altMinorResponsibility = Beta.regularizedBeta((double)0.5, (double)(a + 1), (double)(r + 1));
                }
                catch (MaxCountExceededException e) {
                    altMinorResponsibility = a < r ? 1.0 : 0.0;
                }
                responsibilityWeightedMinorAlleleReadCount += altMinorResponsibility * (double)a + (1.0 - altMinorResponsibility) * (double)r;
                responsibilityWeightedTotalReadCount += (double)(a + r);
            }
            result.add((responsibilityWeightedMinorAlleleReadCount + 1.0) / (responsibilityWeightedTotalReadCount + 2.0));
        }
        return result;
    }

    private double estimateOutlierProbability() {
        Function<Double, Double> objective = outlierProbability -> AlleleFractionLikelihoods.logLikelihood(this.globalParameters.copyWithNewOutlierProbability((double)outlierProbability), this.minorFractions, this.data);
        return OptimizationUtils.argmax(objective, 0.0, 0.15, this.globalParameters.getOutlierProbability());
    }

    private double estimateMeanBias() {
        Function<Double, Double> objective = meanBias -> AlleleFractionLikelihoods.logLikelihood(this.globalParameters.copyWithNewMeanBias((double)meanBias), this.minorFractions, this.data);
        return OptimizationUtils.argmax(objective, 0.0, 5.0, this.globalParameters.getMeanBias());
    }

    private double estimateBiasVariance() {
        Function<Double, Double> objective = biasVariance -> AlleleFractionLikelihoods.logLikelihood(this.globalParameters.copyWithNewBiasVariance((double)biasVariance), this.minorFractions, this.data);
        return OptimizationUtils.argmax(objective, 0.0, 0.5, this.globalParameters.getBiasVariance());
    }

    private double estimateMinorFraction(int segment) {
        Function<Double, Double> objective = minorFraction -> AlleleFractionLikelihoods.segmentLogLikelihood(this.globalParameters, minorFraction, this.data.getIndexedAllelicCountsInSegment(segment));
        return OptimizationUtils.argmax(objective, 0.0, 0.5, (Double)this.minorFractions.get(segment));
    }

    private AlleleFractionState.MinorFractions estimateMinorFractions() {
        return new AlleleFractionState.MinorFractions(IntStream.range(0, this.data.getNumSegments()).boxed().map(this::estimateMinorFraction).collect(Collectors.toList()));
    }
}

