/*
 * Decompiled with CFR 0.152.
 */
package net.maizegenetics.analysis.imputation;

import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.MinMaxPriorityQueue;
import com.google.common.collect.Multimap;
import com.google.common.collect.Range;
import java.awt.Frame;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.concurrent.atomic.LongAdder;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import javax.swing.ImageIcon;
import net.maizegenetics.analysis.popgen.LDResult;
import net.maizegenetics.analysis.popgen.LinkageDisequilibrium;
import net.maizegenetics.dna.map.Position;
import net.maizegenetics.dna.map.PositionList;
import net.maizegenetics.dna.map.PositionListBuilder;
import net.maizegenetics.dna.snp.FilterGenotypeTable;
import net.maizegenetics.dna.snp.GenotypeTable;
import net.maizegenetics.dna.snp.GenotypeTableBuilder;
import net.maizegenetics.dna.snp.GenotypeTableUtils;
import net.maizegenetics.plugindef.AbstractPlugin;
import net.maizegenetics.plugindef.DataSet;
import net.maizegenetics.plugindef.Datum;
import net.maizegenetics.plugindef.GeneratePluginCode;
import net.maizegenetics.plugindef.Plugin;
import net.maizegenetics.plugindef.PluginParameter;
import net.maizegenetics.taxa.Taxon;
import net.maizegenetics.util.Tuple;
import org.apache.log4j.Logger;

public class LDKNNiImputationPlugin
extends AbstractPlugin {
    private PluginParameter<Integer> highLDSSites = new PluginParameter.Builder<Integer>("highLDSSites", 30, Integer.class).range((Range<Comparable<Integer>>)Range.closed((Comparable)Integer.valueOf(2), (Comparable)Integer.valueOf(2000))).guiName("High LD Sites").description("Number of sites in high LD to use in imputation").build();
    private PluginParameter<Integer> knnTaxa = new PluginParameter.Builder<Integer>("knnTaxa", 10, Integer.class).range((Range<Comparable<Integer>>)Range.closed((Comparable)Integer.valueOf(2), (Comparable)Integer.valueOf(200))).guiName("Number of nearest neighbors").description("Number of neighbors to use in imputation").build();
    private PluginParameter<Integer> maxDistance = new PluginParameter.Builder<Integer>("maxLDDistance", 10000000, Integer.class).guiName("Max distance between site to find LD").description("Maximum physical distance between sites to search for LD (-1 for no distance cutoff - unlinked chromosomes will be tested)").build();
    private static final Logger myLogger = Logger.getLogger(LDKNNiImputationPlugin.class);

    public LDKNNiImputationPlugin() {
        super(null, false);
    }

    public LDKNNiImputationPlugin(Frame parentFrame, boolean isInteractive) {
        super(parentFrame, isInteractive);
    }

    @Override
    protected void preProcessParameters(DataSet input) {
        List<Datum> alignInList = input.getDataOfType(GenotypeTable.class);
        if (alignInList.size() != 1) {
            throw new IllegalArgumentException("LDKNNiImputationPlugin: preProcessParameters: Please select one Genotype Table.");
        }
    }

    @Override
    public DataSet processData(DataSet input) {
        Datum genoDatum = input.getDataOfType(GenotypeTable.class).get(0);
        GenotypeTable genotypeTable = (GenotypeTable)genoDatum.getData();
        Multimap<Position, Position> highLDMap = this.getHighLDMap(genotypeTable, this.highLDSSites());
        System.out.println("LD calculated");
        GenotypeTableBuilder incSiteBuilder = GenotypeTableBuilder.getSiteIncremental(genotypeTable.taxa());
        long time = System.nanoTime();
        LongAdder sites1Kdone = new LongAdder();
        IntStream.range(0, genotypeTable.numberOfSites()).parallel().forEach(posIndex -> {
            Position position = (Position)genotypeTable.positions().get(posIndex);
            PositionList positionList = PositionListBuilder.getInstance(new ArrayList<Position>(highLDMap.get((Object)position)));
            byte[] currGenos = genotypeTable.genotypeAllTaxa(posIndex);
            byte[] newGenos = new byte[currGenos.length];
            if (!genotypeTable.isPolymorphic(posIndex)) {
                byte monomorphicGenotype = GenotypeTableUtils.getDiploidValue(genotypeTable.majorAllele(posIndex), genotypeTable.majorAllele(posIndex));
                for (int i = 0; i < newGenos.length; ++i) {
                    newGenos[i] = currGenos[i] == -1 ? monomorphicGenotype : currGenos[i];
                }
            } else {
                GenotypeTable ldGenoTable = GenotypeTableBuilder.getGenotypeCopyInstance(FilterGenotypeTable.getInstance(genotypeTable, positionList));
                int numberSites = ldGenoTable.numberOfSites();
                double[] taxaCoverage = IntStream.range(0, ldGenoTable.numberOfTaxa()).sequential().mapToDouble(t -> (double)ldGenoTable.totalNonMissingForTaxon(t) / (double)numberSites).toArray();
                for (int taxon = 0; taxon < currGenos.length; ++taxon) {
                    newGenos[taxon] = currGenos[taxon];
                    if (currGenos[taxon] != -1) continue;
                    Multimap<Double, Byte> closeGenotypes = this.getClosestNonMissingTaxa((Taxon)genotypeTable.taxa().get(taxon), genotypeTable, ldGenoTable, position, taxaCoverage, this.knnTaxa());
                    newGenos[taxon] = closeGenotypes.isEmpty() ? -1 : LDKNNiImputationPlugin.impute(closeGenotypes, this.highLDSSites());
                }
            }
            incSiteBuilder.addSite(position, newGenos);
            if ((posIndex + 1) % 100 == 0) {
                sites1Kdone.add(100L);
                this.fireProgress(33 + (int)(66L * sites1Kdone.longValue()) / genotypeTable.numberOfSites());
                System.out.println(sites1Kdone.longValue() + ":" + (System.nanoTime() - time) / 1000000L / sites1Kdone.longValue());
            }
        });
        GenotypeTable impGenotypeTable = incSiteBuilder.build();
        return new DataSet(new Datum(genoDatum.getName() + "_KNNimp", impGenotypeTable, "Imputed genotypes by KNN imputation"), (Plugin)this);
    }

    private Multimap<Double, Byte> getClosestNonMissingTaxa(Taxon inputTaxon, GenotypeTable genotypeTable, GenotypeTable ldGenoTable, Position targetPosition, double[] inputCoverage, int numberOfTaxa) {
        int targetPosIdx = genotypeTable.positions().indexOf(targetPosition);
        int inputTaxonIdx = genotypeTable.taxa().indexOf(inputTaxon);
        byte[] inputTaxonGenotypes = ldGenoTable.genotypeAllSites(inputTaxonIdx);
        MinMaxPriorityQueue topTaxa = IntStream.range(0, genotypeTable.numberOfTaxa()).filter(closeTaxonIdx -> closeTaxonIdx != inputTaxonIdx).filter(closeTaxonIdx -> inputCoverage[closeTaxonIdx] * inputCoverage[inputTaxonIdx] * (double)ldGenoTable.numberOfSites() > 10.0).filter(closeTaxonIdx -> genotypeTable.genotype(closeTaxonIdx, targetPosIdx) != -1).mapToObj(closeTaxonIdx -> new Tuple<Double, Byte>(LDKNNiImputationPlugin.dist(inputTaxonGenotypes, ldGenoTable.genotypeAllSites(closeTaxonIdx), 10)[0], genotypeTable.genotype(closeTaxonIdx, targetPosIdx))).filter(distanceTaxon -> !Double.isNaN((Double)distanceTaxon.x)).collect(Collectors.toCollection(() -> MinMaxPriorityQueue.maximumSize((int)numberOfTaxa).create()));
        ArrayListMultimap distGenoMap = ArrayListMultimap.create();
        topTaxa.stream().forEach(arg_0 -> LDKNNiImputationPlugin.lambda$getClosestNonMissingTaxa$8((Multimap)distGenoMap, arg_0));
        return distGenoMap;
    }

    private Multimap<Position, Position> getHighLDMap(GenotypeTable genotypeTable, int numberOfSNPs) {
        ArrayListMultimap highLDMap = ArrayListMultimap.create();
        int numberOfSites = genotypeTable.numberOfSites();
        LongAdder sites1Kdone = new LongAdder();
        IntStream.range(0, genotypeTable.numberOfSites()).parallel().forEach(arg_0 -> this.lambda$getHighLDMap$9(numberOfSNPs, numberOfSites, genotypeTable, (Multimap)highLDMap, sites1Kdone, arg_0));
        return highLDMap;
    }

    static byte impute(Multimap<Double, Byte> distGeno, int useLDSites) {
        double[] weightedCount = new double[256];
        distGeno.entries().forEach(entry -> {
            int n = (Byte)entry.getValue() + 128;
            weightedCount[n] = weightedCount[n] + 1.0 / (1.0 + (double)useLDSites * (Double)entry.getKey());
        });
        int bestGeno = 0;
        double bestWeightedCount = weightedCount[0];
        for (int i = 1; i < 256; ++i) {
            if (!(weightedCount[i] > bestWeightedCount)) continue;
            bestWeightedCount = weightedCount[i];
            bestGeno = i;
        }
        return (byte)(bestGeno - 128);
    }

    @Override
    public String getCitation() {
        return "Daniel Money, Kyle Gardner, Heidi Schwaninger, Gan-Yuan Zhong, Sean Myles. (In Review)  LinkImpute: fast and accurate genotype imputation for non-model organisms";
    }

    @Override
    public ImageIcon getIcon() {
        return null;
    }

    @Override
    public String getButtonName() {
        return "LD KNNi Imputation";
    }

    @Override
    public String getToolTipText() {
        return "LD KNNi Imputation";
    }

    public static void main(String[] args) {
        GeneratePluginCode.generate(LDKNNiImputationPlugin.class);
    }

    public GenotypeTable runPlugin(DataSet input) {
        return (GenotypeTable)this.performFunction(input).getData(0).getData();
    }

    public Integer highLDSSites() {
        return this.highLDSSites.value();
    }

    public LDKNNiImputationPlugin highLDSSites(Integer value) {
        this.highLDSSites = new PluginParameter<Integer>(this.highLDSSites, value);
        return this;
    }

    public Integer knnTaxa() {
        return this.knnTaxa.value();
    }

    public LDKNNiImputationPlugin knnTaxa(Integer value) {
        this.knnTaxa = new PluginParameter<Integer>(this.knnTaxa, value);
        return this;
    }

    public Integer maxDistance() {
        return this.maxDistance.value();
    }

    public LDKNNiImputationPlugin maxDistance(Integer value) {
        this.maxDistance = new PluginParameter<Integer>(this.maxDistance, value);
        return this;
    }

    public static double[] dist(byte[] b1, byte[] b2, int min) {
        int distance = 0;
        int count = 0;
        for (int i = 0; i < b1.length; ++i) {
            byte p1 = GenotypeTableUtils.getUnphasedSortedDiploidValue(b1[i]);
            byte p2 = GenotypeTableUtils.getUnphasedSortedDiploidValue(b2[i]);
            if (p1 == -1 || p2 == -1) continue;
            ++count;
            if (p1 == p2) continue;
            if (GenotypeTableUtils.isHeterozygous(p1) || GenotypeTableUtils.isHeterozygous(p2)) {
                ++distance;
                continue;
            }
            distance += 2;
        }
        if (count < min) {
            return new double[]{Double.NaN, count};
        }
        return new double[]{(double)distance / (double)(2 * count), count};
    }

    private /* synthetic */ void lambda$getHighLDMap$9(int numberOfSNPs, int numberOfSites, GenotypeTable genotypeTable, Multimap highLDMap, LongAdder sites1Kdone, int posIndex) {
        MinMaxPriorityQueue highestLD = MinMaxPriorityQueue.orderedBy((Comparator)LDResult.byR2Ordering.reverse()).maximumSize(numberOfSNPs).create();
        for (int site2 = 0; site2 < numberOfSites; ++site2) {
            LDResult ld;
            if (posIndex == site2 || this.maxDistance() > -1 && Math.abs(genotypeTable.chromosomalPosition(posIndex) - genotypeTable.chromosomalPosition(site2)) > this.maxDistance() || Double.isNaN((ld = LinkageDisequilibrium.calculateBitLDForHaplotype(false, 20, genotypeTable, posIndex, site2)).r2())) continue;
            highestLD.add((Object)ld);
        }
        ArrayList positionList = new ArrayList();
        for (LDResult result : highestLD) {
            positionList.add(genotypeTable.positions().get(result.site2()));
        }
        highLDMap.putAll(genotypeTable.positions().get(posIndex), positionList);
        if ((posIndex + 1) % 1000 == 0) {
            sites1Kdone.add(1000L);
            this.fireProgress((int)(33L * sites1Kdone.longValue()) / numberOfSites);
            System.out.println(sites1Kdone.longValue());
        }
    }

    private static /* synthetic */ void lambda$getClosestNonMissingTaxa$8(Multimap distGenoMap, Tuple distGeno) {
        distGenoMap.put(distGeno.x, distGeno.y);
    }
}

