/*
 * Decompiled with CFR 0.152.
 */
package smile.feature;

import smile.classification.Classifier;
import smile.classification.ClassifierTrainer;
import smile.gap.BitString;
import smile.gap.Chromosome;
import smile.gap.FitnessMeasure;
import smile.gap.GeneticAlgorithm;
import smile.regression.Regression;
import smile.regression.RegressionTrainer;
import smile.validation.ClassificationMeasure;
import smile.validation.RegressionMeasure;
import smile.validation.Validation;

public class GAFeatureSelection {
    private GeneticAlgorithm.Selection selection = GeneticAlgorithm.Selection.TOURNAMENT;
    private double mutationRate = 0.01;
    private BitString.Crossover crossover = BitString.Crossover.UNIFORM;
    private double crossoverRate = 1.0;

    public GAFeatureSelection() {
    }

    public GAFeatureSelection(GeneticAlgorithm.Selection selection, BitString.Crossover crossover, double crossoverRate, double mutationRate) {
        if (crossoverRate < 0.0 || crossoverRate > 1.0) {
            throw new IllegalArgumentException("Invalid crossover rate: " + crossoverRate);
        }
        if (mutationRate < 0.0 || mutationRate > 1.0) {
            throw new IllegalArgumentException("Invalid mutation rate: " + mutationRate);
        }
        this.selection = selection;
        this.crossover = crossover;
        this.crossoverRate = crossoverRate;
        this.mutationRate = mutationRate;
    }

    public BitString[] learn(int size, int generation, ClassifierTrainer<double[]> trainer, ClassificationMeasure measure, double[][] x, int[] y, int k) {
        if (size <= 0) {
            throw new IllegalArgumentException("Invalid population size: " + size);
        }
        if (k < 2) {
            throw new IllegalArgumentException("Invalid k-fold cross validation: " + k);
        }
        if (x.length != y.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
        }
        int p = x[0].length;
        ClassificationFitness fitness = new ClassificationFitness(trainer, measure, x, y, k);
        Chromosome[] seeds = new BitString[size];
        for (int i = 0; i < size; ++i) {
            seeds[i] = new BitString(p, (FitnessMeasure<BitString>)fitness, this.crossover, this.crossoverRate, this.mutationRate);
        }
        GeneticAlgorithm ga = new GeneticAlgorithm(seeds, this.selection);
        ga.evolve(generation);
        return seeds;
    }

    public BitString[] learn(int size, int generation, ClassifierTrainer<double[]> trainer, ClassificationMeasure measure, double[][] x, int[] y, double[][] testx, int[] testy) {
        if (size <= 0) {
            throw new IllegalArgumentException("Invalid population size: " + size);
        }
        if (generation <= 0) {
            throw new IllegalArgumentException("Invlid number of generations to go: " + generation);
        }
        if (x.length != y.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
        }
        if (testx.length != testy.length) {
            throw new IllegalArgumentException(String.format("The sizes of test X and Y don't match: %d != %d", testx.length, testy.length));
        }
        int p = x[0].length;
        ClassificationFitness fitness = new ClassificationFitness(trainer, measure, x, y, testx, testy);
        Chromosome[] seeds = new BitString[size];
        for (int i = 0; i < size; ++i) {
            seeds[i] = new BitString(p, (FitnessMeasure<BitString>)fitness, this.crossover, this.crossoverRate, this.mutationRate);
        }
        GeneticAlgorithm ga = new GeneticAlgorithm(seeds, this.selection);
        ga.evolve(generation);
        return seeds;
    }

    public BitString[] learn(int size, int generation, RegressionTrainer<double[]> trainer, RegressionMeasure measure, double[][] x, double[] y, int k) {
        if (size <= 0) {
            throw new IllegalArgumentException("Invalid population size: " + size);
        }
        if (k < 2) {
            throw new IllegalArgumentException("Invalid k-fold cross validation: " + k);
        }
        if (x.length != y.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
        }
        int p = x[0].length;
        RegressionFitness fitness = new RegressionFitness(trainer, measure, x, y, k);
        Chromosome[] seeds = new BitString[size];
        for (int i = 0; i < size; ++i) {
            seeds[i] = new BitString(p, (FitnessMeasure<BitString>)fitness, this.crossover, this.crossoverRate, this.mutationRate);
        }
        GeneticAlgorithm ga = new GeneticAlgorithm(seeds, this.selection);
        ga.evolve(generation);
        return seeds;
    }

    public BitString[] learn(int size, int generation, RegressionTrainer<double[]> trainer, RegressionMeasure measure, double[][] x, double[] y, double[][] testx, double[] testy) {
        if (size <= 0) {
            throw new IllegalArgumentException("Invalid population size: " + size);
        }
        if (generation <= 0) {
            throw new IllegalArgumentException("Invlid number of generations to go: " + generation);
        }
        if (x.length != y.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
        }
        if (testx.length != testy.length) {
            throw new IllegalArgumentException(String.format("The sizes of test X and Y don't match: %d != %d", testx.length, testy.length));
        }
        int p = x[0].length;
        RegressionFitness fitness = new RegressionFitness(trainer, measure, x, y, testx, testy);
        Chromosome[] seeds = new BitString[size];
        for (int i = 0; i < size; ++i) {
            seeds[i] = new BitString(p, (FitnessMeasure<BitString>)fitness, this.crossover, this.crossoverRate, this.mutationRate);
        }
        GeneticAlgorithm ga = new GeneticAlgorithm(seeds, this.selection);
        ga.evolve(generation);
        return seeds;
    }

    class RegressionFitness
    implements FitnessMeasure<BitString> {
        RegressionTrainer<double[]> trainer;
        RegressionMeasure measure;
        double[][] x;
        double[] y;
        double[][] testx;
        double[] testy;
        int k = -1;

        RegressionFitness(RegressionTrainer<double[]> trainer, RegressionMeasure measure, double[][] x, double[] y, int k) {
            this.trainer = trainer;
            this.measure = measure;
            this.x = x;
            this.y = y;
            this.k = k;
        }

        RegressionFitness(RegressionTrainer<double[]> trainer, RegressionMeasure measure, double[][] x, double[] y, double[][] testx, double[] testy) {
            this.trainer = trainer;
            this.measure = measure;
            this.x = x;
            this.y = y;
            this.testx = testx;
            this.testy = testy;
        }

        @Override
        public double fit(BitString chromosome) {
            int[] bits;
            int p = 0;
            for (int b : bits = chromosome.bits()) {
                p += b;
            }
            if (p == 0) {
                return 0.0;
            }
            int m = this.x[0].length;
            int n = this.x.length;
            double[][] xx = new double[n][p];
            int jj = 0;
            for (int j = 0; j < m; ++j) {
                if (bits[j] != 1) continue;
                for (int i = 0; i < n; ++i) {
                    xx[i][jj] = this.x[i][j];
                }
                ++jj;
            }
            if (this.k != -1) {
                return -Validation.cv(this.k, this.trainer, xx, this.y);
            }
            Regression<double[]> regression = this.trainer.train((T[])xx, this.y);
            int testn = this.testx.length;
            double[][] testxx = new double[testn][p];
            int jj2 = 0;
            for (int j = 0; j < m; ++j) {
                if (bits[j] != 1) continue;
                for (int i = 0; i < testn; ++i) {
                    testxx[i][jj2] = this.testx[i][j];
                }
                ++jj2;
            }
            double[] prediction = new double[testn];
            for (int i = 0; i < testn; ++i) {
                prediction[i] = regression.predict(testxx[i]);
            }
            return -this.measure.measure(this.testy, prediction);
        }
    }

    class ClassificationFitness
    implements FitnessMeasure<BitString> {
        ClassifierTrainer<double[]> trainer;
        ClassificationMeasure measure;
        double[][] x;
        int[] y;
        double[][] testx;
        int[] testy;
        int k = -1;

        ClassificationFitness(ClassifierTrainer<double[]> trainer, ClassificationMeasure measure, double[][] x, int[] y, int k) {
            this.trainer = trainer;
            this.measure = measure;
            this.x = x;
            this.y = y;
            this.k = k;
        }

        ClassificationFitness(ClassifierTrainer<double[]> trainer, ClassificationMeasure measure, double[][] x, int[] y, double[][] testx, int[] testy) {
            this.trainer = trainer;
            this.measure = measure;
            this.x = x;
            this.y = y;
            this.testx = testx;
            this.testy = testy;
        }

        @Override
        public double fit(BitString chromosome) {
            int[] bits;
            int p = 0;
            for (int b : bits = chromosome.bits()) {
                p += b;
            }
            if (p == 0) {
                return 0.0;
            }
            int m = this.x[0].length;
            int n = this.x.length;
            double[][] xx = new double[n][p];
            int jj = 0;
            for (int j = 0; j < m; ++j) {
                if (bits[j] != 1) continue;
                for (int i = 0; i < n; ++i) {
                    xx[i][jj] = this.x[i][j];
                }
                ++jj;
            }
            if (this.k != -1) {
                return Validation.cv(this.k, this.trainer, xx, this.y);
            }
            Classifier<double[]> classifier = this.trainer.train((T[])xx, this.y);
            int testn = this.testx.length;
            double[][] testxx = new double[testn][p];
            int jj2 = 0;
            for (int j = 0; j < m; ++j) {
                if (bits[j] != 1) continue;
                for (int i = 0; i < testn; ++i) {
                    testxx[i][jj2] = this.testx[i][j];
                }
                ++jj2;
            }
            int[] prediction = new int[testn];
            for (int i = 0; i < testn; ++i) {
                prediction[i] = classifier.predict(testxx[i]);
            }
            return this.measure.measure(this.testy, prediction);
        }
    }
}

