/*
 * Decompiled with CFR 0.152.
 */
package net.maizegenetics.analysis.imputation;

import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import net.maizegenetics.analysis.data.FileLoadPlugin;
import net.maizegenetics.analysis.imputation.ImputationUtils;
import net.maizegenetics.dna.snp.GenotypeTable;
import net.maizegenetics.dna.snp.GenotypeTableUtils;
import net.maizegenetics.dna.snp.NucleotideAlignmentConstants;
import org.apache.commons.math3.distribution.BinomialDistribution;
import org.apache.log4j.Logger;

public class RephaseParents {
    private static Logger myLogger = Logger.getLogger(RephaseParents.class);
    private static final byte NN = -1;
    private static final byte N = 15;
    private static final byte AA = NucleotideAlignmentConstants.getNucleotideDiploidByte("AA");
    private static final byte CC = NucleotideAlignmentConstants.getNucleotideDiploidByte("CC");
    private static final byte GG = NucleotideAlignmentConstants.getNucleotideDiploidByte("GG");
    private static final byte TT = NucleotideAlignmentConstants.getNucleotideDiploidByte("TT");
    private static final byte missingState = 4;
    GenotypeTable origGeno;
    Map<String, byte[]> progenyStates;
    List<String[]> plotList;
    Map<String, byte[][]> rephasedParents;
    Map<String, byte[][]> startingParents = null;
    Map<String, double[][]> parentHaplotypeProbabilities = null;
    int minFamilySize = 10;
    int minDepth = 7;
    String outputFilename;

    public RephaseParents() {
    }

    public RephaseParents(GenotypeTable originalGenotypes, Map<String, byte[]> phasedProgeny, List<String[]> plotList, Map<String, byte[][]> parentHapmap) {
        this.origGeno = originalGenotypes;
        this.progenyStates = phasedProgeny;
        this.plotList = plotList;
        this.startingParents = parentHapmap;
    }

    public RephaseParents(GenotypeTable originalGenotypes, String phasedProgeny, String parentage, String parentHaps) {
        this.origGeno = originalGenotypes;
        GenotypeTable progenyStatesTable = (GenotypeTable)FileLoadPlugin.runPlugin(phasedProgeny);
        this.progenyStates = RephaseParents.progenyStates(progenyStatesTable);
        myLogger.info((Object)String.format("progeny states loaded: %s", phasedProgeny));
        try {
            this.plotList = Files.lines(Paths.get(parentage, new String[0])).skip(1L).map(in -> in.split("\t")).collect(Collectors.toList());
            myLogger.info((Object)String.format("plotList has %d entries", this.plotList.size()));
        }
        catch (IOException e) {
            throw new RuntimeException("Unable to read " + parentage, e);
        }
        this.startingParents = ImputationUtils.restorePhasedHaplotypes(Paths.get(parentHaps, new String[0]));
        myLogger.info((Object)String.format("Starting parent haplotypes loaded: %s", parentHaps));
    }

    Map<String, double[][]> rephaseUsingAlleleDepth() {
        return this.rephaseUsingAlleleDepth(null);
    }

    Map<String, double[][]> rephaseUsingAlleleDepth(String saveFilename) {
        HashMap<String, double[][]> nextHaplotypeProbs = new HashMap<String, double[][]>();
        HashMap parentPlotMap = new HashMap();
        for (String[] plot : this.plotList) {
            List<String[]> parentPlotList;
            if (this.startingParents.get(plot[2]) != null) {
                parentPlotList = (ArrayList<String[]>)parentPlotMap.get(plot[1]);
                if (parentPlotList == null) {
                    parentPlotList = new ArrayList<String[]>();
                    parentPlotMap.put(plot[1], parentPlotList);
                }
                parentPlotList.add(plot);
            }
            if (plot[2].equals(plot[1]) || this.startingParents.get(plot[1]) == null) continue;
            parentPlotList = (List)parentPlotMap.get(plot[2]);
            if (parentPlotList == null) {
                parentPlotList = new ArrayList();
                parentPlotMap.put(plot[2], parentPlotList);
            }
            parentPlotList.add(plot);
        }
        for (String parent : parentPlotMap.keySet()) {
            if (this.startingParents.get(parent) == null) continue;
            double[][] probs = this.rephasePreviouslyPhased(parent, (List)parentPlotMap.get(parent));
            nextHaplotypeProbs.put(parent, probs);
        }
        this.parentHaplotypeProbabilities = nextHaplotypeProbs;
        if (saveFilename != null && saveFilename.length() > 1) {
            try {
                FileOutputStream fos = new FileOutputStream(new File(saveFilename));
                ObjectOutputStream oos = new ObjectOutputStream(fos);
                oos.writeObject(this.parentHaplotypeProbabilities);
                oos.close();
            }
            catch (IOException e) {
                e.printStackTrace();
            }
        }
        return this.parentHaplotypeProbabilities;
    }

    double[][] rephasePreviouslyPhased(String parent, List<String[]> plotList) {
        double err = 0.01;
        int nsites = this.origGeno.numberOfSites();
        double[][] haplotypeProbability = new double[2][nsites];
        for (int i = 0; i < 2; ++i) {
            Arrays.fill(haplotypeProbability[i], Double.NaN);
        }
        for (int s = 0; s < nsites; ++s) {
            byte major = this.origGeno.majorAllele(s);
            byte minor = this.origGeno.minorAllele(s);
            if (minor == 15) {
                for (int i = 0; i < 2; ++i) {
                    haplotypeProbability[i][s] = 1.0;
                }
                continue;
            }
            double majorFreq = this.origGeno.majorAlleleFrequency(s);
            double minorFreq = this.origGeno.minorAlleleFrequency(s);
            int[] totalAlleleDepths = new int[6];
            int[][] stateAlleleDepths = new int[4][6];
            int[][] parentStateAlleleDepths = new int[2][6];
            int[] switchState = new int[]{0, 2, 1, 3};
            for (String[] plot : plotList) {
                int myState;
                int parentState;
                int taxonNdx = this.origGeno.taxa().indexOf(plot[0]);
                int[] tempDepths = this.origGeno.depthForAlleles(taxonNdx, s);
                byte geno = this.origGeno.genotype(taxonNdx, s);
                int state = this.progenyStates.get(plot[0])[s];
                if (state >= 4) continue;
                if (parent.equals(plot[1])) {
                    parentState = state == 0 || state == 1 ? 0 : 1;
                    myState = state;
                } else {
                    parentState = state == 0 || state == 2 ? 0 : 1;
                    myState = switchState[state];
                }
                for (int i = 0; i < 6; ++i) {
                    int n = i;
                    totalAlleleDepths[n] = totalAlleleDepths[n] + tempDepths[i];
                    int[] nArray = stateAlleleDepths[myState];
                    int n2 = i;
                    nArray[n2] = nArray[n2] + tempDepths[i];
                    int[] nArray2 = parentStateAlleleDepths[parentState];
                    int n3 = i;
                    nArray2[n3] = nArray2[n3] + tempDepths[i];
                }
            }
            if (parentStateAlleleDepths[0][major] < 5 || parentStateAlleleDepths[0][major] < 5) continue;
            double ph10major = new BinomialDistribution(stateAlleleDepths[0][major] + stateAlleleDepths[0][minor], err).probability(stateAlleleDepths[0][minor]);
            double ph10minor = new BinomialDistribution(stateAlleleDepths[0][major] + stateAlleleDepths[0][minor], 0.5).probability(stateAlleleDepths[0][minor]);
            double ph11major = new BinomialDistribution(stateAlleleDepths[1][major] + stateAlleleDepths[1][minor], err).probability(stateAlleleDepths[1][minor]);
            double ph11minor = new BinomialDistribution(stateAlleleDepths[1][major] + stateAlleleDepths[1][minor], 0.5).probability(stateAlleleDepths[1][minor]);
            double ph00major = ph10major * ph11major * majorFreq * majorFreq + ph10major * ph11minor * majorFreq * minorFreq + ph10minor * ph11major * minorFreq * majorFreq + ph10minor * ph11minor * minorFreq * minorFreq;
            ph00major *= majorFreq;
            ph10major = new BinomialDistribution(stateAlleleDepths[0][major] + stateAlleleDepths[0][minor], 0.5).probability(stateAlleleDepths[0][major]);
            ph10minor = new BinomialDistribution(stateAlleleDepths[0][major] + stateAlleleDepths[0][minor], err).probability(stateAlleleDepths[0][major]);
            ph11major = new BinomialDistribution(stateAlleleDepths[1][major] + stateAlleleDepths[1][minor], 0.5).probability(stateAlleleDepths[1][major]);
            ph11minor = new BinomialDistribution(stateAlleleDepths[1][major] + stateAlleleDepths[1][minor], err).probability(stateAlleleDepths[1][major]);
            double ph00minor = ph10major * ph11major * majorFreq * majorFreq + ph10major * ph11minor * majorFreq * minorFreq + ph10minor * ph11major * minorFreq * majorFreq + ph10minor * ph11minor * minorFreq * minorFreq;
            haplotypeProbability[0][s] = ph00major / (ph00major + (ph00minor *= minorFreq));
            ph10major = new BinomialDistribution(stateAlleleDepths[2][major] + stateAlleleDepths[2][minor], err).probability(stateAlleleDepths[2][minor]);
            ph10minor = new BinomialDistribution(stateAlleleDepths[2][major] + stateAlleleDepths[2][minor], 0.5).probability(stateAlleleDepths[2][minor]);
            ph11major = new BinomialDistribution(stateAlleleDepths[3][major] + stateAlleleDepths[3][minor], err).probability(stateAlleleDepths[3][minor]);
            ph11minor = new BinomialDistribution(stateAlleleDepths[3][major] + stateAlleleDepths[3][minor], 0.5).probability(stateAlleleDepths[3][minor]);
            double ph01major = ph10major * ph11major * majorFreq * majorFreq + ph10major * ph11minor * majorFreq * minorFreq + ph10minor * ph11major * minorFreq * majorFreq + ph10minor * ph11minor * minorFreq * minorFreq;
            ph01major *= majorFreq;
            ph10major = new BinomialDistribution(stateAlleleDepths[2][major] + stateAlleleDepths[2][minor], 0.5).probability(stateAlleleDepths[2][major]);
            ph10minor = new BinomialDistribution(stateAlleleDepths[2][major] + stateAlleleDepths[2][minor], err).probability(stateAlleleDepths[2][major]);
            ph11major = new BinomialDistribution(stateAlleleDepths[3][major] + stateAlleleDepths[3][minor], 0.5).probability(stateAlleleDepths[3][major]);
            ph11minor = new BinomialDistribution(stateAlleleDepths[3][major] + stateAlleleDepths[3][minor], err).probability(stateAlleleDepths[3][major]);
            double ph01minor = ph10major * ph11major * majorFreq * majorFreq + ph10major * ph11minor * majorFreq * minorFreq + ph10minor * ph11major * minorFreq * majorFreq + ph10minor * ph11minor * minorFreq * minorFreq;
            haplotypeProbability[1][s] = ph01major / (ph01major + (ph01minor *= minorFreq));
        }
        return haplotypeProbability;
    }

    Map<String, byte[][]> rephaseUsingCrossProgeny() {
        this.rephasedParents = new HashMap<String, byte[][]>();
        int[] firstParentChr = new int[]{0, 0, 1, 1};
        int[] secondParentChr = new int[]{0, 1, 0, 1};
        HashMap parentPlotMap = new HashMap();
        for (String[] plot : this.plotList) {
            if (!plot[3].equals("outcross")) continue;
            List<String[]> parentPlotList = (ArrayList<String[]>)parentPlotMap.get(plot[1]);
            if (parentPlotList == null) {
                parentPlotList = new ArrayList<String[]>();
                parentPlotMap.put(plot[1], parentPlotList);
            }
            parentPlotList.add(plot);
            if (plot[2].equals(plot[1])) continue;
            parentPlotList = (List)parentPlotMap.get(plot[2]);
            if (parentPlotList == null) {
                parentPlotList = new ArrayList();
                parentPlotMap.put(plot[2], parentPlotList);
            }
            parentPlotList.add(plot);
        }
        int nsites = this.origGeno.numberOfSites();
        for (String parent : parentPlotMap.keySet()) {
            System.out.printf("Rephasing %s\n", parent);
            ArrayList<byte[][]> haplotypeList = new ArrayList<byte[][]>();
            List parentPlotList = (List)parentPlotMap.get(parent);
            if (parentPlotList == null) {
                System.out.printf("parentPlotList null for %s\n", parent);
                continue;
            }
            if (parentPlotList.size() < this.minFamilySize) {
                System.out.printf("parentPlotList has %d plots for %s\n", parentPlotList.size(), parent);
                continue;
            }
            for (String[] plot : parentPlotList) {
                boolean isFirstParent;
                String otherParent;
                if (plot[1].equals(parent)) {
                    otherParent = plot[2];
                    isFirstParent = true;
                } else {
                    otherParent = plot[1];
                    isFirstParent = false;
                }
                int progenyIndex = this.origGeno.taxa().indexOf(plot[0]);
                byte[][] parentHap = new byte[2][nsites];
                for (int i = 0; i < 2; ++i) {
                    Arrays.fill(parentHap[i], (byte)15);
                }
                byte[] myStates = this.progenyStates.get(plot[0]);
                byte[][] otherParentHap = this.startingParents.get(otherParent);
                if (otherParentHap == null || myStates == null) continue;
                for (int s = 0; s < nsites; ++s) {
                    byte[] alleles;
                    int otherChr;
                    int myChr;
                    byte myGenotype;
                    if (myStates[s] == 4 || (myGenotype = this.origGeno.genotype(progenyIndex, s)) == -1) continue;
                    if (isFirstParent) {
                        myChr = firstParentChr[myStates[s]];
                        otherChr = secondParentChr[myStates[s]];
                    } else {
                        otherChr = firstParentChr[myStates[s]];
                        myChr = secondParentChr[myStates[s]];
                    }
                    byte otherAllele = otherParentHap[otherChr][s];
                    if (otherAllele == 15) continue;
                    int mydepth = this.origGeno.depth().depth(progenyIndex, s);
                    if (GenotypeTableUtils.isHeterozygous(myGenotype)) {
                        alleles = GenotypeTableUtils.getDiploidValues(myGenotype);
                        if (otherAllele == alleles[0]) {
                            parentHap[myChr][s] = alleles[1];
                        }
                        if (otherAllele != alleles[1]) continue;
                        parentHap[myChr][s] = alleles[0];
                        continue;
                    }
                    alleles = GenotypeTableUtils.getDiploidValues(myGenotype);
                    if (alleles[0] != otherAllele && mydepth < 9) {
                        parentHap[myChr][s] = alleles[0];
                        continue;
                    }
                    if (mydepth < this.minDepth || otherAllele != alleles[0]) continue;
                    parentHap[myChr][s] = alleles[0];
                }
                haplotypeList.add(parentHap);
            }
            byte[][] newhaps = new byte[2][nsites];
            Arrays.fill(newhaps[0], (byte)15);
            Arrays.fill(newhaps[1], (byte)15);
            for (int s = 0; s < nsites; ++s) {
                byte major = this.origGeno.majorAllele(s);
                byte minor = this.origGeno.minorAllele(s);
                int[][] alleleCount = new int[2][6];
                for (byte[][] hap : haplotypeList) {
                    for (int i = 0; i < 2; ++i) {
                        byte val = hap[i][s];
                        if (val >= 6) continue;
                        int[] nArray = alleleCount[i];
                        byte by = val;
                        nArray[by] = nArray[by] + 1;
                    }
                }
                for (int i = 0; i < 2; ++i) {
                    byte myAllele;
                    int[] order = RephaseParents.countSortOrder(alleleCount[i], true);
                    if (alleleCount[i][order[0]] <= 2 * alleleCount[i][order[1]] || (myAllele = (byte)order[0]) != major && myAllele != minor) continue;
                    newhaps[i][s] = myAllele;
                }
            }
            this.rephasedParents.put(parent, newhaps);
        }
        return this.rephasedParents;
    }

    public void setMinDepth(int mindepth) {
        this.minDepth = mindepth;
    }

    public static int[] countSortOrder(int[] counts, boolean descending) {
        int n = counts.length;
        List order = IntStream.range(0, n).boxed().collect(Collectors.toList());
        if (descending) {
            Collections.sort(order, (a, b) -> {
                if (counts[a] > counts[b]) {
                    return -1;
                }
                if (counts[a] < counts[b]) {
                    return 1;
                }
                return 0;
            });
        } else {
            Collections.sort(order, (a, b) -> {
                if (counts[a] > counts[b]) {
                    return 1;
                }
                if (counts[a] < counts[b]) {
                    return -1;
                }
                return 0;
            });
        }
        return order.stream().mapToInt(I -> I).toArray();
    }

    public static Map<String, byte[]> progenyStates(GenotypeTable gt) {
        HashMap<String, byte[]> outputMap = new HashMap<String, byte[]>();
        int nsites = gt.numberOfSites();
        int ntaxa = gt.numberOfTaxa();
        for (int t = 0; t < ntaxa; ++t) {
            byte[] states = new byte[nsites];
            for (int s = 0; s < nsites; ++s) {
                int val = GenotypeTableUtils.getDiploidValues(gt.genotype(t, s))[0];
                states[s] = val > -1 && val < 3 ? val : 4;
            }
            outputMap.put(gt.taxaName(t), states);
        }
        return outputMap;
    }
}

