/*
 * Decompiled with CFR 0.152.
 */
package eva2.optimization.strategies;

import eva2.gui.BeanInspector;
import eva2.optimization.enums.BOAScoringMethods;
import eva2.optimization.individuals.AbstractEAIndividual;
import eva2.optimization.individuals.GAIndividualBinaryData;
import eva2.optimization.individuals.InterfaceDataTypeBinary;
import eva2.optimization.individuals.InterfaceGAIndividual;
import eva2.optimization.population.InterfaceSolutionSet;
import eva2.optimization.population.Population;
import eva2.optimization.population.SolutionSet;
import eva2.optimization.strategies.AbstractOptimizer;
import eva2.problems.AbstractOptimizationProblem;
import eva2.problems.BKnapsackProblem;
import eva2.tools.Pair;
import eva2.tools.math.BayNet;
import eva2.tools.math.RNG;
import eva2.util.annotation.Description;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintWriter;
import java.io.Serializable;
import java.io.Writer;
import java.text.DateFormat;
import java.util.BitSet;
import java.util.Date;
import java.util.LinkedList;
import java.util.logging.Level;
import java.util.logging.Logger;

@Description(value="Basic implementation of the Bayesian Optimization Algorithm based on the works by Martin Pelikan and David E. Goldberg.")
public class BOA
extends AbstractOptimizer
implements Serializable {
    private static final Logger LOGGER = Logger.getLogger(BOA.class.getName());
    private int probDim = 8;
    private int fitCrit = -1;
    private int PopSize = 50;
    private int numberOfParents = 3;
    private transient BayNet network = null;
    private AbstractOptimizationProblem optimizationProblem = new BKnapsackProblem();
    private AbstractEAIndividual template = null;
    private double learningSetRatio = 0.5;
    private double resampleRatio = 0.5;
    private double upperProbLimit = 0.9;
    private double lowerProbLimit = 0.1;
    private int count = 0;
    private String netFolder = "BOAOutput";
    private int[][] edgeRate = null;
    private BOAScoringMethods scoringMethod = BOAScoringMethods.BDM;
    private boolean printNetworks = false;
    private boolean printEdgeRate = false;
    private boolean printTimestamps = false;
    private boolean printMetrics = false;

    public BOA() {
    }

    public BOA(int numberOfParents, int popSize, BOAScoringMethods method, double learningSetRatio, double resampleRatio, String outputFolder, double upperProbLimit, double lowerProbLimit, boolean printNetworks, boolean printEdgeRate, boolean printMetrics, boolean printTimestamps) {
        this.numberOfParents = numberOfParents;
        this.PopSize = popSize;
        this.scoringMethod = method;
        this.learningSetRatio = learningSetRatio;
        this.resampleRatio = resampleRatio;
        this.netFolder = outputFolder;
        this.upperProbLimit = upperProbLimit;
        this.lowerProbLimit = lowerProbLimit;
        this.printEdgeRate = printEdgeRate;
        this.printNetworks = printNetworks;
        this.printMetrics = printMetrics;
        this.printTimestamps = printTimestamps;
    }

    public BOA(BOA b) {
        this.probDim = b.probDim;
        this.fitCrit = b.fitCrit;
        this.PopSize = b.PopSize;
        this.numberOfParents = b.numberOfParents;
        this.network = (BayNet)b.network.clone();
        this.population = (Population)b.population.clone();
        this.optimizationProblem = (AbstractOptimizationProblem)b.optimizationProblem.clone();
        this.template = (AbstractEAIndividual)b.template.clone();
        this.learningSetRatio = b.learningSetRatio;
        this.resampleRatio = b.resampleRatio;
        this.upperProbLimit = b.upperProbLimit;
        this.lowerProbLimit = b.lowerProbLimit;
        this.count = b.count;
        this.netFolder = b.netFolder;
        this.scoringMethod = b.scoringMethod;
        this.edgeRate = new int[b.edgeRate.length][b.edgeRate.length];
        for (int i = 0; i < this.edgeRate.length; ++i) {
            System.arraycopy(b.edgeRate[i], 0, this.edgeRate[i], 0, this.edgeRate[i].length);
        }
        this.scoringMethod = b.scoringMethod;
        this.printNetworks = b.printNetworks;
        this.printMetrics = b.printMetrics;
        this.printEdgeRate = b.printEdgeRate;
        this.printTimestamps = b.printTimestamps;
    }

    @Override
    public Object clone() {
        return new BOA(this);
    }

    @Override
    public String getName() {
        return "Bayesian Optimization Algorithm";
    }

    private void createDirectoryIfNeeded(String directoryName) {
        File theDir = new File(directoryName);
        if (!theDir.exists()) {
            LOGGER.log(Level.INFO, "creating directory: " + directoryName);
            theDir.mkdir();
        }
    }

    private static BitSet getBinaryData(AbstractEAIndividual indy) {
        if (indy instanceof InterfaceGAIndividual) {
            return ((InterfaceGAIndividual)((Object)indy)).getBGenotype();
        }
        if (indy instanceof InterfaceDataTypeBinary) {
            return ((InterfaceDataTypeBinary)((Object)indy)).getBinaryData();
        }
        throw new RuntimeException("Unable to get binary representation for " + indy.getClass());
    }

    private void evaluate(AbstractEAIndividual indy) {
        if (indy == null) {
            LOGGER.log(Level.WARNING, "tried to evaluate null");
            return;
        }
        this.optimizationProblem.evaluate(indy);
        this.population.incrFunctionCalls();
    }

    private void defaultInit() {
        this.count = 0;
        if (this.printTimestamps) {
            this.printTimeStamp();
        }
        if (this.population == null) {
            this.population = new Population(this.PopSize);
        } else {
            this.population.setTargetPopSize(this.PopSize);
        }
        this.template = this.optimizationProblem.getIndividualTemplate();
        if (!(this.template instanceof InterfaceDataTypeBinary)) {
            LOGGER.log(Level.WARNING, "Requiring binary data!");
        } else {
            Object dim = BeanInspector.callIfAvailable(this.optimizationProblem, "getProblemDimension", null);
            if (dim == null) {
                LOGGER.log(Level.WARNING, "Coudn't get problem dimension!");
            }
            this.probDim = (Integer)dim;
            ((InterfaceDataTypeBinary)((Object)this.template)).setBinaryGenotype(new BitSet(this.probDim));
        }
        this.network = new BayNet(this.probDim, this.upperProbLimit, this.lowerProbLimit);
        this.network.setScoringMethod(this.scoringMethod);
        this.edgeRate = new int[this.probDim][this.probDim];
    }

    @Override
    public void initialize() {
        this.defaultInit();
        this.optimizationProblem.initializePopulation(this.population);
        this.evaluatePopulation(this.population);
        this.firePropertyChangedEvent("NextGenerationPerformed");
    }

    private void evaluatePopulation(Population pop) {
        for (int i = 0; i < pop.size(); ++i) {
            this.evaluate(pop.getEAIndividual(i));
        }
    }

    @Override
    public void initializeByPopulation(Population pop, boolean reset) {
        if (reset) {
            this.initialize();
        } else {
            this.defaultInit();
            this.population = pop;
        }
    }

    private void generateGreedy(Population pop) {
        double score;
        this.network = new BayNet(this.probDim, this.upperProbLimit, this.lowerProbLimit);
        this.network.setScoringMethod(this.scoringMethod);
        boolean improvement = true;
        this.network.initScoreArray(pop);
        double score1 = score = this.network.getNewScore(pop, -1);
        LinkedList<Pair<Integer, Integer>> bestNetworks = new LinkedList<Pair<Integer, Integer>>();
        while (improvement) {
            improvement = false;
            for (int i = 0; i < this.probDim; ++i) {
                for (int j = 0; j < this.probDim; ++j) {
                    double tmpScore;
                    if (this.network.hasEdge(i, j) || i == j || this.network.getNode(j).getNumberOfParents() >= this.numberOfParents) continue;
                    this.network.addEdge(i, j);
                    if (this.network.isACyclic(i, j) && (tmpScore = this.network.getNewScore(pop, j)) >= score && tmpScore != score1) {
                        if (tmpScore == score) {
                            bestNetworks.add(new Pair<Integer, Integer>(i, j));
                        } else {
                            bestNetworks.clear();
                            bestNetworks.add(new Pair<Integer, Integer>(i, j));
                            score = tmpScore;
                            improvement = true;
                        }
                    }
                    this.network.removeEdge(i, j);
                }
            }
            if (bestNetworks.size() > 0) {
                int val = RNG.randomInt(bestNetworks.size());
                Pair pair = (Pair)bestNetworks.get(val);
                this.network.addEdge((Integer)pair.getHead(), (Integer)pair.getTail());
                this.network.updateScoreArray(pop, (Integer)pair.getTail());
            }
            score1 = score = this.network.getNewScore(pop, -1);
            bestNetworks.clear();
        }
        score = this.network.getScore(pop);
    }

    private void constructNetwork(Population pop) {
        this.generateGreedy(pop);
    }

    private Population generateNewIndys(int sampleSetSize) {
        Population pop = new Population(sampleSetSize);
        LOGGER.log(Level.CONFIG, "Resampling " + sampleSetSize + " indies...");
        while (pop.size() < sampleSetSize) {
            AbstractEAIndividual indy = (AbstractEAIndividual)this.template.clone();
            BitSet data = this.network.sample(BOA.getBinaryData(indy));
            ((InterfaceDataTypeBinary)((Object)indy)).setBinaryGenotype(data);
            this.evaluate(indy);
            pop.add(indy);
        }
        return pop;
    }

    private int calcResampleSetSize() {
        int result = (int)Math.min((double)this.PopSize, Math.max(1.0, (double)this.PopSize * this.resampleRatio));
        return result;
    }

    private int calcLearningSetSize() {
        return (int)Math.min((double)this.PopSize, Math.max(1.0, (double)this.PopSize * this.learningSetRatio));
    }

    public void remove(Population pop) {
        for (Object indy : pop) {
            this.population.remove(indy);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void printEdgeRate() {
        String filename = this.netFolder + "/edgeRate.m";
        Writer w = null;
        String message = "edgeRate" + (Object)((Object)this.scoringMethod) + " = [";
        this.createDirectoryIfNeeded(this.netFolder);
        for (int i = 0; i < this.edgeRate.length; ++i) {
            for (int j = 0; j < this.edgeRate.length; ++j) {
                message = message + (double)this.edgeRate[i][j] / (double)(this.count + 1);
                if (j == this.edgeRate.length - 1) continue;
                message = message + ",";
            }
            if (i == this.edgeRate.length - 1) continue;
            message = message + ";";
        }
        message = message + "];";
        try (PrintWriter out = new PrintWriter(w);){
            w = new FileWriter(filename);
            out.write(message);
        }
        catch (IOException e) {
            e.printStackTrace();
        }
        finally {
            try {
                w.close();
            }
            catch (IOException e) {
                e.printStackTrace();
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void printNetworkToFile(String i) {
        String filename = this.netFolder + "/network_" + i + ".graphml";
        Writer w = null;
        String message = this.network.generateYFilesCode();
        this.createDirectoryIfNeeded(this.netFolder);
        try (PrintWriter out = new PrintWriter(w);){
            w = new FileWriter(filename);
            out.write(message);
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        finally {
            try {
                w.close();
            }
            catch (IOException e) {
                e.printStackTrace();
            }
        }
    }

    private void printTimeStamp() {
        String fileName = this.netFolder + "/timestamps.txt";
        Date d = new Date();
        DateFormat df = DateFormat.getTimeInstance(2);
        String message = this.count + "\t" + df.format(d) + "\n";
        this.createDirectoryIfNeeded(this.netFolder);
        boolean exists = new File(fileName).exists();
        if (exists) {
            try {
                FileWriter fstream = new FileWriter(fileName, true);
                BufferedWriter out = new BufferedWriter(fstream);
                out.write(message);
                out.newLine();
                out.close();
            }
            catch (Exception e) {
                LOGGER.log(Level.WARNING, "Error: ", e);
            }
        } else {
            try {
                FileWriter fstream = new FileWriter(fileName, false);
                BufferedWriter out = new BufferedWriter(fstream);
                out.newLine();
                out.write(message);
                out.newLine();
                out.close();
            }
            catch (Exception e) {
                LOGGER.log(Level.WARNING, "Error: ", e);
            }
        }
    }

    private void printMetrics(Population pop) {
        this.network.setScoringMethod(BOAScoringMethods.BDM);
        double bdmMetric = this.network.getScore(pop);
        this.network.setScoringMethod(BOAScoringMethods.K2);
        double k2Metric = this.network.getScore(pop);
        this.network.setScoringMethod(BOAScoringMethods.BIC);
        double bicMetric = this.network.getScore(pop);
        this.network.setScoringMethod(this.scoringMethod);
        String fileName = this.netFolder + "/" + "metrics.csv";
        this.createDirectoryIfNeeded(this.netFolder);
        boolean exists = new File(fileName).exists();
        if (exists) {
            try {
                FileWriter fstream = new FileWriter(fileName, true);
                BufferedWriter out = new BufferedWriter(fstream);
                out.write("" + bdmMetric + "," + k2Metric + "," + bicMetric);
                out.newLine();
                out.close();
            }
            catch (Exception e) {
                LOGGER.log(Level.WARNING, "Error: ", e);
            }
        } else {
            try {
                FileWriter fstream = new FileWriter(fileName, false);
                BufferedWriter out = new BufferedWriter(fstream);
                out.write("BDMMetric,  K2Metric, BIC");
                out.newLine();
                out.write("" + bdmMetric + "," + k2Metric + "," + bicMetric);
                out.newLine();
                out.close();
            }
            catch (Exception e) {
                LOGGER.log(Level.WARNING, "Error: ", e);
            }
        }
    }

    @Override
    public void optimize() {
        this.optimizationProblem.evaluatePopulationStart(this.population);
        Population best = this.population.getBestNIndividuals(this.calcLearningSetSize(), this.fitCrit);
        this.constructNetwork(best);
        if (this.printEdgeRate) {
            this.edgeRate = this.network.adaptEdgeRate(this.edgeRate);
        }
        Population newlyGenerated = this.generateNewIndys(this.calcResampleSetSize());
        Population toRemove = this.population.getWorstNIndividuals(this.calcResampleSetSize(), this.fitCrit);
        this.remove(toRemove);
        this.population.addAll(newlyGenerated);
        ++this.count;
        this.firePropertyChangedEvent("NextGenerationPerformed");
        this.optimizationProblem.evaluatePopulationEnd(this.population);
        if (this.printNetworks) {
            this.printNetworkToFile("" + this.count);
        }
        if (this.printEdgeRate) {
            this.printEdgeRate();
        }
        if (this.printMetrics) {
            this.printMetrics(best);
        }
        if (this.printTimestamps) {
            this.printTimeStamp();
        }
    }

    @Override
    public InterfaceSolutionSet getAllSolutions() {
        return new SolutionSet(this.population);
    }

    @Override
    public String getStringRepresentation() {
        return "Bayesian Network";
    }

    public int getNumberOfParents() {
        return this.numberOfParents;
    }

    public void setNumberOfParents(int i) {
        this.numberOfParents = i;
    }

    public String numberOfParentsTipText() {
        return "The maximum number of parents a node in the Bayesian Network can have";
    }

    public String replaceNetworkTipText() {
        return "if set, the network will be completely replaced. If not, it will be tried to improve the last network, if that is not possible, it will be replaced";
    }

    public BOAScoringMethods getNetworkGenerationMethod() {
        return this.scoringMethod;
    }

    public void setNetworkGenerationMethod(BOAScoringMethods n) {
        this.scoringMethod = n;
    }

    public String networkGenerationMethodTipText() {
        return "The Method with which the Bayesian Network will be gererated";
    }

    public int getPopulationSize() {
        return this.PopSize;
    }

    public void setPopulationSize(int popSize) {
        this.PopSize = popSize;
    }

    public String populationSizeTipText() {
        return "Define the pool size used by BOA";
    }

    public double getResamplingRatio() {
        return this.resampleRatio;
    }

    public void setResamplingRatio(double resampleRat) {
        this.resampleRatio = resampleRat;
    }

    public String resamplingRatioTipText() {
        return "Ratio of individuals to be resampled from the Bayesian network per iteration";
    }

    public double getLearningRatio() {
        return this.learningSetRatio;
    }

    public void setLearningRatio(double rat) {
        this.learningSetRatio = rat;
    }

    public String learningRatioTipText() {
        return "Ratio of individuals to be used to learn the Bayesian network";
    }

    public double getProbLimitHigh() {
        return this.upperProbLimit;
    }

    public void setProbLimitHigh(double upperProbLimit) {
        this.upperProbLimit = upperProbLimit;
    }

    public String probLimitHighTipText() {
        return "the upper limit of the probability to set one Bit to 1";
    }

    public double getProbLimitLow() {
        return this.lowerProbLimit;
    }

    public void setProbLimitLow(double lowerProbLimit) {
        this.lowerProbLimit = lowerProbLimit;
    }

    public String probLimitLowTipText() {
        return "the lower limit of the probability to set one Bit to 1";
    }

    public String[] customPropertyOrder() {
        return new String[]{"learningRatio", "resamplingRatio"};
    }

    public boolean isPrintNetworks() {
        return this.printNetworks;
    }

    public void setPrintNetworks(boolean b) {
        this.printNetworks = b;
    }

    public String printNetworksTipText() {
        return "Print the underlying networks of each generation";
    }

    public boolean isPrintEdgeRate() {
        return this.printEdgeRate;
    }

    public void setPrintEdgeRate(boolean b) {
        this.printEdgeRate = b;
    }

    public String printEdgeRateTipText() {
        return "Print the rate with which each edge is used in the optimization run";
    }

    public boolean isPrintMetrics() {
        return this.printMetrics;
    }

    public void setPrintMetrics(boolean b) {
        this.printMetrics = b;
    }

    public String printMetricsTipText() {
        return "Print the values of all the metrics for every network";
    }

    public boolean isPrintTimestamps() {
        return this.printTimestamps;
    }

    public void setPrintTimestamps(boolean b) {
        this.printTimestamps = b;
    }

    public String printTimestampsTipText() {
        return "Print the time starting time and a timestamp after each generation";
    }

    public static void main(String[] args) {
        Population pop = new Population();
        GAIndividualBinaryData indy1 = new GAIndividualBinaryData();
        indy1.setBinaryDataLength(8);
        GAIndividualBinaryData indy2 = (GAIndividualBinaryData)indy1.clone();
        GAIndividualBinaryData indy3 = (GAIndividualBinaryData)indy1.clone();
        GAIndividualBinaryData indy4 = (GAIndividualBinaryData)indy1.clone();
        GAIndividualBinaryData indy5 = (GAIndividualBinaryData)indy1.clone();
        BitSet data1 = indy1.getBinaryData();
        BitSet data2 = indy2.getBinaryData();
        BitSet data3 = indy3.getBinaryData();
        BitSet data4 = indy4.getBinaryData();
        BitSet data5 = indy5.getBinaryData();
        BitSet data6 = indy5.getBinaryData();
        BitSet data7 = indy5.getBinaryData();
        BitSet data8 = indy5.getBinaryData();
        BitSet data9 = indy5.getBinaryData();
        BitSet data10 = indy5.getBinaryData();
        BitSet data11 = indy5.getBinaryData();
        BitSet data12 = indy5.getBinaryData();
        BitSet data13 = indy5.getBinaryData();
        BitSet data14 = indy5.getBinaryData();
        BitSet data15 = indy5.getBinaryData();
        BitSet data16 = indy5.getBinaryData();
        data1.set(0, true);
        data1.set(1, false);
        data1.set(2, true);
        data1.set(3, false);
        data1.set(4, true);
        data1.set(5, true);
        data1.set(6, false);
        data1.set(7, false);
        data5.set(0, true);
        data5.set(1, false);
        data5.set(2, true);
        data5.set(3, false);
        data5.set(4, false);
        data5.set(5, true);
        data5.set(6, true);
        data5.set(7, true);
        data6.set(0, true);
        data6.set(1, false);
        data6.set(2, true);
        data6.set(3, false);
        data6.set(4, true);
        data6.set(5, true);
        data6.set(6, false);
        data6.set(7, false);
        data7.set(0, true);
        data7.set(1, false);
        data7.set(2, true);
        data7.set(3, false);
        data7.set(4, true);
        data7.set(5, true);
        data7.set(6, false);
        data7.set(7, false);
        data2.set(0, true);
        data2.set(1, false);
        data2.set(2, true);
        data2.set(3, false);
        data2.set(4, true);
        data2.set(5, true);
        data2.set(6, false);
        data2.set(7, false);
        data8.set(0, true);
        data8.set(1, false);
        data8.set(2, true);
        data8.set(3, false);
        data8.set(4, true);
        data8.set(5, true);
        data8.set(6, false);
        data8.set(7, false);
        data9.set(0, true);
        data9.set(1, false);
        data9.set(2, true);
        data9.set(3, false);
        data9.set(4, true);
        data9.set(5, true);
        data9.set(6, false);
        data9.set(7, false);
        data10.set(0, true);
        data10.set(1, false);
        data10.set(2, true);
        data10.set(3, false);
        data10.set(4, true);
        data10.set(5, true);
        data10.set(6, false);
        data10.set(7, false);
        data3.set(0, true);
        data3.set(1, false);
        data3.set(2, true);
        data3.set(3, false);
        data3.set(4, false);
        data3.set(5, false);
        data3.set(6, true);
        data3.set(7, true);
        data11.set(0, true);
        data11.set(1, false);
        data11.set(2, true);
        data11.set(3, false);
        data11.set(4, false);
        data11.set(5, false);
        data11.set(6, true);
        data11.set(7, true);
        data12.set(0, true);
        data12.set(1, false);
        data12.set(2, true);
        data12.set(3, false);
        data12.set(4, false);
        data12.set(5, false);
        data12.set(6, true);
        data12.set(7, true);
        data13.set(0, true);
        data13.set(1, false);
        data13.set(2, true);
        data13.set(3, false);
        data13.set(4, false);
        data13.set(5, false);
        data13.set(6, true);
        data13.set(7, true);
        data4.set(0, true);
        data4.set(1, false);
        data4.set(2, true);
        data4.set(3, false);
        data4.set(4, false);
        data4.set(5, false);
        data4.set(6, true);
        data4.set(7, true);
        data14.set(0, true);
        data14.set(1, false);
        data14.set(2, true);
        data14.set(3, false);
        data14.set(4, false);
        data14.set(5, false);
        data14.set(6, true);
        data14.set(7, true);
        data15.set(0, true);
        data15.set(1, false);
        data15.set(2, true);
        data15.set(3, false);
        data15.set(4, false);
        data15.set(5, false);
        data15.set(6, true);
        data15.set(7, true);
        data16.set(0, true);
        data16.set(1, false);
        data16.set(2, true);
        data16.set(3, false);
        data16.set(4, false);
        data16.set(5, false);
        data16.set(6, true);
        data16.set(7, true);
        indy1.setBinaryGenotype(data1);
        indy2.setBinaryGenotype(data2);
        indy3.setBinaryGenotype(data3);
        indy4.setBinaryGenotype(data4);
        indy5.setBinaryGenotype(data5);
        pop.add(indy1);
        pop.add(indy2);
        pop.add(indy3);
        pop.add(indy4);
        pop.add(indy5);
        BOA b = new BOA();
        b.initialize();
        b.optimize();
        b.optimize();
        b.optimize();
        b.optimize();
        b.optimize();
    }
}

