/*
 * Decompiled with CFR 0.152.
 */
package net.maizegenetics.analysis.imputation;

import java.io.BufferedWriter;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.stream.Collectors;
import net.maizegenetics.analysis.imputation.EmissionProbability;
import net.maizegenetics.analysis.imputation.TransitionProbability;
import org.apache.log4j.Logger;

public class BackwardForwardAlgorithm {
    private static final Logger myLogger = Logger.getLogger(BackwardForwardAlgorithm.class);
    private int[] myObservations;
    private int[] myPositions;
    private TransitionProbability myTransitions;
    private EmissionProbability myEmissions;
    private double[] initialStateProbability;
    private List<double[]> alpha;
    private List<double[]> beta;

    public BackwardForwardAlgorithm calculateAlpha() {
        int nStates = this.myTransitions.getNumberOfStates();
        int nObs = this.myObservations.length;
        this.alpha = new LinkedList<double[]>();
        double[] aPrior = new double[nStates];
        for (int s = 0; s < nStates; ++s) {
            aPrior[s] = this.initialStateProbability[s] * this.myEmissions.getProbObsGivenState(s, this.myObservations[0], 0);
        }
        this.alpha.add(aPrior);
        for (int t = 1; t < nObs; ++t) {
            double[] aT = new double[nStates];
            this.myTransitions.setNode(t);
            for (int j = 0; j < nStates; ++j) {
                double sumTrans = 0.0;
                for (int i = 0; i < nStates; ++i) {
                    sumTrans += aPrior[i] * this.myTransitions.getTransitionProbability(i, j);
                }
                aT[j] = sumTrans * this.myEmissions.getProbObsGivenState(j, this.myObservations[t], t);
            }
            aT = this.multiplyArrayByConstantIfSmall(aT);
            this.alpha.add(aT);
            aPrior = aT;
        }
        return this;
    }

    private double[] multiplyArrayByConstantIfSmall(double[] dblArray) {
        double maxval;
        double minval = Arrays.stream(dblArray).min().getAsDouble();
        if (minval < 1.0E-50 && (maxval = Arrays.stream(dblArray).max().getAsDouble()) < 1.0E-25) {
            return Arrays.stream(dblArray).map(d -> d * 1.0E25).toArray();
        }
        return dblArray;
    }

    public BackwardForwardAlgorithm calculateBeta() {
        LinkedList<double[]> betaTemp = new LinkedList<double[]>();
        int nStates = this.myTransitions.getNumberOfStates();
        int nObs = this.myObservations.length;
        double[] bNext = new double[nStates];
        Arrays.fill(bNext, 1.0);
        betaTemp.add(bNext);
        for (int t = nObs - 2; t >= 0; --t) {
            double[] bT = new double[nStates];
            this.myTransitions.setNode(t + 1);
            for (int i = 0; i < nStates; ++i) {
                double sumStates = 0.0;
                for (int j = 0; j < nStates; ++j) {
                    sumStates += this.myTransitions.getTransitionProbability(i, j) * this.myEmissions.getProbObsGivenState(j, this.myObservations[t + 1], t + 1) * bNext[j];
                }
                bT[i] = sumStates;
            }
            bT = this.multiplyArrayByConstantIfSmall(bT);
            betaTemp.addFirst(bT);
            bNext = bT;
        }
        this.beta = betaTemp;
        return this;
    }

    private void printSite(int pos, double[] values) {
        System.out.print(pos + ": ");
        Arrays.stream(values).mapToObj(d -> String.format("%1.4f ", d)).forEach(System.out::print);
        System.out.println();
    }

    public List<double[]> gamma() {
        ArrayList<double[]> gamma = new ArrayList<double[]>();
        Iterator<double[]> itAlpha = this.alpha.iterator();
        Iterator<double[]> itBeta = this.beta.iterator();
        while (itAlpha.hasNext()) {
            double[] alphaArray = itAlpha.next();
            double[] betaArray = itBeta.next();
            int n = alphaArray.length;
            double[] gammaArray = new double[n];
            for (int i = 0; i < n; ++i) {
                gammaArray[i] = alphaArray[i] * betaArray[i];
            }
            double divisor = Arrays.stream(gammaArray).sum();
            int i = 0;
            while (i < n) {
                int n2 = i++;
                gammaArray[n2] = gammaArray[n2] / divisor;
            }
            gamma.add(gammaArray);
        }
        return gamma;
    }

    public void writeGamma(String outputFile, String formatString) {
        Iterator<double[]> itAlpha = this.alpha.iterator();
        Iterator<double[]> itBeta = this.beta.iterator();
        int counter = 0;
        try (BufferedWriter bw = Files.newBufferedWriter(Paths.get(outputFile, new String[0]), new OpenOption[0]);){
            while (itAlpha.hasNext()) {
                double[] alphaArray = itAlpha.next();
                double[] betaArray = itBeta.next();
                int n = alphaArray.length;
                double[] gammaArray = new double[n];
                for (int i = 0; i < n; ++i) {
                    gammaArray[i] = alphaArray[i] * betaArray[i];
                }
                double divisor = Arrays.stream(gammaArray).sum();
                double[] normalizedGamma = Arrays.stream(gammaArray).map(g -> g / divisor).toArray();
                bw.write(this.myPositions[counter] + "\t");
                bw.write(Arrays.stream(normalizedGamma).mapToObj(dbl -> String.format(formatString, dbl)).collect(Collectors.joining("\t", "", "\n")));
                ++counter;
            }
        }
        catch (IOException ioe) {
            throw new RuntimeException("Unable to write " + outputFile, ioe);
        }
    }

    public void writeGamma(String outputFile) {
        this.writeGamma(outputFile, "%1.2e");
    }

    public BackwardForwardAlgorithm emission(EmissionProbability emission) {
        this.myEmissions = emission;
        return this;
    }

    public BackwardForwardAlgorithm transition(TransitionProbability transition) {
        this.myTransitions = transition;
        return this;
    }

    public BackwardForwardAlgorithm observations(int[] observations) {
        this.myObservations = observations;
        return this;
    }

    public BackwardForwardAlgorithm positions(int[] positions) {
        this.myPositions = positions;
        return this;
    }

    public BackwardForwardAlgorithm initialStateProbability(double[] probs) {
        this.initialStateProbability = probs;
        return this;
    }

    public List<double[]> alpha() {
        return this.alpha;
    }

    public List<double[]> beta() {
        return this.beta;
    }
}

