/*
 * Decompiled with CFR 0.152.
 */
package org.monarchinitiative.phenol.analysis.mgsa;

import java.util.Map;
import java.util.Objects;
import java.util.Random;
import java.util.Set;
import org.monarchinitiative.phenol.analysis.DirectAndIndirectTermAnnotations;
import org.monarchinitiative.phenol.analysis.GoAssociationContainer;
import org.monarchinitiative.phenol.analysis.StudySet;
import org.monarchinitiative.phenol.analysis.mgsa.DoubleParam;
import org.monarchinitiative.phenol.analysis.mgsa.FixedAlphaBetaScore;
import org.monarchinitiative.phenol.analysis.mgsa.IntegerParam;
import org.monarchinitiative.phenol.analysis.mgsa.MgsaGOTermResult;
import org.monarchinitiative.phenol.analysis.mgsa.MgsaGOTermsResultContainer;
import org.monarchinitiative.phenol.analysis.mgsa.MgsaParam;
import org.monarchinitiative.phenol.analysis.mgsa.TermToItemMatrix;
import org.monarchinitiative.phenol.ontology.data.Ontology;
import org.monarchinitiative.phenol.ontology.data.TermId;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MgsaCalculation {
    private static final Logger logger = LoggerFactory.getLogger((String)MgsaCalculation.class.getName());
    private final long seed;
    private boolean usePrior = true;
    private boolean integrateParams = false;
    private final DoubleParam alpha = new DoubleParam(MgsaParam.Type.MCMC);
    private final DoubleParam beta = new DoubleParam(MgsaParam.Type.MCMC);
    private final IntegerParam expectedNumberOfTerms = new IntegerParam(MgsaParam.Type.MCMC);
    private boolean takePopulationAsReference = false;
    private boolean randomStart = false;
    private static final int DEFAULT_MCMCSTEPS = 250000;
    private final int mcmcSteps;
    private final int burnin = 20000;
    private int updateReportTime = 1000;
    private final GoAssociationContainer goAssociations;
    private final TermToItemMatrix termToItemMatrix;
    private final Ontology ontology;
    private final StudySet populationSet;

    public MgsaCalculation(Ontology ontology, GoAssociationContainer goAssociations, int mcmcSteps, long seed) {
        this.ontology = ontology;
        this.goAssociations = goAssociations;
        this.mcmcSteps = mcmcSteps;
        this.seed = seed;
        Objects.requireNonNull(goAssociations);
        System.err.println("MgsaCalculation, mcsc steps " + mcmcSteps);
        this.termToItemMatrix = new TermToItemMatrix(goAssociations);
        Set<TermId> allAnnotatedGenes = goAssociations.getAllAnnotatedGenes();
        Map<TermId, DirectAndIndirectTermAnnotations> assocs = goAssociations.getAssociationMap(allAnnotatedGenes);
        this.populationSet = StudySet.populationSet(assocs);
    }

    public MgsaCalculation(Ontology ontology, GoAssociationContainer goAssociations, int mcmcSteps) {
        this(ontology, goAssociations, mcmcSteps, new Random().nextLong());
    }

    public int getPopulationSetCount() {
        return this.populationSet.getAnnotatedItemCount();
    }

    public void setAlpha(double alpha) {
        if (alpha < 1.0E-6) {
            alpha = 1.0E-6;
        }
        if (alpha > 0.999999) {
            alpha = 0.999999;
        }
        this.alpha.setValue(alpha);
    }

    public void setBeta(double beta) {
        if (beta < 1.0E-6) {
            beta = 1.0E-6;
        }
        if (beta > 0.999999) {
            beta = 0.999999;
        }
        this.beta.setValue(beta);
    }

    public void setAlpha(MgsaParam.Type alpha) {
        this.alpha.setType(alpha);
    }

    public void setAlphaBounds(double min, double max) {
        this.alpha.setMin(min);
        this.alpha.setMax(max);
    }

    public void setBeta(MgsaParam.Type beta) {
        this.beta.setType(beta);
    }

    public void setBetaBounds(double min, double max) {
        this.beta.setMin(min);
        this.beta.setMax(max);
    }

    public void setExpectedNumber(int expectedNumber) {
        this.expectedNumberOfTerms.setValue(expectedNumber);
    }

    public void setExpectedNumber(MgsaParam.Type type) {
        this.expectedNumberOfTerms.setType(type);
    }

    public void setIntegrateParams(boolean integrateParams) {
        this.integrateParams = integrateParams;
    }

    public void setTakePopulationAsReference(boolean takePopulationAsReference) {
        this.takePopulationAsReference = takePopulationAsReference;
    }

    public void useRandomStart(boolean randomStart) {
        this.randomStart = randomStart;
    }

    public void setUpdateReportTime(int updateReportTime) {
        this.updateReportTime = updateReportTime;
    }

    public MgsaGOTermsResultContainer calculateStudySet(StudySet studySet) {
        MgsaGOTermsResultContainer result = new MgsaGOTermsResultContainer(this.ontology, this.goAssociations, studySet, this.getPopulationSetCount());
        if (studySet.getAnnotatedItemCount() == 0) {
            System.err.println("[WARNING] Study set empty! Returning specious result");
            return result;
        }
        logger.info("Starting calculation: expectedNumberOfTerms=" + String.valueOf(this.expectedNumberOfTerms) + " alpha=" + String.valueOf(this.alpha) + " beta=" + String.valueOf(this.beta) + " numberOfPop=" + this.getPopulationSetCount() + " numberOfStudy=" + studySet.getAnnotatedItemCount());
        this.expectedNumberOfTerms.setValue(10);
        long start = System.currentTimeMillis();
        this.calculateByMCMC(result, studySet);
        long end = System.currentTimeMillis();
        logger.info(end - start + "ms");
        return result;
    }

    public void setUsePrior(boolean usePrior) {
        this.usePrior = usePrior;
    }

    private void calculateByMCMC(MgsaGOTermsResultContainer result, StudySet studySet) {
        Objects.requireNonNull(this.goAssociations);
        TermToItemMatrix calcUtils = new TermToItemMatrix(this.goAssociations);
        int[][] termLinks = calcUtils.getTermLinks();
        boolean[] observedItems = calcUtils.getBooleanArrayobservedItems(studySet.getGeneSet());
        double[] marginalProbabilities = this.calculate(termLinks, observedItems);
        for (int i = 0; i < marginalProbabilities.length; ++i) {
            TermId tid = calcUtils.getGoTermAtIndex(i);
            MgsaGOTermResult prop = new MgsaGOTermResult(tid, calcUtils.getAnnotatedGeneCount(tid), this.populationSet.getAnnotatedItemCount(), marginalProbabilities[i]);
            result.addGOTermProperties(prop);
        }
    }

    public String getName() {
        return "MGSA";
    }

    public boolean supportsTestCorrection() {
        return false;
    }

    private double[] calculate(int[][] term2Items, boolean[] observedItems) {
        int numTerms = term2Items.length;
        double[] res = new double[numTerms];
        Random rnd = new Random(this.seed);
        logger.info("Using random seed of: " + this.seed);
        double alpha = Double.NaN;
        double beta = Double.NaN;
        double expectedNumberOfTerms = Double.NaN;
        int maxIter = 1;
        for (int i = 0; i < maxIter; ++i) {
            int t;
            FixedAlphaBetaScore fixedAlphaBetaScore = new FixedAlphaBetaScore(rnd, term2Items, observedItems);
            fixedAlphaBetaScore.setIntegrateParams(this.integrateParams);
            logger.info("MCMC only: " + alpha + "  " + beta + "  " + expectedNumberOfTerms);
            fixedAlphaBetaScore.setAlpha(alpha);
            if (this.alpha.hasMax()) {
                fixedAlphaBetaScore.setMaxAlpha(this.alpha.getMax());
            }
            fixedAlphaBetaScore.setBeta(beta);
            if (this.beta.hasMax()) {
                fixedAlphaBetaScore.setMaxBeta(this.beta.getMax());
            }
            fixedAlphaBetaScore.setExpectedNumberOfTerms(expectedNumberOfTerms);
            fixedAlphaBetaScore.setUsePrior(this.usePrior);
            logger.info("Score of empty set: " + fixedAlphaBetaScore.getScore());
            if (this.randomStart) {
                int numberOfTerms = fixedAlphaBetaScore.EXPECTED_NUMBER_OF_TERMS[rnd.nextInt(fixedAlphaBetaScore.EXPECTED_NUMBER_OF_TERMS.length)];
                double pForStart = (double)numberOfTerms / (double)term2Items.length;
                for (int j = 0; j < term2Items.length; ++j) {
                    if (!(rnd.nextDouble() < pForStart)) continue;
                    fixedAlphaBetaScore.switchState(j);
                }
                logger.info("Starting with " + fixedAlphaBetaScore.getActiveTerms().length + " terms (p=" + pForStart + ")");
            }
            double score = fixedAlphaBetaScore.getScore();
            logger.info("Score of initial set: " + score);
            int maxSteps = this.mcmcSteps;
            int burnin = 20000;
            int numAccepts = 0;
            int numRejects = 0;
            double maxScore = score;
            int[] maxScoredTerms = fixedAlphaBetaScore.getActiveTerms();
            double maxScoredAlpha = Double.NaN;
            double maxScoredBeta = Double.NaN;
            double maxScoredP = Double.NaN;
            int maxWhenSeen = -1;
            long start = System.currentTimeMillis();
            for (t = 0; t < maxSteps; ++t) {
                long now;
                if (score > maxScore) {
                    maxScore = score;
                    maxScoredTerms = fixedAlphaBetaScore.getActiveTerms();
                    if (fixedAlphaBetaScore != null) {
                        maxScoredAlpha = fixedAlphaBetaScore.getAlpha();
                        maxScoredBeta = fixedAlphaBetaScore.getBeta();
                        maxScoredP = fixedAlphaBetaScore.getP();
                    }
                    maxWhenSeen = t;
                }
                if ((now = System.currentTimeMillis()) - start > (long)this.updateReportTime) {
                    logger.info(t * 100 / maxSteps + "% (score=" + score + " maxScore=" + maxScore + " #terms=" + fixedAlphaBetaScore.getActiveTerms().length + " accept/reject=" + (double)numAccepts / (double)numRejects + " accept/steps=" + (double)numAccepts / (double)t + " exp=" + expectedNumberOfTerms + " usePrior=" + this.usePrior + ")");
                    start = now;
                }
                long oldPossibilities = fixedAlphaBetaScore.getNeighborhoodSize();
                long r = rnd.nextLong();
                fixedAlphaBetaScore.proposeNewState(r);
                double newScore = fixedAlphaBetaScore.getScore();
                long newPossibilities = fixedAlphaBetaScore.getNeighborhoodSize();
                double acceptProb = Math.exp(newScore - score) * (double)oldPossibilities / (double)newPossibilities;
                double u = rnd.nextDouble();
                if (u >= acceptProb) {
                    fixedAlphaBetaScore.undoProposal();
                    ++numRejects;
                } else {
                    score = newScore;
                    ++numAccepts;
                }
                if (t <= burnin) continue;
                fixedAlphaBetaScore.record();
            }
            if (i == maxIter - 1) {
                for (t = 0; t < numTerms; ++t) {
                    res[t] = (double)fixedAlphaBetaScore.termActivationCounts[t] / (double)fixedAlphaBetaScore.numRecords;
                }
            }
            if (fixedAlphaBetaScore != null) {
                int j;
                if (Double.isNaN(alpha)) {
                    for (j = 0; j < fixedAlphaBetaScore.totalAlpha.length; ++j) {
                        logger.info("alpha(" + fixedAlphaBetaScore.ALPHA[j] + ")=" + (double)fixedAlphaBetaScore.totalAlpha[j] / (double)fixedAlphaBetaScore.numRecords);
                    }
                }
                if (Double.isNaN(beta)) {
                    for (j = 0; j < fixedAlphaBetaScore.totalBeta.length; ++j) {
                        logger.info("beta(" + fixedAlphaBetaScore.BETA[j] + ")=" + (double)fixedAlphaBetaScore.totalBeta[j] / (double)fixedAlphaBetaScore.numRecords);
                    }
                }
                if (Double.isNaN(expectedNumberOfTerms)) {
                    for (j = 0; j < fixedAlphaBetaScore.totalExp.length; ++j) {
                        logger.info("exp(" + fixedAlphaBetaScore.EXPECTED_NUMBER_OF_TERMS[j] + ")=" + (double)fixedAlphaBetaScore.totalExp[j] / (double)fixedAlphaBetaScore.numRecords);
                    }
                }
            }
            logger.info("numAccepts=" + numAccepts + "  numRejects = " + numRejects);
            if (!logger.isInfoEnabled()) continue;
            StringBuilder b = new StringBuilder();
            logger.info("Term combination that reaches score of " + maxScore + " when alpha=" + maxScoredAlpha + ", beta=" + maxScoredBeta + ", p=" + maxScoredP + " at step " + maxWhenSeen);
            b.append("Indices: ");
            for (int t2 : maxScoredTerms) {
                b.append(t2);
                b.append(", ");
            }
            logger.info(b.toString());
        }
        return res;
    }
}

