/*
 * Decompiled with CFR 0.152.
 */
package net.maizegenetics.analysis.imputation;

import java.util.stream.IntStream;
import net.maizegenetics.analysis.imputation.PopulationData;
import net.maizegenetics.analysis.popgen.LinkageDisequilibrium;
import net.maizegenetics.dna.WHICH_ALLELE;
import net.maizegenetics.dna.snp.FilterGenotypeTable;
import net.maizegenetics.dna.snp.GenotypeTable;
import net.maizegenetics.dna.snp.GenotypeTableBuilder;
import net.maizegenetics.dna.snp.GenotypeTableUtils;
import net.maizegenetics.dna.snp.NucleotideAlignmentConstants;
import net.maizegenetics.stats.statistics.FisherExact;
import net.maizegenetics.taxa.Taxon;
import net.maizegenetics.util.BitSet;

public class UseParentHaplotypes {
    PopulationData myFamily;
    double minMaf = 0.1;
    double minCoverage = 0.5;
    double maxHet = 0.15;

    public UseParentHaplotypes(PopulationData family) {
        this.myFamily = family;
    }

    public void assignHaplotypes() {
        this.prefilterSites();
        this.setAllelesToParents();
        System.out.printf("imputed genotype table has %d sites and %d taxa", this.myFamily.imputed.numberOfSites(), this.myFamily.imputed.numberOfTaxa());
        System.out.println();
    }

    private void setAllelesToParents() {
        byte NN = -1;
        byte AA = NucleotideAlignmentConstants.getNucleotideDiploidByte("AA");
        byte CC = NucleotideAlignmentConstants.getNucleotideDiploidByte("CC");
        byte AC = NucleotideAlignmentConstants.getNucleotideDiploidByte("AC");
        int parent1Index = this.myFamily.imputed.taxa().indexOf(this.myFamily.parent1);
        int parent2Index = this.myFamily.imputed.taxa().indexOf(this.myFamily.parent2);
        this.myFamily.alleleA = this.myFamily.imputed.genotypeAllSites(parent1Index);
        this.myFamily.alleleC = this.myFamily.imputed.genotypeAllSites(parent2Index);
        byte[] genoA = this.myFamily.alleleA;
        byte[] genoC = this.myFamily.alleleC;
        int nsites = this.myFamily.imputed.numberOfSites();
        int ntaxa = this.myFamily.imputed.numberOfTaxa();
        for (int s = 0; s < nsites; ++s) {
            byte minorgeno;
            byte minor;
            byte majorgeno;
            byte major;
            if (genoA[s] == NN && genoC[s] != NN) {
                major = this.myFamily.imputed.majorAllele(s);
                majorgeno = GenotypeTableUtils.getDiploidValue(major, major);
                minor = this.myFamily.imputed.minorAllele(s);
                minorgeno = GenotypeTableUtils.getDiploidValue(minor, minor);
                if (genoC[s] == majorgeno) {
                    genoA[s] = minorgeno;
                    continue;
                }
                if (genoC[s] != minorgeno) continue;
                genoA[s] = majorgeno;
                continue;
            }
            if (genoA[s] == NN || genoC[s] != NN) continue;
            major = this.myFamily.imputed.majorAllele(s);
            majorgeno = GenotypeTableUtils.getDiploidValue(major, major);
            minor = this.myFamily.imputed.minorAllele(s);
            minorgeno = GenotypeTableUtils.getDiploidValue(minor, minor);
            if (genoA[s] == majorgeno) {
                genoC[s] = minorgeno;
                continue;
            }
            if (genoA[s] != minorgeno) continue;
            genoC[s] = majorgeno;
        }
        GenotypeTableBuilder genoBuilder = GenotypeTableBuilder.getTaxaIncremental(this.myFamily.imputed.positions());
        for (int t = 0; t < ntaxa; ++t) {
            byte[] taxonGeno = this.myFamily.imputed.genotypeAllSites(t);
            for (int s = 0; s < nsites; ++s) {
                taxonGeno[s] = genoA[s] == genoC[s] || GenotypeTableUtils.isHeterozygous(genoA[s]) || GenotypeTableUtils.isHeterozygous(genoC[s]) ? NN : (taxonGeno[s] == genoA[s] ? AA : (taxonGeno[s] == genoC[s] ? CC : (GenotypeTableUtils.isHeterozygous(taxonGeno[s]) ? AC : NN)));
            }
            genoBuilder.addTaxon((Taxon)this.myFamily.imputed.taxa().get(t), taxonGeno);
        }
        this.myFamily.imputed = genoBuilder.build();
    }

    private void validateParentGenotypes() {
        int p1 = this.myFamily.original.taxa().indexOf(this.myFamily.parent1);
        byte[] parent1Genotype = this.myFamily.original.genotypeAllSites(p1);
        int p2 = this.myFamily.original.taxa().indexOf(this.myFamily.parent1);
        byte[] parent2Genotype = this.myFamily.original.genotypeAllSites(p2);
    }

    private void prefilterSites() {
        GenotypeTable geno = this.myFamily.original;
        int nsites = geno.numberOfSites();
        int minNonMissing = (int)this.minCoverage * nsites;
        int[] sites = IntStream.range(0, nsites).filter(s -> {
            if (geno.minorAlleleFrequency(s) < this.minMaf) {
                return false;
            }
            int numberNotMissing = geno.totalNonMissingForSite(s);
            if (numberNotMissing < minNonMissing) {
                return false;
            }
            double proportionHet = (double)geno.heterozygousCount(s) / (double)numberNotMissing;
            return !(proportionHet > this.maxHet);
        }).toArray();
        GenotypeTable filteredGeno = FilterGenotypeTable.getInstance(geno, sites);
        GenotypeTable copy = GenotypeTableBuilder.getGenotypeCopyInstance(filteredGeno);
        int[] goodSites = IntStream.range(0, copy.numberOfSites()).filter(s -> this.isSiteInLD(copy, s)).toArray();
        this.myFamily.imputed = goodSites.length < sites.length ? FilterGenotypeTable.getInstance(copy, goodSites) : filteredGeno;
    }

    private boolean isSiteInLD(GenotypeTable geno, int site) {
        int window = 50;
        double minr2 = 0.8;
        int nsites = geno.numberOfSites();
        int startSite = Math.max(0, site - window);
        int endSite = Math.min(nsites, site + window + 1);
        BitSet rMj = geno.allelePresenceForAllTaxa(site, WHICH_ALLELE.Major);
        BitSet rMn = geno.allelePresenceForAllTaxa(site, WHICH_ALLELE.Minor);
        FisherExact fisherExact = FisherExact.getInstance(2 * geno.numberOfTaxa() + 10);
        double maxR2 = 0.0;
        for (int s = startSite; s < endSite; ++s) {
            if (s == site) continue;
            BitSet cMj = geno.allelePresenceForAllTaxa(s, WHICH_ALLELE.Major);
            BitSet cMn = geno.allelePresenceForAllTaxa(s, WHICH_ALLELE.Minor);
            float r2 = LinkageDisequilibrium.getLDForSitePair(rMj, rMn, cMj, cMn, 2, 10, -1.0f, fisherExact, site, s).r2();
            maxR2 = Math.max(maxR2, (double)r2);
        }
        return !(maxR2 < minr2);
    }
}

