/*
 * Decompiled with CFR 0.152.
 */
package net.maizegenetics.analysis.imputation;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.TreeSet;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import net.maizegenetics.analysis.imputation.SelfedHaplotypeFinder;
import net.maizegenetics.dna.snp.GenotypeTable;
import net.maizegenetics.dna.snp.GenotypeTableUtils;
import net.maizegenetics.dna.snp.NucleotideAlignmentConstants;
import net.maizegenetics.stats.PCA.ClassicMds;
import net.maizegenetics.taxa.TaxaListBuilder;
import net.maizegenetics.taxa.Taxon;
import net.maizegenetics.taxa.distance.DistanceMatrix;
import net.maizegenetics.taxa.tree.TreeClusters;
import net.maizegenetics.taxa.tree.UPGMATree;
import org.apache.commons.math3.distribution.ChiSquaredDistribution;
import org.apache.commons.math3.stat.inference.ChiSquareTest;

public class PhaseHighCoverage {
    private static final byte N = 15;
    private static final byte NN = -1;
    private Path parentagePath;
    private Path genopath;
    private Path outhapsSelfCross;
    private Path outhapsCross;
    private Path monomorphs;
    private Path hapsSelf;
    private Path outhapsCrossHighCover;
    private Path outhapsAllProgeny;
    private GenotypeTable myGenotypeTable;
    private static ChiSquareTest chisqTest = new ChiSquareTest();
    private int monoMultiplier = 100;
    Map<String, byte[][]> selfHaps;

    public PhaseHighCoverage(GenotypeTable genotype) {
        this.myGenotypeTable = genotype;
    }

    public List<String[]> loadPlotInfo() {
        ArrayList<String[]> plotList = new ArrayList<String[]>();
        List taxaNames = this.myGenotypeTable.taxa().stream().map(t -> t.getName()).collect(Collectors.toList());
        try (BufferedReader br = Files.newBufferedReader(this.parentagePath);){
            String input;
            br.readLine();
            while ((input = br.readLine()) != null) {
                String[] data = input.split("\t");
                if (data.length <= 3 || !taxaNames.contains(data[0])) continue;
                plotList.add(data);
            }
        }
        catch (IOException e) {
            e.printStackTrace();
            System.exit(-1);
        }
        return plotList;
    }

    public void phaseParentsUsingAllAvailableProgeny(double minEigenRatio, Path savepath) {
        this.outhapsCrossHighCover = savepath;
        this.phaseParentsUsingAllAvailableProgeny(minEigenRatio);
    }

    public void phaseParentsUsingAllAvailableProgeny(double minEigenRatio) {
        System.out.println("Phasing parents using method phaseParentsUsingAllAvailableProgeny().");
        System.out.println("That in turn uses phaseParentUsingSelfAndCrossProgeny()");
        System.out.println("-------------------------------------------------------");
        int minNumberPhasedSites = 1000;
        HashMap<String, byte[][]> phasedParents = new HashMap<String, byte[][]>();
        List<String[]> plotInfo = this.loadPlotInfo();
        TreeSet<String> parentSet = new TreeSet<String>();
        for (String[] plot : plotInfo) {
            parentSet.add(plot[1]);
            parentSet.add(plot[2]);
        }
        for (String parent : parentSet) {
            System.out.printf("Phasing %s\n", parent);
            byte[][] phase = this.phaseParentUsingSelfAndCrossProgeny(parent, this.myGenotypeTable, minEigenRatio, plotInfo);
            if (phase == null) {
                System.out.println("Too few phased haplotypes, skipping.");
                System.out.println();
                continue;
            }
            int nsites = phase[0].length;
            int phasedSiteCount = 0;
            for (int s = 0; s < nsites; ++s) {
                if (phase[0][s] == 15) continue;
                ++phasedSiteCount;
            }
            System.out.printf("%d sites phased for %s\n", phasedSiteCount, parent);
            if (phasedSiteCount < minNumberPhasedSites) {
                System.out.println("Too few sites phased, skipping.");
                continue;
            }
            phasedParents.put(parent, phase);
        }
        SelfedHaplotypeFinder.serializePhasedHaplotypes(phasedParents, this.outhapsCrossHighCover);
        System.out.println("Finished phasing and storing haplotypes.");
    }

    public byte[][] phaseParentUsingSelfAndCrossProgeny(String parent, GenotypeTable myGeno, double minEigenRatio, List<String[]> plotInfo) {
        double alpha = 1.0E-4;
        double chisqLimit = new ChiSquaredDistribution(1.0).inverseCumulativeProbability(1.0 - alpha);
        ArrayList<byte[]> phasedHaplotypes = new ArrayList<byte[]>();
        int parentIndex = myGeno.taxa().indexOf(parent);
        byte[] parentGeno = myGeno.genotypeAllSites(parentIndex);
        byte[][] phasedParent = new byte[2][];
        for (int i2 = 0; i2 < 2; ++i2) {
            phasedParent[i2] = new byte[myGeno.numberOfSites()];
            Arrays.fill(phasedParent[i2], (byte)15);
        }
        for (String[] plot : plotInfo) {
            if (!plot[1].equals(parent) && !plot[2].equals(parent)) continue;
            String otherParent = plot[1].equals(parent) ? plot[2] : plot[1];
            System.out.println("phasing " + plot[0]);
            byte[][] haps = this.phaseParentUsingOneProgeny(parent, otherParent, plot[0], myGeno);
            phasedHaplotypes.add(haps[0]);
            phasedHaplotypes.add(haps[1]);
        }
        if (phasedHaplotypes.size() < 10) {
            return null;
        }
        int[] chrstart = myGeno.chromosomesOffsets();
        int[] chrend = new int[10];
        System.arraycopy(chrstart, 1, chrend, 0, 9);
        chrend[9] = myGeno.numberOfSites();
        int window = 40;
        int minWindow = 20;
        int minPresent = 4;
        ArrayList<int[]> monomorphicSites = new ArrayList<int[]>();
        for (int c = 0; c < 10; ++c) {
            int s = chrstart[c];
            ArrayList<Integer> prevHapIndices1 = null;
            ArrayList<Integer> prevHapIndices2 = null;
            Object previousSiteIndex = null;
            boolean isPreviousHapValid = false;
            while (s < chrend[c]) {
                List<byte[]> hapList2;
                List<byte[]> hapList1;
                boolean reverseHaps;
                Object siteIndex = new int[window];
                int indexCount = 0;
                int startS = s;
                while (s < chrend[c] && indexCount < window) {
                    int[] alleleCounts = this.countAllelesAtSite(phasedHaplotypes, s);
                    int npresent = Arrays.stream(alleleCounts).sum();
                    if (npresent > minPresent) {
                        List index = IntStream.range(0, 6).boxed().collect(Collectors.toList());
                        Collections.sort(index, (a, b) -> alleleCounts[a] >= alleleCounts[b] ? -1 : 1);
                        int mult = 10;
                        if (alleleCounts[(Integer)index.get(1)] > 1 && alleleCounts[(Integer)index.get(1)] * mult > alleleCounts[(Integer)index.get(0)]) {
                            siteIndex[indexCount++] = s;
                        }
                        if (alleleCounts[(Integer)index.get(0)] > 20 && alleleCounts[(Integer)index.get(1)] == 0) {
                            monomorphicSites.add(new int[]{s, (Integer)index.get(0)});
                        }
                    }
                    ++s;
                }
                int nSitesScanned = s - startS;
                if (indexCount < minWindow) {
                    if (!isPreviousHapValid) {
                        System.out.printf("No windows phased in chromosome %d%n", c + 1);
                        continue;
                    }
                    int combinedCount = (previousSiteIndex).length + indexCount;
                    Object combinedIndex = new int[combinedCount];
                    System.arraycopy(previousSiteIndex, 0, combinedIndex, 0, (previousSiteIndex).length);
                    System.arraycopy(siteIndex, 0, combinedIndex, (previousSiteIndex).length, indexCount);
                    siteIndex = combinedIndex;
                    indexCount = ((int[])siteIndex).length;
                }
                ArrayList<byte[]> seglist = new ArrayList<byte[]>();
                for (byte[] haps : phasedHaplotypes) {
                    byte[] seg = new byte[indexCount];
                    for (int i3 = 0; i3 < indexCount; ++i3) {
                        seg[i3] = haps[siteIndex[i3]];
                    }
                    seglist.add(seg);
                }
                double[][] dist = this.mismatchDistanceMatrix(seglist);
                List<Taxon> dummyTaxa = IntStream.range(0, dist.length).mapToObj(i -> new Taxon(Integer.toString(i))).collect(Collectors.toList());
                DistanceMatrix dm = new DistanceMatrix(dist, new TaxaListBuilder().addAll((Collection<Taxon>)dummyTaxa).build());
                ClassicMds mds = new ClassicMds(dm);
                System.out.printf("Eigenvalue ratio = %1.3f, Eigenvalues: %1.3f, %1.3f, %1.3f, %1.3f\n", mds.getEigenvalue(0) / mds.getEigenvalue(1), mds.getEigenvalue(0), mds.getEigenvalue(1), mds.getEigenvalue(2), mds.getEigenvalue(3));
                System.out.printf("%d sites scanned to generate this interval\n", nSitesScanned);
                if (mds.getEigenvalue(0) / mds.getEigenvalue(1) < minEigenRatio) continue;
                UPGMATree myTree = new UPGMATree(dm);
                TreeClusters myClusters = new TreeClusters(myTree);
                int[] myGroups = myClusters.getGroups(2);
                ArrayList<Integer> hapIndices1 = new ArrayList<Integer>();
                ArrayList<Integer> hapIndices2 = new ArrayList<Integer>();
                int n = myGroups.length;
                for (int i4 = 0; i4 < n; ++i4) {
                    String name = myTree.getExternalNode(i4).getIdentifier().getName();
                    if (myGroups[i4] == 0) {
                        hapIndices1.add(Integer.parseInt(name));
                        continue;
                    }
                    hapIndices2.add(Integer.parseInt(name));
                }
                if (prevHapIndices1 == null) {
                    reverseHaps = false;
                } else {
                    int[][] matches = new int[2][2];
                    matches[0][0] = this.countSharedMembers(hapIndices1, prevHapIndices1);
                    matches[0][1] = this.countSharedMembers(hapIndices1, prevHapIndices2);
                    matches[1][0] = this.countSharedMembers(hapIndices2, prevHapIndices1);
                    matches[1][1] = this.countSharedMembers(hapIndices2, prevHapIndices2);
                    int mainDiagSum = matches[0][0] + matches[1][1];
                    int offDiagSum = matches[0][1] + matches[1][0];
                    if (mainDiagSum > 2 * offDiagSum) {
                        reverseHaps = false;
                        isPreviousHapValid = true;
                    } else if (offDiagSum > 2 * mainDiagSum) {
                        reverseHaps = true;
                        isPreviousHapValid = true;
                    } else {
                        if (isPreviousHapValid) continue;
                        for (int prevSite : previousSiteIndex) {
                            phasedParent[0][prevSite] = 15;
                            phasedParent[1][prevSite] = 15;
                        }
                        reverseHaps = false;
                    }
                    System.out.printf("haplotype matches at chr %d, site %d: %d, %d, %d, %d, reverse = %b\n", c + 1, s, matches[0][0], matches[0][1], matches[1][0], matches[1][1], reverseHaps);
                }
                System.out.printf("hap list 1 has %d members, 2 has %d members (chr %d, site %d)\n", hapIndices1.size(), hapIndices2.size(), c + 1, s);
                if (reverseHaps) {
                    hapList1 = hapIndices2.stream().map(I -> (byte[])seglist.get((int)I)).collect(Collectors.toList());
                    hapList2 = hapIndices1.stream().map(I -> (byte[])seglist.get((int)I)).collect(Collectors.toList());
                } else {
                    hapList1 = hapIndices1.stream().map(I -> (byte[])seglist.get((int)I)).collect(Collectors.toList());
                    hapList2 = hapIndices2.stream().map(I -> (byte[])seglist.get((int)I)).collect(Collectors.toList());
                }
                System.out.println(this.haplotypeAsString(this.consensusHaplotype(hapList1)));
                System.out.println(this.haplotypeAsString(this.consensusHaplotype(hapList2)));
                for (int i5 = 0; i5 < indexCount; ++i5) {
                    double testval;
                    int[] alleleCount1 = this.countAllelesAtSite(hapList1, i5);
                    int[] alleleCount2 = this.countAllelesAtSite(hapList2, i5);
                    long[][] counts = new long[2][2];
                    int which = 0;
                    for (int j = 0; j < 4; ++j) {
                        if (alleleCount1[j] <= 0 && alleleCount2[j] <= 0) continue;
                        if (which == 2) {
                            System.out.printf("Site %d has more than 2 alleles\n", siteIndex[i5]);
                            break;
                        }
                        counts[0][which] = alleleCount1[j];
                        counts[1][which] = alleleCount2[j];
                        ++which;
                    }
                    double[] dArray = new double[]{(double)counts[0][0] / (double)counts[0][1], (double)counts[1][0] / (double)counts[1][1]};
                    double[] ratio = dArray;
                    if (!(ratio[0] >= 2.0 && ratio[1] <= 0.5) && (!(ratio[1] >= 2.0) || !(ratio[0] <= 0.5)) || !((testval = chisqTest.chiSquare(counts)) >= chisqLimit)) continue;
                    phasedParent[0][siteIndex[i5]] = this.maxAllele(alleleCount1);
                    phasedParent[1][siteIndex[i5]] = this.maxAllele(alleleCount2);
                }
                if (reverseHaps) {
                    prevHapIndices1 = hapIndices2;
                    prevHapIndices2 = hapIndices1;
                } else {
                    prevHapIndices1 = hapIndices1;
                    prevHapIndices2 = hapIndices2;
                }
                previousSiteIndex = siteIndex;
            }
        }
        int siteCount = 0;
        int nsites = myGeno.numberOfSites();
        for (int s = 0; s < nsites; ++s) {
            if (phasedParent[0][s] == 15 || phasedParent[1][s] == 15) continue;
            ++siteCount;
        }
        System.out.printf("There were %d polymorphic sites phased for %s\n", siteCount, parent);
        if (siteCount < 1500) {
            return null;
        }
        for (int[] site : monomorphicSites) {
            phasedParent[0][site[0]] = (byte)site[1];
            phasedParent[1][site[0]] = (byte)site[1];
        }
        return phasedParent;
    }

    private byte maxAllele(int[] alleleCounts) {
        int max = 0;
        for (int i = 1; i < alleleCounts.length; ++i) {
            if (alleleCounts[i] <= alleleCounts[max]) continue;
            max = i;
        }
        return (byte)max;
    }

    private int countSharedMembers(List<Integer> list1, List<Integer> list2) {
        ArrayList<Integer> refList = new ArrayList<Integer>(list1);
        Collections.sort(refList);
        int count = 0;
        for (Integer I : list2) {
            if (Collections.binarySearch(refList, I) <= -1) continue;
            ++count;
        }
        return count;
    }

    private byte[] consensusHaplotype(List<byte[]> haplotypes) {
        int nsites = haplotypes.get(0).length;
        byte[] consensus = new byte[nsites];
        for (int i = 0; i < nsites; ++i) {
            int[] alleleCounts = new int[6];
            for (byte[] hap : haplotypes) {
                if (hap[i] >= 6) continue;
                byte by = hap[i];
                alleleCounts[by] = alleleCounts[by] + 1;
            }
            int ndx = 0;
            for (int j = 1; j < 6; ++j) {
                if (alleleCounts[j] <= alleleCounts[ndx]) continue;
                ndx = j;
            }
            consensus[i] = (byte)ndx;
        }
        return consensus;
    }

    private String haplotypeAsString(byte[] hap) {
        StringBuilder sb = new StringBuilder();
        for (byte b : hap) {
            sb.append(NucleotideAlignmentConstants.getHaplotypeNucleotide(b));
        }
        return sb.toString();
    }

    private int[] countAllelesAtSite(List<byte[]> haps, int site) {
        int[] alleleCounts = new int[6];
        for (byte[] hap : haps) {
            if (hap[site] >= 6) continue;
            byte by = hap[site];
            alleleCounts[by] = alleleCounts[by] + 1;
        }
        return alleleCounts;
    }

    public byte[][] phaseParentUsingOneProgeny(String parent, String otherParent, String progeny, GenotypeTable gt) {
        int minDepth = 7;
        int nsites = gt.numberOfSites();
        byte[][] phasedGenotype = new byte[2][nsites];
        Arrays.fill(phasedGenotype[0], (byte)15);
        Arrays.fill(phasedGenotype[1], (byte)15);
        int parentIndex = gt.taxa().indexOf(parent);
        int otherParentIndex = gt.taxa().indexOf(otherParent);
        int progenyIndex = gt.taxa().indexOf(progeny);
        for (int s = 0; s < nsites; ++s) {
            byte otherAllele;
            byte[] progenyAlleles;
            boolean progenyHomozygous;
            byte parentGenotype = gt.genotype(parentIndex, s);
            byte otherParentGenotype = gt.genotype(otherParentIndex, s);
            byte progenyGenotype = gt.genotype(progenyIndex, s);
            boolean parentHomozygous = !GenotypeTableUtils.isHeterozygous(parentGenotype) && gt.depth().depth(parentIndex, s) >= minDepth;
            boolean otherParentHomozygous = !GenotypeTableUtils.isHeterozygous(otherParentGenotype) && gt.depth().depth(otherParentIndex, s) >= minDepth;
            boolean bl = progenyHomozygous = !GenotypeTableUtils.isHeterozygous(progenyGenotype) && gt.depth().depth(progenyIndex, s) >= minDepth;
            if (parentHomozygous) {
                byte by = GenotypeTableUtils.getDiploidValues(parentGenotype)[0];
                phasedGenotype[1][s] = by;
                phasedGenotype[0][s] = by;
                continue;
            }
            if (progenyHomozygous) {
                phasedGenotype[0][s] = GenotypeTableUtils.getDiploidValues(progenyGenotype)[0];
                if (parentGenotype == -1) continue;
                byte progenyAllele = GenotypeTableUtils.getDiploidValues(progenyGenotype)[0];
                byte[] parentAlleles = GenotypeTableUtils.getDiploidValues(parentGenotype);
                if (parentAlleles[0] != progenyAllele) {
                    phasedGenotype[1][s] = parentAlleles[0];
                    continue;
                }
                if (parentAlleles[1] == progenyAllele) continue;
                phasedGenotype[1][s] = parentAlleles[1];
                continue;
            }
            if (!otherParentHomozygous || (progenyAlleles = GenotypeTableUtils.getDiploidValues(progenyGenotype))[0] == (otherAllele = GenotypeTableUtils.getDiploidValues(otherParentGenotype)[0]) && progenyAlleles[1] == otherAllele) continue;
            phasedGenotype[0][s] = progenyAlleles[0] != otherAllele ? progenyAlleles[0] : progenyAlleles[1];
            byte[] parentAlleles = GenotypeTableUtils.getDiploidValues(parentGenotype);
            if (parentAlleles[0] != phasedGenotype[0][s]) {
                phasedGenotype[1][s] = parentAlleles[0];
                continue;
            }
            if (parentAlleles[1] == phasedGenotype[0][s]) continue;
            phasedGenotype[1][s] = parentAlleles[1];
        }
        return phasedGenotype;
    }

    private double[][] mismatchDistanceMatrix(List<byte[]> segments) {
        int n = segments.size();
        double[][] dist = new double[n][n];
        for (int i = 0; i < n - 1; ++i) {
            byte[] seg1 = segments.get(i);
            for (int j = i + 1; j < n; ++j) {
                byte[] seg2 = segments.get(j);
                int notMissingCount = 0;
                int notMatchCount = 0;
                for (int k = 0; k < seg1.length; ++k) {
                    if (seg1[k] == 15 || seg2[k] == 15) continue;
                    ++notMissingCount;
                    if (seg1[k] == seg2[k]) continue;
                    ++notMatchCount;
                }
                if (notMissingCount <= 0) continue;
                double d = (double)notMatchCount / (double)notMissingCount;
                dist[j][i] = d;
                dist[i][j] = d;
            }
        }
        return dist;
    }

    private double averageDistanceToCluster(double[][] distance, List<Integer> cluster2, int ndx) {
        int n = cluster2.size();
        double total = 0.0;
        for (int i = 0; i < n; ++i) {
            total += distance[cluster2.get(i)][ndx];
        }
        return total / (double)n;
    }

    private void saveSeglist(List<byte[]> seglist, String filename) {
        try (BufferedWriter bw = Files.newBufferedWriter(Paths.get(filename, new String[0]), new OpenOption[0]);){
            for (byte[] seg : seglist) {
                for (byte b : seg) {
                    bw.write(NucleotideAlignmentConstants.getHaplotypeNucleotide(b));
                }
            }
            bw.write("\n");
        }
        catch (IOException e) {
            e.printStackTrace();
        }
    }

    private void saveDistanceMatrix(DistanceMatrix dm, String filename) {
        int n = dm.numberOfTaxa();
        try (BufferedWriter bw = Files.newBufferedWriter(Paths.get(filename, new String[0]), new OpenOption[0]);){
            for (int i = 0; i < n; ++i) {
                for (int j = 0; j < n; ++j) {
                    bw.write(String.format("%1.5f ", Float.valueOf(dm.getDistance(i, j))));
                }
            }
        }
        catch (IOException e) {
            e.printStackTrace();
        }
    }

    public void setOuthapsAllProgeny(String filename) {
        this.outhapsAllProgeny = Paths.get(filename, new String[0]);
    }

    public void setParentage(String filename) {
        this.parentagePath = Paths.get(filename, new String[0]);
    }

    public void setGenotypeTable(GenotypeTable genotype) {
        this.myGenotypeTable = genotype;
    }
}

