/*
 * Decompiled with CFR 0.152.
 */
package net.maizegenetics.analysis.association;

import java.awt.Frame;
import java.net.URL;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.stream.IntStream;
import javax.swing.ImageIcon;
import net.maizegenetics.analysis.association.AssociationUtils;
import net.maizegenetics.matrixalgebra.Matrix.DoubleMatrix;
import net.maizegenetics.matrixalgebra.Matrix.DoubleMatrixFactory;
import net.maizegenetics.phenotype.CategoricalAttribute;
import net.maizegenetics.phenotype.NumericAttribute;
import net.maizegenetics.phenotype.Phenotype;
import net.maizegenetics.phenotype.PhenotypeAttribute;
import net.maizegenetics.phenotype.PhenotypeBuilder;
import net.maizegenetics.phenotype.TaxaAttribute;
import net.maizegenetics.plugindef.AbstractPlugin;
import net.maizegenetics.plugindef.DataSet;
import net.maizegenetics.plugindef.Datum;
import net.maizegenetics.plugindef.Plugin;
import net.maizegenetics.plugindef.PluginParameter;
import net.maizegenetics.stats.EMMA.EMMAforDoubleMatrix;
import net.maizegenetics.stats.linearmodels.BasicShuffler;
import net.maizegenetics.stats.linearmodels.FactorModelEffect;
import net.maizegenetics.taxa.TaxaList;
import net.maizegenetics.taxa.TaxaListBuilder;
import net.maizegenetics.taxa.TaxaListUtils;
import net.maizegenetics.taxa.Taxon;
import net.maizegenetics.taxa.distance.DistanceMatrix;
import net.maizegenetics.taxa.distance.DistanceMatrixUtils;
import net.maizegenetics.util.TableReportBuilder;
import org.apache.commons.math3.stat.correlation.PearsonsCorrelation;
import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics;
import org.apache.log4j.Logger;

public class GenomicSelectionPlugin
extends AbstractPlugin {
    private static final Logger myLogger = Logger.getLogger(GenomicSelectionPlugin.class);
    private PluginParameter<Boolean> performCrossValidation = new PluginParameter.Builder<Boolean>("doCV", true, Boolean.class).description("Perform cross-validation: True or False").guiName("Perform cross-validation").build();
    private PluginParameter<Integer> kFolds = new PluginParameter.Builder<Integer>("kFolds", 5, Integer.class).description("Number of folds to use for k-fold cross-validation (default = 5)").guiName("Number of folds").dependentOnParameter(this.performCrossValidation).build();
    private PluginParameter<Integer> nIterations = new PluginParameter.Builder<Integer>("nIter", 20, Integer.class).description("Number of iterations when running k-fold cross-validation (default = 20)").guiName("Number of iterations").dependentOnParameter(this.performCrossValidation).build();

    public GenomicSelectionPlugin(Frame parentFrame, boolean isInteractive) {
        super(parentFrame, isInteractive);
    }

    @Override
    public DataSet processData(DataSet input) {
        List<Datum> myDataList = input.getDataOfType(Phenotype.class);
        if (myDataList.size() == 0) {
            throw new IllegalArgumentException("No phenotype selected.");
        }
        if (myDataList.size() > 1) {
            throw new IllegalArgumentException("Too many phenotypes selected.");
        }
        Phenotype myPhenotype = (Phenotype)myDataList.get(0).getData();
        String inputPhenotypeName = myDataList.get(0).getName();
        myDataList = input.getDataOfType(DistanceMatrix.class);
        if (myDataList.size() == 0) {
            throw new IllegalArgumentException("No kinship matrix selected.");
        }
        if (myDataList.size() > 1) {
            throw new IllegalArgumentException("Too many kinship matrices selected.");
        }
        DistanceMatrix kinship = (DistanceMatrix)myDataList.get(0).getData();
        TaxaList phenoTaxa = myPhenotype.taxa();
        TaxaList kinTaxa = kinship.getTaxaList();
        TaxaList jointTaxa = TaxaListUtils.getCommonTaxa(phenoTaxa, kinTaxa);
        Phenotype reducedPheno = new PhenotypeBuilder().fromPhenotype(myPhenotype).keepTaxa(jointTaxa).build().get(0);
        if (this.performCrossValidation.value().booleanValue()) {
            return this.processDataforCrossValidation(reducedPheno, kinship, inputPhenotypeName);
        }
        return this.processDataforPrediction(reducedPheno, kinship, inputPhenotypeName);
    }

    public DataSet processDataforPrediction(Phenotype myPhenotype, DistanceMatrix kinshipMatrix, String inputPhenotypeName) {
        Object[] columnHeaders = new String[]{"Trait", "Taxon", "Observed", "Predicted", "PEV"};
        String tableName = "Genomic Prediction Results";
        TableReportBuilder myReportBuilder = TableReportBuilder.getInstance(tableName, columnHeaders);
        List<Taxon> phenoTaxa = myPhenotype.taxaAttribute().allTaxaAsList();
        TaxaList phenoTaxaList = new TaxaListBuilder().addAll((Collection<Taxon>)phenoTaxa).build();
        DistanceMatrix myKinship = DistanceMatrixUtils.keepTaxa(kinshipMatrix, phenoTaxaList);
        for (PhenotypeAttribute attr : myPhenotype.attributeListOfType(Phenotype.ATTRIBUTE_TYPE.data)) {
            NumericAttribute dataAttribute = (NumericAttribute)attr;
            String traitname = dataAttribute.name();
            double[] phenotypeData = AssociationUtils.convertFloatArrayToDouble(dataAttribute.floatValues());
            int nObs = phenotypeData.length;
            DoubleMatrix phenotype = DoubleMatrixFactory.DEFAULT.make(nObs, 1, phenotypeData);
            DoubleMatrix fixedEffects = this.fixedEffectMatrix(myPhenotype);
            DoubleMatrix kinship = DoubleMatrixFactory.DEFAULT.make(myKinship.getClonedDistances());
            EMMAforDoubleMatrix runEMMA = new EMMAforDoubleMatrix(phenotype, fixedEffects, kinship);
            runEMMA.setCalculatePEV(true);
            runEMMA.solve();
            runEMMA.calculateBlupsPredicted();
            TaxaAttribute myTaxa = myPhenotype.taxaAttribute();
            DoubleMatrix predictedValues = runEMMA.getPred();
            DoubleMatrix pevs = runEMMA.getPev();
            for (int obs = 0; obs < nObs; ++obs) {
                Object[] reportRow = new Object[]{traitname, myTaxa.taxon(obs).getName(), phenotype.get(obs, 0), predictedValues.get(obs, 0), pevs.get(obs, 0)};
                myReportBuilder.add(reportRow);
            }
        }
        String datumName = "Prediction_" + inputPhenotypeName;
        String comment = "Genomic Prediction for " + inputPhenotypeName;
        DataSet myReportSet = new DataSet(new Datum(datumName, myReportBuilder.build(), comment), (Plugin)this);
        return myReportSet;
    }

    public DataSet processDataforCrossValidation(Phenotype reducedPheno, DistanceMatrix kinshipOriginal, String inputPhenotypeName) {
        Object[] columnHeaders = new String[]{"Trait", "Iteration", "Fold", "Accuracy"};
        String tableName = "Genomic Prediction Accuracy";
        TableReportBuilder myReportBuilder = TableReportBuilder.getInstance(tableName, columnHeaders);
        ArrayList<String> commentList = new ArrayList<String>();
        int[] taxaAttrIndex = reducedPheno.attributeIndicesOfType(Phenotype.ATTRIBUTE_TYPE.taxa);
        int[] dataAttrIndex = reducedPheno.attributeIndicesOfType(Phenotype.ATTRIBUTE_TYPE.data);
        int[] factorAttrIndex = reducedPheno.attributeIndicesOfType(Phenotype.ATTRIBUTE_TYPE.factor);
        int[] covariateAttrIndex = reducedPheno.attributeIndicesOfType(Phenotype.ATTRIBUTE_TYPE.covariate);
        int nAttributes = taxaAttrIndex.length + factorAttrIndex.length + covariateAttrIndex.length + 1;
        int[] singlePhenotypeIndex = new int[nAttributes];
        int counter = 0;
        for (int addIndex : taxaAttrIndex) {
            singlePhenotypeIndex[counter++] = addIndex;
        }
        for (int addIndex : factorAttrIndex) {
            singlePhenotypeIndex[counter++] = addIndex;
        }
        for (int addIndex : covariateAttrIndex) {
            singlePhenotypeIndex[counter++] = addIndex;
        }
        int numberOfIterations = this.nIterations.value();
        int numberOfFolds = this.kFolds.value();
        int numberOfTraits = dataAttrIndex.length;
        int numberOfComputes = numberOfIterations * numberOfFolds * numberOfTraits;
        int updateProgressValue = Math.max(1, numberOfComputes / 100);
        int computeCount = 0;
        int[] nArray = dataAttrIndex;
        int n = nArray.length;
        for (int i = 0; i < n; ++i) {
            int singleDataIndex;
            singlePhenotypeIndex[nAttributes - 1] = singleDataIndex = nArray[i];
            Phenotype singlePhenotype = new PhenotypeBuilder().fromPhenotype(reducedPheno).keepAttributes(singlePhenotypeIndex).removeMissingObservations().build().get(0);
            NumericAttribute dataAttribute = (NumericAttribute)singlePhenotype.attributeListOfType(Phenotype.ATTRIBUTE_TYPE.data).get(0);
            String traitname = dataAttribute.name();
            TaxaList phenoTaxa = singlePhenotype.taxa();
            TaxaList kinTaxa = kinshipOriginal.getTaxaList();
            TaxaList jointTaxa = TaxaListUtils.getCommonTaxa(phenoTaxa, kinTaxa);
            DistanceMatrix myKinship = DistanceMatrixUtils.keepTaxa(kinshipOriginal, jointTaxa);
            double[] phenotypeData = AssociationUtils.convertFloatArrayToDouble(dataAttribute.floatValues());
            int nObs = phenotypeData.length;
            DoubleMatrix phenotype = DoubleMatrixFactory.DEFAULT.make(nObs, 1, phenotypeData);
            DoubleMatrix fixedEffects = this.fixedEffectMatrix(singlePhenotype);
            DoubleMatrix kinship = DoubleMatrixFactory.DEFAULT.make(myKinship.getClonedDistances());
            int foldSize = nObs / this.kFolds.value();
            int[] seq = IntStream.range(0, nObs).toArray();
            BasicShuffler.reset();
            double[] rValues = new double[numberOfIterations * numberOfFolds];
            int rValueIndex = 0;
            for (int iter = 0; iter < numberOfIterations; ++iter) {
                BasicShuffler.shuffle(seq);
                int startFold = 0;
                for (int fold = 0; fold < numberOfFolds; ++fold) {
                    DoubleMatrix phenoTraining = phenotype.copy();
                    int endFold = startFold + foldSize;
                    if (fold == numberOfFolds - 1) {
                        endFold = nObs;
                    }
                    for (int ndx = startFold; ndx < endFold; ++ndx) {
                        phenoTraining.set(seq[ndx], 0, Double.NaN);
                    }
                    EMMAforDoubleMatrix runEMMA = new EMMAforDoubleMatrix(phenoTraining, fixedEffects, kinship);
                    runEMMA.solve();
                    runEMMA.calculateBlupsPredicted();
                    double[] predictions = runEMMA.getPred().to1DArray();
                    int testSize = endFold - startFold;
                    double[] testPredictions = new double[testSize];
                    double[] testObserved = new double[testSize];
                    for (int ndx = 0; ndx < testSize; ++ndx) {
                        int seqIndex = seq[ndx + startFold];
                        testPredictions[ndx] = predictions[seqIndex];
                        testObserved[ndx] = phenotype.get(seqIndex, 0);
                    }
                    PearsonsCorrelation Pearsons = new PearsonsCorrelation();
                    double rval = Pearsons.correlation(testPredictions, testObserved);
                    rValues[rValueIndex++] = rval;
                    myReportBuilder.add(new Object[]{traitname, new Integer(iter), new Integer(fold), new Double(rval)});
                    startFold = endFold;
                    if (++computeCount % updateProgressValue != 0) continue;
                    this.fireProgress(computeCount / updateProgressValue);
                }
            }
            DescriptiveStatistics stats = new DescriptiveStatistics(rValues);
            double meanR = stats.getMean();
            double varR = stats.getVariance();
            double sdR = Math.sqrt(varR / (double)rValues.length);
            System.out.printf("For phenotype %s\n", dataAttribute.name());
            System.out.printf("Mean from genomic prediction = %1.4f\n", meanR);
            System.out.printf("Standard deviation of mean from genomic prediction = %1.8f\n", sdR);
            commentList.add(" ");
            commentList.add(String.format("For phenotype %s", dataAttribute.name()));
            commentList.add(String.format("Mean from genomic prediction = %1.4f", meanR));
            commentList.add(String.format("Standard deviation of mean from genomic prediction = %1.8f", sdR));
        }
        String comment = "Genomic Prediction Accuracy Summary:\n";
        for (String commentLine : commentList) {
            comment = comment + commentLine + "\n";
        }
        DataSet returnData = new DataSet(new Datum("Accuracy_" + inputPhenotypeName, myReportBuilder.build(), comment), (Plugin)this);
        this.fireProgress(100);
        return returnData;
    }

    private DoubleMatrix fixedEffectMatrix(Phenotype aPhenotype) {
        DoubleMatrix fixedEffects;
        List<PhenotypeAttribute> factorAttributeList = aPhenotype.attributeListOfType(Phenotype.ATTRIBUTE_TYPE.factor);
        List<PhenotypeAttribute> covariateAttributeList = aPhenotype.attributeListOfType(Phenotype.ATTRIBUTE_TYPE.covariate);
        int numberOfFactors = factorAttributeList.size();
        int numberOfCovariates = covariateAttributeList.size();
        int numberOfEffects = numberOfFactors + numberOfCovariates + 1;
        int nObs = aPhenotype.numberOfObservations();
        if (numberOfEffects > 1) {
            int i;
            DoubleMatrix[][] effects = new DoubleMatrix[1][numberOfEffects];
            effects[0][0] = DoubleMatrixFactory.DEFAULT.make(nObs, 1, 1.0);
            for (i = 0; i < numberOfFactors; ++i) {
                CategoricalAttribute fa = (CategoricalAttribute)factorAttributeList.get(i);
                FactorModelEffect fme = new FactorModelEffect(fa.allIntValues(), true);
                effects[0][i + 1] = fme.getX();
            }
            for (i = 0; i < numberOfCovariates; ++i) {
                NumericAttribute na = (NumericAttribute)covariateAttributeList.get(i);
                double[] values = AssociationUtils.convertFloatArrayToDouble(na.floatValues());
                effects[0][i + numberOfFactors + 1] = DoubleMatrixFactory.DEFAULT.make(nObs, 1, values);
            }
            fixedEffects = DoubleMatrixFactory.DEFAULT.compose(effects);
        } else {
            fixedEffects = DoubleMatrixFactory.DEFAULT.make(nObs, 1, 1.0);
        }
        return fixedEffects;
    }

    @Override
    public ImageIcon getIcon() {
        URL imageURL = GenomicSelectionPlugin.class.getResource("/net/maizegenetics/analysis/images/LinearAssociation.gif");
        if (imageURL == null) {
            return null;
        }
        return new ImageIcon(imageURL);
    }

    @Override
    public String getButtonName() {
        return "Genomic Selection";
    }

    @Override
    public String getToolTipText() {
        return "Predict Phenotypes using G-BLUP for Genomic Selection";
    }

    @Override
    public String pluginDescription() {
        return "Predicts phenotypes using G-BLUP for genomic selection using a user-inputted kinship matrix and phenotype(s).";
    }

    @Override
    public String getCitation() {
        return "C Diepenbrock, P Bradbury (2015) First Annual Tassel Hackathon";
    }

    public DataSet runPlugin(DataSet input) {
        return (DataSet)this.performFunction(input).getData(0).getData();
    }

    public Boolean performCrossValidation() {
        return this.performCrossValidation.value();
    }

    public GenomicSelectionPlugin performCrossValidation(Boolean value) {
        this.performCrossValidation = new PluginParameter<Boolean>(this.performCrossValidation, value);
        return this;
    }

    public Integer kFolds() {
        return this.kFolds.value();
    }

    public GenomicSelectionPlugin kFolds(Integer value) {
        this.kFolds = new PluginParameter<Integer>(this.kFolds, value);
        return this;
    }

    public Integer nIterations() {
        return this.nIterations.value();
    }

    public GenomicSelectionPlugin nIterations(Integer value) {
        this.nIterations = new PluginParameter<Integer>(this.nIterations, value);
        return this;
    }
}

