/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.mlplan.metamining.dyadranking;

import ai.libs.hasco.core.HASCOUtil;
import ai.libs.jaicore.basic.algorithm.AlgorithmInitializedEvent;
import ai.libs.jaicore.basic.sets.Pair;
import ai.libs.jaicore.components.api.IComponent;
import ai.libs.jaicore.components.api.IComponentInstance;
import ai.libs.jaicore.components.api.IComponentRepository;
import ai.libs.jaicore.components.model.ComponentInstance;
import ai.libs.jaicore.logic.fol.structure.Monom;
import ai.libs.jaicore.math.linearalgebra.DenseDoubleVector;
import ai.libs.jaicore.ml.classification.loss.dataset.EClassificationPerformanceMeasure;
import ai.libs.jaicore.ml.core.evaluation.evaluator.FixedSplitClassifierEvaluator;
import ai.libs.jaicore.ml.ranking.dyad.dataset.DyadRankingDataset;
import ai.libs.jaicore.ml.ranking.dyad.dataset.SparseDyadRankingInstance;
import ai.libs.jaicore.ml.ranking.dyad.learner.algorithm.PLNetDyadRanker;
import ai.libs.jaicore.ml.ranking.dyad.learner.util.DyadMinMaxScaler;
import ai.libs.jaicore.ml.weka.WekaUtil;
import ai.libs.jaicore.ml.weka.dataset.WekaInstances;
import ai.libs.jaicore.planning.hierarchical.algorithms.forwarddecomposition.graphgenerators.tfd.TFDNode;
import ai.libs.jaicore.search.algorithms.standard.bestfirst.events.EvaluatedSearchSolutionCandidateFoundEvent;
import ai.libs.jaicore.search.algorithms.standard.bestfirst.events.FValueEvent;
import ai.libs.jaicore.search.algorithms.standard.bestfirst.nodeevaluation.RandomizedDepthFirstNodeEvaluator;
import ai.libs.jaicore.search.algorithms.standard.gbf.SolutionEventBus;
import ai.libs.jaicore.search.algorithms.standard.random.RandomSearch;
import ai.libs.jaicore.search.model.other.EvaluatedSearchGraphPath;
import ai.libs.jaicore.search.model.other.SearchGraphPath;
import ai.libs.jaicore.search.probleminputs.GraphSearchWithSubpathEvaluationsInput;
import ai.libs.mlplan.core.ILearnerFactory;
import ai.libs.mlplan.metamining.dyadranking.DyadRankingBasedNodeEvaluatorConfig;
import ai.libs.mlplan.metamining.pipelinecharacterizing.ComponentInstanceVectorFeatureGenerator;
import ai.libs.mlplan.metamining.pipelinecharacterizing.IPipelineCharacterizer;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.nio.file.Paths;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorCompletionService;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.stream.Collectors;
import org.aeonbits.owner.ConfigFactory;
import org.apache.commons.collections.BidiMap;
import org.apache.commons.collections.bidimap.DualHashBidiMap;
import org.api4.java.ai.graphsearch.problem.IPathSearchInput;
import org.api4.java.ai.graphsearch.problem.implicit.graphgenerator.IPathGoalTester;
import org.api4.java.ai.graphsearch.problem.pathsearch.pathevaluation.IPathEvaluator;
import org.api4.java.ai.graphsearch.problem.pathsearch.pathevaluation.IPotentiallyGraphDependentPathEvaluator;
import org.api4.java.ai.graphsearch.problem.pathsearch.pathevaluation.IPotentiallySolutionReportingPathEvaluator;
import org.api4.java.ai.graphsearch.problem.pathsearch.pathevaluation.PathEvaluationException;
import org.api4.java.ai.ml.classification.IClassifier;
import org.api4.java.ai.ml.core.dataset.splitter.SplitFailedException;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset;
import org.api4.java.ai.ml.core.evaluation.supervised.loss.IDeterministicPredictionPerformanceMeasure;
import org.api4.java.ai.ml.core.exception.PredictionException;
import org.api4.java.ai.ml.core.learner.ISupervisedLearner;
import org.api4.java.ai.ml.ranking.IRanking;
import org.api4.java.ai.ml.ranking.dyad.dataset.IDyad;
import org.api4.java.ai.ml.ranking.dyad.dataset.IDyadRankingDataset;
import org.api4.java.ai.ml.ranking.dyad.dataset.IDyadRankingInstance;
import org.api4.java.algorithm.exceptions.AlgorithmExecutionCanceledException;
import org.api4.java.common.attributedobjects.IObjectEvaluator;
import org.api4.java.common.math.IVector;
import org.api4.java.datastructure.graph.ILabeledPath;
import org.api4.java.datastructure.graph.implicit.IGraphGenerator;
import org.openml.webapplication.fantail.dc.LandmarkerCharacterizer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import weka.core.Instances;

public class DyadRankingBasedNodeEvaluator<T, A, V extends Comparable<V>>
implements IPotentiallyGraphDependentPathEvaluator<T, A, V>,
IPotentiallySolutionReportingPathEvaluator<T, A, V> {
    private static final Logger logger = LoggerFactory.getLogger(DyadRankingBasedNodeEvaluator.class);
    private BidiMap pathToPipelines = new DualHashBidiMap();
    private RandomSearch<T, A> randomPathCompleter;
    private IObjectEvaluator<ComponentInstance, V> pipelineEvaluator;
    private Collection<IComponent> components;
    private final int randomlyCompletedPaths;
    private Instances evaluationDataset;
    private double[] datasetMetaFeatures;
    private final int evaluatedPaths;
    private final Random random;
    private PLNetDyadRanker dyadRanker = new PLNetDyadRanker();
    private IPipelineCharacterizer characterizer;
    private final int landmarkerSampleSize;
    private final int[] landmarkers;
    private Instances[][] landmarkerSets;
    private ILearnerFactory<IClassifier> classifierFactory;
    private boolean useLandmarkers;
    private Instant firstEvaluation = null;
    private SolutionEventBus<T> eventBus = new SolutionEventBus();
    private IGraphGenerator<T, A> graphGenerator;
    private IPathGoalTester<T, A> goalTester;
    private DyadMinMaxScaler scaler = null;

    public void setClassifierFactory(ILearnerFactory<IClassifier> classifierFactory) {
        this.classifierFactory = classifierFactory;
    }

    public DyadRankingBasedNodeEvaluator(IComponentRepository repository) {
        this(repository, (DyadRankingBasedNodeEvaluatorConfig)ConfigFactory.create(DyadRankingBasedNodeEvaluatorConfig.class, (Map[])new Map[0]));
    }

    public DyadRankingBasedNodeEvaluator(IComponentRepository repository, DyadRankingBasedNodeEvaluatorConfig config) {
        this.components = repository;
        this.random = new Random(config.getSeed());
        this.evaluatedPaths = config.getNumberOfEvaluations();
        this.randomlyCompletedPaths = config.getNumberOfRandomSamples();
        logger.debug("Initialized DyadRankingBasedNodeEvaluator with evalNum: {} and completionNum: {}", (Object)this.randomlyCompletedPaths, (Object)this.evaluatedPaths);
        this.characterizer = new ComponentInstanceVectorFeatureGenerator((Collection<? extends IComponent>)repository);
        this.landmarkers = config.getLandmarkers();
        this.landmarkerSampleSize = config.getLandmarkerSampleSize();
        this.useLandmarkers = config.useLandmarkers();
        String scalerPath = config.scalerPath();
        try {
            this.dyadRanker.loadModelFromFile(Paths.get(config.getPlNetPath(), new String[0]).toString());
        }
        catch (IOException e) {
            logger.error("Could not load model for plnet in {}", (Object)Paths.get(config.getPlNetPath(), new String[0]));
        }
        try (ObjectInputStream oin = new ObjectInputStream(new FileInputStream(Paths.get(scalerPath, new String[0]).toFile()));){
            this.scaler = (DyadMinMaxScaler)oin.readObject();
        }
        catch (IOException e) {
            logger.error("Could not load sclader for plnet in {}", (Object)Paths.get(config.scalerPath(), new String[0]));
        }
        catch (ClassNotFoundException e) {
            logger.error("Could not read scaler.", (Throwable)e);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public V evaluate(ILabeledPath<T, A> path) throws InterruptedException, PathEvaluationException {
        List<ComponentInstance> allRankedPaths;
        if (this.firstEvaluation == null) {
            this.firstEvaluation = Instant.now();
        }
        if (((IPathSearchInput)this.randomPathCompleter.getInput()).getGoalTester().isGoal(path)) {
            return null;
        }
        Instant startOfEvaluation = Instant.now();
        if (!this.randomPathCompleter.knowsNode(path.getHead())) {
            RandomSearch<T, A> randomSearch = this.randomPathCompleter;
            synchronized (randomSearch) {
                this.randomPathCompleter.appendPathToNode(path);
            }
        }
        List<List<T>> randomPaths = null;
        try {
            randomPaths = this.getNRandomPaths(path);
        }
        catch (InterruptedException | TimeoutException e) {
            logger.error("Interrupted in path completion!");
            Thread.currentThread().interrupt();
            Thread.interrupted();
            throw new InterruptedException();
        }
        try {
            allRankedPaths = this.getDyadRankedPaths(randomPaths);
        }
        catch (PredictionException e1) {
            throw new PathEvaluationException("Could not rank nodes", (Throwable)e1);
        }
        if (allRankedPaths.isEmpty()) {
            return (V)Double.valueOf(9000.0);
        }
        List<ComponentInstance> topKRankedPaths = allRankedPaths.subList(0, Math.min(this.evaluatedPaths, allRankedPaths.size()));
        List<Pair<ComponentInstance, V>> allEvaluatedPaths = null;
        try {
            allEvaluatedPaths = this.evaluateTopKPaths(topKRankedPaths);
        }
        catch (InterruptedException | TimeoutException e) {
            logger.error("Interrupted while predicitng next best solution");
            Thread.currentThread().interrupt();
            Thread.interrupted();
            throw new InterruptedException();
        }
        catch (ExecutionException e2) {
            logger.error("Couldn't evaluate solution candidates- Returning null as FValue!.");
            return null;
        }
        Duration evaluationTime = Duration.between(startOfEvaluation, Instant.now());
        logger.info("Evaluation took {}ms", (Object)evaluationTime.toMillis());
        V bestSoultion = this.getBestSolution(allEvaluatedPaths);
        logger.info("Best solution is {}, {}", bestSoultion, allEvaluatedPaths.stream().map(Pair::getY).collect(Collectors.toList()));
        if (bestSoultion == null) {
            return (V)Double.valueOf(9000.0);
        }
        this.eventBus.post((Object)new FValueEvent(null, bestSoultion, (double)evaluationTime.toMillis()));
        return bestSoultion;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private List<List<T>> getNRandomPaths(ILabeledPath<T, A> node) throws InterruptedException, TimeoutException {
        ArrayList<List<T>> completedPaths = new ArrayList<List<T>>();
        for (int currentPath = 0; currentPath < this.randomlyCompletedPaths; ++currentPath) {
            ArrayList pathCompletion = null;
            ArrayList completedPath = null;
            RandomSearch<T, A> randomSearch = this.randomPathCompleter;
            synchronized (randomSearch) {
                if (this.randomPathCompleter.isCanceled()) {
                    logger.info("Completer has been canceled (perhaps due a cancel on the evaluator). Canceling RDFS");
                    break;
                }
                completedPath = new ArrayList(node.getNodes());
                SearchGraphPath solutionPathFromN = null;
                try {
                    solutionPathFromN = this.randomPathCompleter.nextSolutionUnderSubPath(node);
                }
                catch (AlgorithmExecutionCanceledException e) {
                    logger.info("Completer has been canceled. Returning control.");
                    break;
                }
                if (solutionPathFromN == null) {
                    logger.info("No completion was found for path {}.", (Object)node.getNodes());
                    break;
                }
                pathCompletion = new ArrayList(solutionPathFromN.getNodes());
                pathCompletion.remove(0);
                completedPath.addAll(pathCompletion);
            }
            completedPaths.add(completedPath);
        }
        logger.info("Returning {} paths", (Object)completedPaths.size());
        return completedPaths;
    }

    private List<ComponentInstance> getDyadRankedPaths(List<List<T>> randomPaths) throws PredictionException, InterruptedException {
        HashMap<IVector, ComponentInstance> pipelineToCharacterization = new HashMap<IVector, ComponentInstance>();
        for (List<T> randomPath : randomPaths) {
            TFDNode goalNode = (TFDNode)randomPath.get(randomPath.size() - 1);
            ComponentInstance cI = HASCOUtil.getSolutionCompositionFromState(this.components, (Monom)goalNode.getState(), (boolean)true);
            this.pathToPipelines.put(randomPath, (Object)cI);
            if (this.useLandmarkers) {
                IVector yPrime = this.evaluateLandmarkersForAlgorithm(cI);
                pipelineToCharacterization.put(yPrime, cI);
                continue;
            }
            DenseDoubleVector y = new DenseDoubleVector(this.characterizer.characterize((IComponentInstance)cI));
            if (this.scaler != null) {
                List<IDyadRankingInstance> asList = Arrays.asList(new SparseDyadRankingInstance((IVector)new DenseDoubleVector(this.datasetMetaFeatures), Arrays.asList(y)));
                DyadRankingDataset dataset = new DyadRankingDataset(asList);
                this.scaler.transformAlternatives((IDyadRankingDataset)dataset);
            }
            pipelineToCharacterization.put((IVector)y, cI);
        }
        return this.rankRandomPipelines(pipelineToCharacterization);
    }

    private IVector evaluateLandmarkersForAlgorithm(ComponentInstance cI) {
        double[] y = this.characterizer.characterize((IComponentInstance)cI);
        int sizeOfYPrime = this.characterizer.getLengthOfCharacterization() + this.landmarkers.length;
        double[] yPrime = new double[sizeOfYPrime];
        System.arraycopy(y, 0, yPrime, 0, y.length);
        for (int i = 0; i < this.landmarkers.length; ++i) {
            Instances[] subsets = this.landmarkerSets[i];
            double score = 0.0;
            for (Instances train : subsets) {
                FixedSplitClassifierEvaluator evaluator = new FixedSplitClassifierEvaluator((ILabeledDataset)new WekaInstances(train), (ILabeledDataset)new WekaInstances(this.evaluationDataset), (IDeterministicPredictionPerformanceMeasure)EClassificationPerformanceMeasure.ERRORRATE);
                try {
                    score += evaluator.evaluate((ISupervisedLearner)this.classifierFactory.getComponentInstantiation((IComponentInstance)cI)).doubleValue();
                }
                catch (Exception e) {
                    logger.error("Couldn't get classifier for {}", (Object)cI);
                }
            }
            if (score != 0.0) {
                score /= (double)subsets.length;
            }
            yPrime[y.length + i] = score;
        }
        return new DenseDoubleVector(yPrime);
    }

    private List<ComponentInstance> rankRandomPipelines(Map<IVector, ComponentInstance> randomPipelines) throws PredictionException, InterruptedException {
        ArrayList<IVector> alternatives = new ArrayList<IVector>(randomPipelines.keySet());
        SparseDyadRankingInstance toRank = new SparseDyadRankingInstance((IVector)new DenseDoubleVector(this.datasetMetaFeatures), alternatives);
        IRanking rankedInstance = this.dyadRanker.predict((IDyadRankingInstance)toRank);
        ArrayList<ComponentInstance> rankedPipelines = new ArrayList<ComponentInstance>();
        for (IDyad dyad : rankedInstance) {
            rankedPipelines.add(randomPipelines.get(dyad.getAlternative()));
        }
        return rankedPipelines;
    }

    private List<Pair<ComponentInstance, V>> evaluateTopKPaths(List<ComponentInstance> topKRankedPaths) throws InterruptedException, ExecutionException, TimeoutException {
        ExecutorService executor = Executors.newFixedThreadPool(1);
        ExecutorCompletionService<Pair> completionService = new ExecutorCompletionService<Pair>(executor);
        ArrayList<Pair<ComponentInstance, V>> evaluatedSolutions = new ArrayList<Pair<ComponentInstance, V>>();
        for (ComponentInstance node : topKRankedPaths) {
            completionService.submit(() -> {
                try {
                    Instant startTime = Instant.now();
                    Comparable score = this.pipelineEvaluator.evaluate((Object)node);
                    Duration evalTime = Duration.between(startTime, Instant.now());
                    this.postSolution(node, evalTime.toMillis(), score);
                    return new Pair((Object)node, (Object)score);
                }
                catch (Exception e) {
                    logger.error("Couldn't evaluate {}", (Object)node);
                    return null;
                }
            });
        }
        for (int i = 0; i < topKRankedPaths.size(); ++i) {
            logger.info("Got {} solutions. Waiting for iteration {} of max iterations {}", new Object[]{evaluatedSolutions.size(), i + 1, topKRankedPaths.size()});
            Future evaluatedPipe = completionService.poll(20L, TimeUnit.SECONDS);
            if (evaluatedPipe == null) {
                logger.info("Didn't receive any futures (expected {} futures)", (Object)topKRankedPaths.size());
                continue;
            }
            try {
                Pair solution = (Pair)evaluatedPipe.get(20L, TimeUnit.SECONDS);
                if (solution != null) {
                    logger.info("Evaluation was successful. Adding {} to solutions", solution.getY());
                    evaluatedSolutions.add(solution);
                    continue;
                }
                logger.info("No solution was found while waiting up to 20s.");
                evaluatedPipe.cancel(true);
                continue;
            }
            catch (Exception e) {
                logger.info("Got exception while evaluating {}", (Object)e.getMessage());
            }
        }
        return evaluatedSolutions;
    }

    private V getBestSolution(List<Pair<ComponentInstance, V>> allEvaluatedPaths) {
        return (V)((Comparable)allEvaluatedPaths.stream().map(Pair::getY).min(Comparable::compareTo).orElse(null));
    }

    public void setGenerator(IGraphGenerator<T, A> generator, IPathGoalTester<T, A> goalTester) {
        this.graphGenerator = generator;
        this.goalTester = goalTester;
        this.initializeRandomSearch();
    }

    private void initializeRandomSearch() {
        RandomizedDepthFirstNodeEvaluator nodeEvaluator = new RandomizedDepthFirstNodeEvaluator(this.random);
        GraphSearchWithSubpathEvaluationsInput completionProblem = new GraphSearchWithSubpathEvaluationsInput(this.graphGenerator, this.goalTester, (IPathEvaluator)nodeEvaluator);
        this.randomPathCompleter = new RandomSearch((IPathSearchInput)completionProblem, null, this.random);
        while (!(this.randomPathCompleter.next() instanceof AlgorithmInitializedEvent)) {
        }
    }

    public void setDataset(Instances dataset) {
        try {
            if (this.useLandmarkers) {
                List split = WekaUtil.getStratifiedSplit((Instances)dataset, (long)42L, (double)0.8);
                Instances trainData = (Instances)split.get(0);
                this.evaluationDataset = (Instances)split.get(1);
                Map<String, Double> metaFeatures = new LandmarkerCharacterizer().characterize(dataset);
                this.datasetMetaFeatures = metaFeatures.entrySet().stream().mapToDouble(Map.Entry::getValue).toArray();
                this.setUpLandmarkingDatasets(dataset, trainData);
            } else {
                Map<String, Double> metaFeatures = new LandmarkerCharacterizer().characterize(dataset);
                this.datasetMetaFeatures = metaFeatures.entrySet().stream().mapToDouble(Map.Entry::getValue).toArray();
            }
        }
        catch (SplitFailedException e) {
            throw new IllegalArgumentException(e);
        }
        catch (InterruptedException e) {
            Thread.currentThread().interrupt();
        }
        catch (Exception e) {
            logger.error("Failed to characterize the dataset", (Throwable)e);
        }
    }

    private void setUpLandmarkingDatasets(Instances dataset, Instances trainData) {
        this.landmarkerSets = new Instances[this.landmarkers.length][this.landmarkerSampleSize];
        for (int i = 0; i < this.landmarkers.length; ++i) {
            int landmarker = this.landmarkers[i];
            for (int j = 0; j < this.landmarkerSampleSize; ++j) {
                Instances instances = new Instances(dataset, landmarker);
                for (int k = 0; k < landmarker; ++k) {
                    int randomEntry = this.random.nextInt(trainData.size());
                    instances.add(trainData.get(randomEntry));
                }
                this.landmarkerSets[i][j] = instances;
            }
        }
    }

    protected void postSolution(ComponentInstance solution, long time, V score) {
        try {
            List pathToSolution = (List)this.pathToPipelines.getKey((Object)solution);
            EvaluatedSearchGraphPath solutionObject = new EvaluatedSearchGraphPath(pathToSolution, null, score);
            solutionObject.setAnnotation("fTime", (Object)time);
            solutionObject.setAnnotation("timeToSolution", (Object)Duration.between(this.firstEvaluation, Instant.now()).toMillis());
            solutionObject.setAnnotation("nodesEvaluatedToSolution", (Object)this.randomlyCompletedPaths);
            logger.debug("Posting solution {}", (Object)solutionObject);
            this.eventBus.post((Object)new EvaluatedSearchSolutionCandidateFoundEvent(null, solutionObject));
        }
        catch (Exception e) {
            logger.error("Couldn't post solution to event bus.", (Throwable)e);
        }
    }

    public void setPipelineEvaluator(IObjectEvaluator<ComponentInstance, V> wrappedSearchBenchmark) {
        this.pipelineEvaluator = wrappedSearchBenchmark;
    }

    public boolean requiresGraphGenerator() {
        return true;
    }

    public void registerSolutionListener(Object listener) {
        this.eventBus.register(listener);
    }

    public boolean reportsSolutions() {
        return true;
    }
}

