/*
 * Decompiled with CFR 0.152.
 */
package net.maizegenetics.analysis.distance;

import java.util.Arrays;
import java.util.Optional;
import java.util.Spliterator;
import java.util.function.Consumer;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
import net.maizegenetics.analysis.distance.KinshipPlugin;
import net.maizegenetics.dna.snp.GenotypeTable;
import net.maizegenetics.dna.snp.genotypecall.AlleleFreqCache;
import net.maizegenetics.prefs.TasselPrefs;
import net.maizegenetics.taxa.TaxaList;
import net.maizegenetics.taxa.Taxon;
import net.maizegenetics.taxa.distance.DistanceMatrix;
import net.maizegenetics.taxa.distance.DistanceMatrixBuilder;
import net.maizegenetics.taxa.distance.DistanceMatrixWithCounts;
import net.maizegenetics.util.GeneralAnnotationStorage;
import net.maizegenetics.util.ProgressListener;
import net.maizegenetics.util.Tuple;
import org.apache.log4j.Logger;

public class GCTADistanceMatrix {
    private static final Logger myLogger;
    private static final byte[] PRECALCULATED_COUNTS;
    private static final byte[] INCREMENT;
    private static final int NUM_CORES_TO_USE;
    private static int myNumSitesProcessed;

    private GCTADistanceMatrix() {
    }

    public static DistanceMatrix getInstance(GenotypeTable genotype) {
        return GCTADistanceMatrix.getInstance(genotype, null);
    }

    public static DistanceMatrix getInstance(GenotypeTable genotype, ProgressListener listener) {
        return GCTADistanceMatrix.computeGCTADistances(genotype, listener);
    }

    private static DistanceMatrix computeGCTADistances(GenotypeTable genotype, ProgressListener listener) {
        int numTaxa = genotype.numberOfTaxa();
        long time = System.currentTimeMillis();
        Optional<CountersDistances> optional = GCTADistanceMatrix.stream(genotype, listener).reduce((t, u) -> {
            t.addAll((CountersDistances)u);
            return t;
        });
        if (!optional.isPresent()) {
            return null;
        }
        CountersDistances counters = optional.get();
        int[] counts = counters.myCounters;
        float[] distances = counters.myDistances;
        GeneralAnnotationStorage.Builder annotations = GeneralAnnotationStorage.getBuilder();
        annotations.addAnnotation("Matrix_Type", KinshipPlugin.KINSHIP_METHOD.Normalized_IBS.toString());
        DistanceMatrixBuilder builder = DistanceMatrixBuilder.getInstance(genotype.taxa());
        builder.annotation(annotations.build());
        int index = 0;
        for (int t2 = 0; t2 < numTaxa; ++t2) {
            for (int i = t2; i < numTaxa; ++i) {
                builder.set(t2, i, (double)distances[index] / (double)counts[index]);
                builder.setCount(t2, i, counts[index]);
                ++index;
            }
        }
        myLogger.info((Object)("GCTADistanceMatrix: computeGCTADistances time: " + (System.currentTimeMillis() - time) / 1000L + " seconds"));
        return builder.build();
    }

    public static DistanceMatrix subtractGCTADistance(DistanceMatrixWithCounts[] matrices, DistanceMatrixWithCounts superMatrix, ProgressListener listener) {
        int numTaxa = superMatrix.numberOfTaxa();
        String matrixType = superMatrix.annotations().getTextAnnotation("Matrix_Type")[0];
        if (!matrixType.equals(KinshipPlugin.KINSHIP_METHOD.Normalized_IBS.toString())) {
            throw new IllegalArgumentException("subtractGCTADistance: superset matrix must be matrix type: " + KinshipPlugin.KINSHIP_METHOD.Normalized_IBS.toString());
        }
        for (DistanceMatrixWithCounts current : matrices) {
            int currentNumTaxa = current.numberOfTaxa();
            if (currentNumTaxa != numTaxa) {
                throw new IllegalArgumentException("subtractGCTADistance: subset and superset must have same number of taxa.");
            }
            String[] currentMatrixType = current.annotations().getTextAnnotation("Matrix_Type");
            if (currentMatrixType.length == 0) {
                throw new IllegalArgumentException("subtractGCTADistance: subset matrix must be created with a more recent build of Tassel that adds neccessary annotations to the matrix");
            }
            if (matrixType.equals(currentMatrixType[0])) continue;
            throw new IllegalArgumentException("subtractGCTADistance: subset matrix must be matrix type: " + KinshipPlugin.KINSHIP_METHOD.Normalized_IBS.toString());
        }
        TaxaList superTaxaList = superMatrix.getTaxaList();
        for (DistanceMatrixWithCounts current : matrices) {
            TaxaList subsetTaxaList = current.getTaxaList();
            for (int t = 0; t < numTaxa; ++t) {
                if (((Taxon)superTaxaList.get(t)).equals(subsetTaxaList.get(t))) continue;
                throw new IllegalArgumentException("subtractGCTADistance: superset taxon: " + ((Taxon)superTaxaList.get(t)).getName() + " doesn't match subset taxon: " + subsetTaxaList.taxaName(t));
            }
        }
        DistanceMatrixBuilder builder = DistanceMatrixBuilder.getInstance(superTaxaList);
        int numMatrices = matrices.length;
        GeneralAnnotationStorage.Builder resultAnnotations = GeneralAnnotationStorage.getBuilder();
        resultAnnotations.addAnnotation("Matrix_Type", KinshipPlugin.KINSHIP_METHOD.Normalized_IBS.toString());
        builder.annotation(resultAnnotations.build());
        for (int t = 0; t < numTaxa; ++t) {
            int n = numTaxa - t;
            for (int i = 0; i < n; ++i) {
                int resultCount = superMatrix.getCount(t, t + i);
                double resultValue = (double)superMatrix.getDistance(t, t + i) * (double)resultCount;
                for (int j = 0; j < numMatrices; ++j) {
                    resultValue -= (double)(matrices[j].getDistance(t, t + i) * (float)matrices[j].getCount(t, t + i));
                    resultCount -= matrices[j].getCount(t, t + i);
                }
                builder.set(t, t + i, resultValue / (double)resultCount);
                builder.setCount(t, t + i, resultCount);
            }
        }
        return builder.build();
    }

    public static DistanceMatrix addGCTADistance(DistanceMatrixWithCounts[] matrices, ProgressListener listener) {
        int t;
        int numTaxa = matrices[0].numberOfTaxa();
        String matrixType = matrices[0].annotations().getTextAnnotation("Matrix_Type")[0];
        if (!matrixType.equals(KinshipPlugin.KINSHIP_METHOD.Normalized_IBS.toString())) {
            throw new IllegalArgumentException("addGCTADistance: superset matrix must be matrix type: " + KinshipPlugin.KINSHIP_METHOD.Normalized_IBS.toString());
        }
        for (int i = 1; i < matrices.length; ++i) {
            DistanceMatrixWithCounts current = matrices[i];
            int currentNumTaxa = current.numberOfTaxa();
            if (currentNumTaxa != numTaxa) {
                throw new IllegalArgumentException("addGCTADistance: all matrices must have same number of taxa.");
            }
            String[] currentMatrixType = current.annotations().getTextAnnotation("Matrix_Type");
            if (currentMatrixType.length == 0) {
                throw new IllegalArgumentException("addGCTADistance: matrix must be created with a more recent build of Tassel that adds neccessary annotations to the matrix");
            }
            if (matrixType.equals(currentMatrixType[0])) continue;
            throw new IllegalArgumentException("addGCTADistance: matrix must be matrix type: " + KinshipPlugin.KINSHIP_METHOD.Normalized_IBS.toString());
        }
        TaxaList superTaxaList = matrices[0].getTaxaList();
        for (int i = 1; i < matrices.length; ++i) {
            DistanceMatrixWithCounts current = matrices[i];
            TaxaList subsetTaxaList = current.getTaxaList();
            for (t = 0; t < numTaxa; ++t) {
                if (((Taxon)superTaxaList.get(t)).equals(subsetTaxaList.get(t))) continue;
                throw new IllegalArgumentException("addGCTADistance: superset taxon: " + ((Taxon)superTaxaList.get(t)).getName() + " doesn't match subset taxon: " + subsetTaxaList.taxaName(t));
            }
        }
        DistanceMatrixBuilder builder = DistanceMatrixBuilder.getInstance(superTaxaList);
        int numMatrices = matrices.length;
        GeneralAnnotationStorage.Builder resultAnnotations = GeneralAnnotationStorage.getBuilder();
        resultAnnotations.addAnnotation("Matrix_Type", KinshipPlugin.KINSHIP_METHOD.Normalized_IBS.toString());
        builder.annotation(resultAnnotations.build());
        for (t = 0; t < numTaxa; ++t) {
            int n = numTaxa - t;
            for (int i = 0; i < n; ++i) {
                int resultCount = 0;
                double resultValue = 0.0;
                for (int j = 0; j < numMatrices; ++j) {
                    resultValue += (double)(matrices[j].getDistance(t, t + i) * (float)matrices[j].getCount(t, t + i));
                    resultCount += matrices[j].getCount(t, t + i);
                }
                builder.set(t, t + i, resultValue / (double)resultCount);
                builder.setCount(t, t + i, resultCount);
            }
        }
        return builder.build();
    }

    protected static void fireProgress(int percent, ProgressListener listener) {
        if (listener != null) {
            if (percent > 100) {
                percent = 100;
            }
            listener.progress(percent, null);
        }
    }

    private static Stream<CountersDistances> stream(GenotypeTable genotypes, ProgressListener listener) {
        myNumSitesProcessed = 0;
        return StreamSupport.stream(new GCTASiteSpliterator(genotypes, 0, genotypes.numberOfSites(), listener), true);
    }

    static {
        int b;
        myLogger = Logger.getLogger(GCTADistanceMatrix.class);
        PRECALCULATED_COUNTS = new byte[512];
        for (int major = 0; major < 8; ++major) {
            for (int a = 0; a < 8; ++a) {
                for (b = 0; b < 8; ++b) {
                    int temp = major << 6 | a << 3 | b;
                    if (major == 7 || a == 7 && b == 7) {
                        GCTADistanceMatrix.PRECALCULATED_COUNTS[temp] = 7;
                        continue;
                    }
                    if (a == major) {
                        if (b == major) {
                            GCTADistanceMatrix.PRECALCULATED_COUNTS[temp] = 4;
                            continue;
                        }
                        GCTADistanceMatrix.PRECALCULATED_COUNTS[temp] = 2;
                        continue;
                    }
                    GCTADistanceMatrix.PRECALCULATED_COUNTS[temp] = b == major ? 2 : 1;
                }
            }
        }
        INCREMENT = new byte[32768];
        for (int a = 1; a < 8; ++a) {
            int temp = a << 12;
            for (b = 1; b < 8; ++b) {
                int temp2 = b << 9;
                for (int c = 1; c < 8; ++c) {
                    int temp3 = c << 6;
                    for (int d = 1; d < 8; ++d) {
                        int temp4 = d << 3;
                        for (int e = 1; e < 8; ++e) {
                            int incrementIndex = temp | temp2 | temp3 | temp4 | e;
                            if (a != 7) {
                                int n = incrementIndex;
                                INCREMENT[n] = (byte)(INCREMENT[n] + 1);
                            }
                            if (b != 7) {
                                int n = incrementIndex;
                                INCREMENT[n] = (byte)(INCREMENT[n] + 1);
                            }
                            if (c != 7) {
                                int n = incrementIndex;
                                INCREMENT[n] = (byte)(INCREMENT[n] + 1);
                            }
                            if (d != 7) {
                                int n = incrementIndex;
                                INCREMENT[n] = (byte)(INCREMENT[n] + 1);
                            }
                            if (e == 7) continue;
                            int n = incrementIndex;
                            INCREMENT[n] = (byte)(INCREMENT[n] + 1);
                        }
                    }
                }
            }
        }
        NUM_CORES_TO_USE = TasselPrefs.getMaxThreads();
        myNumSitesProcessed = 0;
    }

    static class GCTASiteSpliterator
    implements Spliterator<CountersDistances> {
        private int myCurrentSite;
        private final int myFence;
        private final GenotypeTable myGenotypes;
        private final int myNumTaxa;
        private final int myNumSites;
        private final ProgressListener myProgressListener;
        private final int myMinSitesToProcess;
        private final int myNumSitesPerBlockForProgressReporting;
        private static final int NUM_SITES_PER_BLOCK = 5;

        GCTASiteSpliterator(GenotypeTable genotypes, int currentIndex, int fence, ProgressListener listener) {
            this.myGenotypes = genotypes;
            this.myNumTaxa = this.myGenotypes.numberOfTaxa();
            this.myNumSites = this.myGenotypes.numberOfSites();
            this.myCurrentSite = currentIndex;
            this.myFence = fence;
            this.myProgressListener = listener;
            this.myMinSitesToProcess = Math.max(this.myNumSites / NUM_CORES_TO_USE, 1000);
            this.myNumSitesPerBlockForProgressReporting = (this.myFence - this.myCurrentSite) / 10;
        }

        @Override
        public void forEachRemaining(Consumer<? super CountersDistances> action) {
            CountersDistances result = new CountersDistances(this.myNumTaxa);
            int[] counts = result.myCounters;
            float[] distances = result.myDistances;
            float[] answer1 = new float[32768];
            float[] answer2 = new float[32768];
            float[] answer3 = new float[32768];
            while (this.myCurrentSite < this.myFence) {
                int currentBlockFence = Math.min(this.myCurrentSite + this.myNumSitesPerBlockForProgressReporting, this.myFence);
                int numSitesProcessed = currentBlockFence - this.myCurrentSite;
                while (this.myCurrentSite < currentBlockFence) {
                    int[] numSites = new int[1];
                    Tuple<short[], float[]> firstBlock = this.getBlockOfSites(this.myCurrentSite, numSites);
                    float[] possibleTerms = (float[])firstBlock.y;
                    short[] majorCount1 = (short[])firstBlock.x;
                    Tuple<short[], float[]> secondBlock = this.getBlockOfSites(this.myCurrentSite + numSites[0], numSites);
                    float[] possibleTerms2 = (float[])secondBlock.y;
                    short[] majorCount2 = (short[])secondBlock.x;
                    Tuple<short[], float[]> thirdBlock = this.getBlockOfSites(this.myCurrentSite + numSites[0], numSites);
                    float[] possibleTerms3 = (float[])thirdBlock.y;
                    short[] majorCount3 = (short[])thirdBlock.x;
                    this.myCurrentSite += numSites[0];
                    for (int i = 0; i < 32768; ++i) {
                        answer1[i] = possibleTerms[(i & 0x7000) >>> 12] + possibleTerms[(i & 0xE00) >>> 9 | 8] + possibleTerms[(i & 0x1C0) >>> 6 | 0x10] + possibleTerms[(i & 0x38) >>> 3 | 0x18] + possibleTerms[i & 7 | 0x20];
                        answer2[i] = possibleTerms2[(i & 0x7000) >>> 12] + possibleTerms2[(i & 0xE00) >>> 9 | 8] + possibleTerms2[(i & 0x1C0) >>> 6 | 0x10] + possibleTerms2[(i & 0x38) >>> 3 | 0x18] + possibleTerms2[i & 7 | 0x20];
                        answer3[i] = possibleTerms3[(i & 0x7000) >>> 12] + possibleTerms3[(i & 0xE00) >>> 9 | 8] + possibleTerms3[(i & 0x1C0) >>> 6 | 0x10] + possibleTerms3[(i & 0x38) >>> 3 | 0x18] + possibleTerms3[i & 7 | 0x20];
                    }
                    int index = 0;
                    for (int firstTaxa = 0; firstTaxa < this.myNumTaxa; ++firstTaxa) {
                        if (majorCount1[firstTaxa] != Short.MAX_VALUE || majorCount2[firstTaxa] != Short.MAX_VALUE || majorCount3[firstTaxa] != Short.MAX_VALUE) {
                            for (int secondTaxa = firstTaxa; secondTaxa < this.myNumTaxa; ++secondTaxa) {
                                int n = index;
                                distances[n] = distances[n] + (answer1[majorCount1[firstTaxa] | majorCount1[secondTaxa]] + answer2[majorCount2[firstTaxa] | majorCount2[secondTaxa]] + answer3[majorCount3[firstTaxa] | majorCount3[secondTaxa]]);
                                int n2 = index++;
                                counts[n2] = counts[n2] + (INCREMENT[majorCount1[firstTaxa] | majorCount1[secondTaxa]] + INCREMENT[majorCount2[firstTaxa] | majorCount2[secondTaxa]] + INCREMENT[majorCount3[firstTaxa] | majorCount3[secondTaxa]]);
                            }
                            continue;
                        }
                        index += this.myNumTaxa - firstTaxa;
                    }
                }
                myNumSitesProcessed = myNumSitesProcessed + numSitesProcessed;
                GCTADistanceMatrix.fireProgress((int)((double)myNumSitesProcessed / (double)this.myNumSites * 100.0), this.myProgressListener);
            }
            action.accept(result);
        }

        private Tuple<short[], float[]> getBlockOfSites(int currentSite, int[] numSites) {
            int currentSiteNum = 0;
            float[] possibleTerms = new float[40];
            short[] majorCount = new short[this.myNumTaxa];
            Arrays.fill(majorCount, (short)Short.MAX_VALUE);
            while (currentSiteNum < 5 && currentSite < this.myFence) {
                byte[] genotypes = this.myGenotypes.genotypeAllTaxa(currentSite);
                int[][] alleleCounts = AlleleFreqCache.allelesSortedByFrequencyNucleotide(genotypes);
                byte major = AlleleFreqCache.majorAllele(alleleCounts);
                float majorFreq = (float)AlleleFreqCache.majorAlleleFrequency(alleleCounts);
                float majorFreqTimes2 = majorFreq * 2.0f;
                float denominatorTerm = majorFreqTimes2 * (1.0f - majorFreq);
                float[] term = new float[3];
                if (major != 15 && (double)denominatorTerm != 0.0) {
                    term[0] = 0.0f - majorFreqTimes2;
                    term[1] = 1.0f - majorFreqTimes2;
                    term[2] = 2.0f - majorFreqTimes2;
                    int siteNumIncrement = currentSiteNum * 8;
                    possibleTerms[siteNumIncrement + 1] = term[0] * term[0] / denominatorTerm;
                    possibleTerms[siteNumIncrement + 3] = term[0] * term[1] / denominatorTerm;
                    possibleTerms[siteNumIncrement + 5] = term[0] * term[2] / denominatorTerm;
                    possibleTerms[siteNumIncrement + 2] = term[1] * term[1] / denominatorTerm;
                    possibleTerms[siteNumIncrement + 6] = term[1] * term[2] / denominatorTerm;
                    possibleTerms[siteNumIncrement + 4] = term[2] * term[2] / denominatorTerm;
                    int temp = (major & 7) << 6;
                    int shift = (5 - currentSiteNum - 1) * 3;
                    int mask = ~(7 << shift) & Short.MAX_VALUE;
                    for (int i = 0; i < this.myNumTaxa; ++i) {
                        majorCount[i] = (short)(majorCount[i] & (mask | PRECALCULATED_COUNTS[temp | (genotypes[i] & 0x70) >>> 1 | genotypes[i] & 7] << shift));
                    }
                    ++currentSiteNum;
                }
                ++currentSite;
                numSites[0] = numSites[0] + 1;
            }
            return new Tuple<short[], float[]>(majorCount, possibleTerms);
        }

        @Override
        public boolean tryAdvance(Consumer<? super CountersDistances> action) {
            if (this.myCurrentSite < this.myFence) {
                this.forEachRemaining(action);
                return true;
            }
            return false;
        }

        @Override
        public Spliterator<CountersDistances> trySplit() {
            int lo = this.myCurrentSite;
            int mid = lo + this.myMinSitesToProcess;
            if (mid < this.myFence) {
                this.myCurrentSite = mid;
                return new GCTASiteSpliterator(this.myGenotypes, lo, mid, this.myProgressListener);
            }
            return null;
        }

        @Override
        public long estimateSize() {
            return this.myFence - this.myCurrentSite;
        }

        @Override
        public int characteristics() {
            return 1024;
        }
    }

    private static class CountersDistances {
        private final int[] myCounters;
        private final float[] myDistances;
        private final int myNumTaxa;

        public CountersDistances(int numTaxa) {
            this.myNumTaxa = numTaxa;
            this.myCounters = new int[this.myNumTaxa * (this.myNumTaxa + 1) / 2];
            this.myDistances = new float[this.myNumTaxa * (this.myNumTaxa + 1) / 2];
        }

        public void addAll(CountersDistances counters) {
            float[] otherDistances = counters.myDistances;
            int n = this.myCounters.length;
            for (int t = 0; t < n; ++t) {
                int n2 = t;
                this.myDistances[n2] = this.myDistances[n2] + otherDistances[t];
            }
            otherDistances = null;
            int[] otherCounters = counters.myCounters;
            int n3 = this.myCounters.length;
            for (int t = 0; t < n3; ++t) {
                int n4 = t;
                this.myCounters[n4] = this.myCounters[n4] + otherCounters[t];
            }
        }
    }
}

