/*
 * Decompiled with CFR 0.152.
 */
package net.maizegenetics.analysis.association;

import java.awt.Frame;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import javax.swing.ImageIcon;
import net.maizegenetics.analysis.association.AssociationUtils;
import net.maizegenetics.dna.map.Position;
import net.maizegenetics.dna.snp.GenotypeTable;
import net.maizegenetics.dna.snp.GenotypeTableUtils;
import net.maizegenetics.phenotype.CategoricalAttribute;
import net.maizegenetics.phenotype.GenotypePhenotype;
import net.maizegenetics.phenotype.NumericAttribute;
import net.maizegenetics.phenotype.Phenotype;
import net.maizegenetics.phenotype.PhenotypeAttribute;
import net.maizegenetics.plugindef.AbstractPlugin;
import net.maizegenetics.plugindef.DataSet;
import net.maizegenetics.plugindef.Datum;
import net.maizegenetics.plugindef.Plugin;
import net.maizegenetics.plugindef.PluginParameter;
import net.maizegenetics.prefs.TasselPrefs;
import net.maizegenetics.stats.linearmodels.CovariateModelEffect;
import net.maizegenetics.stats.linearmodels.FactorModelEffect;
import net.maizegenetics.stats.linearmodels.LinearModelUtils;
import net.maizegenetics.stats.linearmodels.ModelEffect;
import net.maizegenetics.stats.linearmodels.SolveByOrthogonalizing;
import net.maizegenetics.util.TableReport;
import net.maizegenetics.util.TableReportBuilder;
import org.apache.commons.math3.distribution.FDistribution;
import org.apache.commons.math3.exception.OutOfRangeException;
import org.apache.log4j.Logger;

public class FastMultithreadedAssociationPlugin
extends AbstractPlugin {
    private static Logger myLogger = Logger.getLogger(FastMultithreadedAssociationPlugin.class);
    private GenotypeTable.GENOTYPE_TABLE_COMPONENT[] GENOTYPE_COMP = new GenotypeTable.GENOTYPE_TABLE_COMPONENT[]{GenotypeTable.GENOTYPE_TABLE_COMPONENT.Genotype, GenotypeTable.GENOTYPE_TABLE_COMPONENT.ReferenceProbability, GenotypeTable.GENOTYPE_TABLE_COMPONENT.AlleleProbability};
    private final byte NN = (byte)-1;
    private Phenotype myPhenotype;
    private GenotypeTable myGenotype;
    List<String> phenotypeNames;
    double minR2;
    private FDistribution Fdist;
    GenotypePhenotype myGenoPheno;
    private PluginParameter<Double> maxp = new PluginParameter.Builder<Double>("MaxPValue", 0.001, Double.class).guiName("MaxPValue").description("The maximum p-value that will be output by the analysis.").build();
    private PluginParameter<GenotypeTable.GENOTYPE_TABLE_COMPONENT> myGenotypeTable = new PluginParameter.Builder<GenotypeTable.GENOTYPE_TABLE_COMPONENT>("genotypeComponent", GenotypeTable.GENOTYPE_TABLE_COMPONENT.Genotype, GenotypeTable.GENOTYPE_TABLE_COMPONENT.class).genotypeTable().range((Comparable<T>[])this.GENOTYPE_COMP).description("If the genotype table contains more than one type of genotype data, choose the type to use for the analysis.").build();
    private PluginParameter<Boolean> saveAsFile = new PluginParameter.Builder<Boolean>("writeToFile", false, Boolean.class).description("Should the results be saved to a file rather than stored in memory? It true, the results will be written to a file as each SNP is analyzed in order to reduce memory requirementsand the results will NOT be saved to the data tree. Default = false.").guiName("Write to file").build();
    private PluginParameter<String> reportFilename = new PluginParameter.Builder<String>("outputFile", null, String.class).outFile().dependentOnParameter(this.saveAsFile).description("The name of the file to which these results will be saved.").guiName("Output File").build();
    private PluginParameter<Integer> maxThreads = new PluginParameter.Builder<Integer>("maxThreads", TasselPrefs.getMaxThreads(), Integer.class).description("the maximum number of threads to be used by this plugin.").guiName("Max Threads").build();

    public FastMultithreadedAssociationPlugin() {
        this(null, false);
    }

    public FastMultithreadedAssociationPlugin(Frame parentFrame, boolean isInteractive) {
        super(parentFrame, isInteractive);
    }

    @Override
    protected void preProcessParameters(DataSet input) {
        List<Datum> inData = input.getDataOfType(GenotypePhenotype.class);
        if (inData.size() != 1) {
            throw new IllegalArgumentException("Fast Association requires exactly one joined genotype-phenotype data set.");
        }
    }

    @Override
    public DataSet processData(DataSet input) {
        long start = System.currentTimeMillis();
        int maxSitesInQueue = 2000;
        int maxObjectsInQueue = 1000;
        Datum inDatum = input.getDataOfType(GenotypePhenotype.class).get(0);
        this.myGenoPheno = (GenotypePhenotype)inDatum.getData();
        this.myGenotype = this.myGenoPheno.genotypeTable();
        this.myPhenotype = this.myGenoPheno.phenotype();
        int numberOfObservations = this.myPhenotype.numberOfObservations();
        this.testMissingDataInTheBaseModel();
        SolveByOrthogonalizing sbo = this.initializeOrthogonalizer();
        double errdf = numberOfObservations - sbo.baseDf() - 1;
        this.Fdist = new FDistribution(1.0, errdf);
        this.calculateR2Fromp(errdf);
        TableReportBuilder myReport = this.initializeOutput(inDatum);
        int nthreads = this.maxThreads.value();
        nthreads = Math.max(nthreads, 2);
        int siteTesterThreads = nthreads - 1;
        ExecutorService myExecutor = Executors.newFixedThreadPool(nthreads);
        LinkedBlockingQueue<Object[]> reportQueue = new LinkedBlockingQueue<Object[]>();
        LinkedBlockingQueue<Marker> siteQueue = new LinkedBlockingQueue<Marker>(maxSitesInQueue);
        List<double[]> dataList = sbo.getOrthogonalizedData();
        List<double[]> uList = sbo.getUColumns();
        for (int i = 0; i < siteTesterThreads; ++i) {
            myExecutor.execute(new SiteTester(dataList, this.phenotypeNames, uList, siteQueue, reportQueue, this.minR2, errdf, numberOfObservations));
        }
        myExecutor.execute(new ReportWriter(myReport, reportQueue, siteTesterThreads));
        System.out.printf("Time to set up threads = %d ms.\n", System.currentTimeMillis() - start);
        start = System.currentTimeMillis();
        int nsites = this.myGenotype.numberOfSites();
        System.out.printf("myGenotype has %d sites\n", nsites);
        for (int s = 0; s < nsites; ++s) {
            if (s % 1000000 == 0) {
                myLogger.info((Object)("Adding site " + s + " to the site queue."));
            }
            try {
                byte major = this.myGenotype.majorAllele(s);
                double freq = this.myGenotype.majorAlleleFrequency(s);
                byte[] geno = this.myGenoPheno.genotypeAllTaxa(s);
                siteQueue.put(new Marker(geno, major, freq, (Position)this.myGenotype.positions().get(s)));
                continue;
            }
            catch (Exception e) {
                throw new RuntimeException("Site thread interrupted at site " + s, e);
            }
        }
        for (int i = 0; i < nthreads; ++i) {
            byte zerobyte = 0;
            try {
                siteQueue.put(new Marker(new byte[0], zerobyte, 0.0, null));
                continue;
            }
            catch (InterruptedException e) {
                throw new RuntimeException("siteQueue interrupted", e);
            }
        }
        myExecutor.shutdown();
        try {
            myExecutor.awaitTermination(1L, TimeUnit.HOURS);
        }
        catch (InterruptedException e) {
            e.printStackTrace();
        }
        System.out.printf("Time to process sites = %d ms.\n", System.currentTimeMillis() - start);
        if (this.saveAsFile.value().booleanValue()) {
            myReport.build();
            return null;
        }
        String name = String.format("Fast Association_%s", inDatum.getName());
        String comment = String.format("Fast Association Test Results\n Source = %s", inDatum.getName());
        return new DataSet(new Datum(name, myReport.build(), comment), (Plugin)this);
    }

    private void testMissingDataInTheBaseModel() {
        for (PhenotypeAttribute attr : this.myPhenotype.attributeListOfType(Phenotype.ATTRIBUTE_TYPE.factor)) {
            if (attr.missing().cardinality() <= 0L) continue;
            String msg = "There is missing data in the factor " + attr.name();
            throw new IllegalArgumentException(msg);
        }
        for (PhenotypeAttribute attr : this.myPhenotype.attributeListOfType(Phenotype.ATTRIBUTE_TYPE.covariate)) {
            if (attr.missing().cardinality() <= 0L) continue;
            String msg = "There is missing data in the covariate " + attr.name();
            throw new IllegalArgumentException(msg);
        }
        for (PhenotypeAttribute attr : this.myPhenotype.attributeListOfType(Phenotype.ATTRIBUTE_TYPE.data)) {
            if (attr.missing().cardinality() <= 0L) continue;
            String msg = "There is missing data in the phenotype " + attr.name();
            throw new IllegalArgumentException(msg);
        }
    }

    private SolveByOrthogonalizing initializeOrthogonalizer() {
        List<PhenotypeAttribute> phenotypeList = this.myPhenotype.attributeListOfType(Phenotype.ATTRIBUTE_TYPE.data);
        List<PhenotypeAttribute> covariateList = this.myPhenotype.attributeListOfType(Phenotype.ATTRIBUTE_TYPE.covariate);
        List<PhenotypeAttribute> factorList = this.myPhenotype.attributeListOfType(Phenotype.ATTRIBUTE_TYPE.factor);
        ArrayList<ModelEffect> baseModel = new ArrayList<ModelEffect>();
        for (PhenotypeAttribute pa2 : factorList) {
            CategoricalAttribute ca = (CategoricalAttribute)pa2;
            baseModel.add(new FactorModelEffect(ca.allIntValues(), true, ca.name()));
        }
        for (PhenotypeAttribute pa2 : covariateList) {
            NumericAttribute na = (NumericAttribute)pa2;
            CovariateModelEffect cme = new CovariateModelEffect(AssociationUtils.convertFloatArrayToDouble(na.floatValues()), na.name());
            baseModel.add(cme);
        }
        List<double[]> dataList = phenotypeList.stream().map(pa -> (float[])pa.allValues()).map(a -> AssociationUtils.convertFloatArrayToDouble(a)).collect(Collectors.toList());
        this.phenotypeNames = phenotypeList.stream().map(PhenotypeAttribute::name).collect(Collectors.toList());
        return SolveByOrthogonalizing.getInstanceFromModel(baseModel, dataList);
    }

    private TableReportBuilder initializeOutput(Datum myDatum) {
        Object[] columnNames = new String[]{"Trait", "Marker", "Chr", "Pos", "df", "r2", "p"};
        String name = "EqtlReport_" + myDatum.getName();
        if (this.saveAsFile.value().booleanValue()) {
            return TableReportBuilder.getInstance(name, columnNames, this.reportFilename.value());
        }
        return TableReportBuilder.getInstance(name, columnNames);
    }

    private void calculateR2Fromp(double errdf) {
        double p = 1.0 - this.maxp.value();
        try {
            double F = this.Fdist.inverseCumulativeProbability(p);
            this.minR2 = F / (errdf + F);
        }
        catch (OutOfRangeException e) {
            e.printStackTrace();
            this.minR2 = Double.NaN;
        }
    }

    @Override
    public ImageIcon getIcon() {
        return null;
    }

    @Override
    public String getButtonName() {
        return "Fast-MT Association";
    }

    @Override
    public String getToolTipText() {
        return "Multi-threaded version of Fast Association";
    }

    public TableReport runPlugin(DataSet input) {
        return (TableReport)this.performFunction(input).getData(0).getData();
    }

    public Double maxp() {
        return this.maxp.value();
    }

    public FastMultithreadedAssociationPlugin maxp(Double value) {
        this.maxp = new PluginParameter<Double>(this.maxp, value);
        return this;
    }

    public GenotypeTable.GENOTYPE_TABLE_COMPONENT genotypeTable() {
        return this.myGenotypeTable.value();
    }

    public FastMultithreadedAssociationPlugin genotypeTable(GenotypeTable.GENOTYPE_TABLE_COMPONENT value) {
        this.myGenotypeTable = new PluginParameter<GenotypeTable.GENOTYPE_TABLE_COMPONENT>(this.myGenotypeTable, value);
        return this;
    }

    public Boolean saveAsFile() {
        return this.saveAsFile.value();
    }

    public FastMultithreadedAssociationPlugin saveAsFile(Boolean value) {
        this.saveAsFile = new PluginParameter<Boolean>(this.saveAsFile, value);
        return this;
    }

    public String reportFilename() {
        return this.reportFilename.value();
    }

    public FastMultithreadedAssociationPlugin reportFilename(String value) {
        this.reportFilename = new PluginParameter<String>(this.reportFilename, value);
        return this;
    }

    public Integer maxThreads() {
        return this.maxThreads.value();
    }

    public FastMultithreadedAssociationPlugin maxThreads(Integer value) {
        this.maxThreads = new PluginParameter<Integer>(this.maxThreads, value);
        return this;
    }

    class Marker {
        byte[] geno;
        byte major;
        double majorFrequency;
        Position myPosition;

        Marker(byte[] geno, byte major, double majorFreq, Position pos) {
            this.geno = geno;
            this.major = major;
            this.majorFrequency = majorFreq;
            this.myPosition = pos;
        }
    }

    class ReportWriter
    extends Thread {
        TableReportBuilder myReportBuilder;
        BlockingQueue<Object[]> myReportQueue;
        int numberOfSources;

        ReportWriter(TableReportBuilder reportBuilder, BlockingQueue<Object[]> reportQueue, int numberOfSourceThreads) {
            this.myReportBuilder = reportBuilder;
            this.myReportQueue = reportQueue;
            this.numberOfSources = numberOfSourceThreads;
        }

        @Override
        public void run() {
            int numberOfFinishedThreads = 0;
            try {
                do {
                    Object[] reportRow;
                    if ((reportRow = this.myReportQueue.poll(30L, TimeUnit.MINUTES)) == null) {
                        throw new IllegalStateException("ERROR: report queue timed out.");
                    }
                    if (reportRow.length > 0) {
                        this.myReportBuilder.add(reportRow);
                        continue;
                    }
                    System.out.printf("number of threads finished = %d\n", ++numberOfFinishedThreads);
                } while (numberOfFinishedThreads < this.numberOfSources);
            }
            catch (InterruptedException e) {
                throw new RuntimeException("Report thread was interrupted.", e);
            }
            System.out.println("report thread finished");
        }
    }

    class SiteTester
    extends Thread {
        final List<double[]> orthogonalPhenotypes;
        final List<String> phenotypeNames;
        final List<double[]> Ucolumns;
        final BlockingQueue<Marker> siteQueue;
        final BlockingQueue<Object[]> outQueue;
        final double minR2;
        final int nphenotypes;
        final int numberOfObservations;
        final double errdf;
        private final FDistribution Fdist;

        SiteTester(List<double[]> orthogonalPhenotypes, List<String> phenotypeNames, List<double[]> Ucol, BlockingQueue<Marker> siteQueue, BlockingQueue<Object[]> outQueue, double minRSquare, double errdf, int ntaxa) {
            this.orthogonalPhenotypes = orthogonalPhenotypes;
            this.phenotypeNames = phenotypeNames;
            this.Ucolumns = Ucol;
            this.siteQueue = siteQueue;
            this.outQueue = outQueue;
            this.minR2 = minRSquare;
            this.numberOfObservations = ntaxa;
            this.errdf = errdf;
            this.nphenotypes = orthogonalPhenotypes.size();
            this.Fdist = new FDistribution(1.0, errdf);
        }

        @Override
        public void run() {
            try {
                Marker thisMarker = this.siteQueue.poll(30L, TimeUnit.SECONDS);
                if (thisMarker == null) {
                    this.outQueue.put(new Object[0]);
                    throw new IllegalStateException("ERROR: The site tester timeout was exceeded.");
                }
                byte[] geno = thisMarker.geno;
                while (geno.length > 0) {
                    byte major = thisMarker.major;
                    double genoMean = thisMarker.majorFrequency;
                    double[] siteValues = new double[this.numberOfObservations];
                    for (int t = 0; t < this.numberOfObservations; ++t) {
                        if (geno[t] == -1) {
                            siteValues[t] = 0.0;
                            continue;
                        }
                        siteValues[t] = -genoMean;
                        byte[] alleles = GenotypeTableUtils.getDiploidValues(geno[t]);
                        if (alleles[0] == major) {
                            int n = t;
                            siteValues[n] = siteValues[n] + 0.5;
                        }
                        if (alleles[1] != major) continue;
                        int n = t;
                        siteValues[n] = siteValues[n] + 0.5;
                    }
                    double sum = 0.0;
                    for (double d : siteValues) {
                        sum += d;
                    }
                    sum /= (double)this.numberOfObservations;
                    siteValues = this.orthogonalizeByBase(siteValues);
                    if ((siteValues = SolveByOrthogonalizing.centerAndScale(siteValues)) == null) {
                        System.err.printf("siteValues null at position %d, probably invariant\n", thisMarker.myPosition.getPosition());
                    } else {
                        double[] r2values = new double[this.nphenotypes];
                        for (int p = 0; p < this.nphenotypes; ++p) {
                            double sumprod = 0.0;
                            double[] pheno = this.orthogonalPhenotypes.get(p);
                            for (int t = 0; t < this.numberOfObservations; ++t) {
                                sumprod += siteValues[t] * pheno[t];
                            }
                            r2values[p] = sumprod * sumprod;
                        }
                        this.outputResult(r2values, thisMarker.myPosition);
                    }
                    thisMarker = this.siteQueue.poll(2L, TimeUnit.SECONDS);
                    if (thisMarker == null) {
                        this.outQueue.put(new Object[0]);
                        throw new IllegalStateException("Error: The site tester timeout was exceeded.");
                    }
                    geno = thisMarker.geno;
                }
                this.outQueue.put(new Object[0]);
            }
            catch (InterruptedException e) {
                throw new RuntimeException("InterruptedException occurred in SiteTester thread", e);
            }
        }

        private double[] orthogonalizeByBase(double[] vector) {
            if (this.Ucolumns == null || this.Ucolumns.size() == 0) {
                return vector;
            }
            int nrows = vector.length;
            double[] result = Arrays.copyOf(vector, nrows);
            for (double[] u : this.Ucolumns) {
                double ip = SolveByOrthogonalizing.innerProduct(vector, u);
                for (int j = 0; j < nrows; ++j) {
                    int n = j;
                    result[n] = result[n] - ip * u[j];
                }
            }
            return result;
        }

        private void outputResult(double[] rvalues, Position pos) throws InterruptedException {
            for (int p = 0; p < this.nphenotypes; ++p) {
                if (!(rvalues[p] >= this.minR2)) continue;
                Object[] result = new Object[]{this.phenotypeNames.get(p), pos.getSNPID(), pos.getChromosome().getName(), pos.getPosition(), 1, rvalues[p], this.pvalue(rvalues[p])};
                this.outQueue.put(result);
            }
        }

        private double pvalue(double rvalue) {
            double p;
            double F = rvalue / (1.0 - rvalue) * this.errdf;
            try {
                p = LinearModelUtils.Ftest(F, 1.0, this.errdf);
            }
            catch (Exception e) {
                p = Double.NaN;
            }
            return p;
        }
    }
}

