/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.core.evaluation.evaluator;

import ai.libs.jaicore.ml.core.evaluation.evaluator.LearnerRunReport;
import ai.libs.jaicore.ml.core.evaluation.evaluator.SupervisedLearnerExecutor;
import ai.libs.jaicore.ml.core.evaluation.evaluator.events.TrainTestSplitEvaluationCompletedEvent;
import ai.libs.jaicore.ml.core.evaluation.evaluator.events.TrainTestSplitEvaluationFailedEvent;
import com.google.common.eventbus.EventBus;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.commons.lang3.exception.ExceptionUtils;
import org.api4.java.ai.ml.classification.IClassifierEvaluator;
import org.api4.java.ai.ml.core.dataset.splitter.SplitFailedException;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance;
import org.api4.java.ai.ml.core.evaluation.execution.IAggregatedPredictionPerformanceMeasure;
import org.api4.java.ai.ml.core.evaluation.execution.IDatasetSplitSet;
import org.api4.java.ai.ml.core.evaluation.execution.IFixedDatasetSplitSetGenerator;
import org.api4.java.ai.ml.core.evaluation.execution.ILearnerRunReport;
import org.api4.java.ai.ml.core.evaluation.execution.LearnerExecutionFailedException;
import org.api4.java.ai.ml.core.evaluation.execution.LearnerExecutionInterruptedException;
import org.api4.java.ai.ml.core.learner.ISupervisedLearner;
import org.api4.java.common.attributedobjects.ObjectEvaluationFailedException;
import org.api4.java.common.control.ILoggingCustomizable;
import org.api4.java.common.event.IEventEmitter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class TrainPredictionBasedClassifierEvaluator
implements IClassifierEvaluator,
ILoggingCustomizable,
IEventEmitter<Object> {
    private Logger logger = LoggerFactory.getLogger(TrainPredictionBasedClassifierEvaluator.class);
    private final IFixedDatasetSplitSetGenerator<ILabeledDataset<? extends ILabeledInstance>> splitGenerator;
    private final SupervisedLearnerExecutor executor = new SupervisedLearnerExecutor();
    private final IAggregatedPredictionPerformanceMeasure metric;
    private final EventBus eventBus = new EventBus();
    private boolean hasListeners;

    public TrainPredictionBasedClassifierEvaluator(IFixedDatasetSplitSetGenerator<ILabeledDataset<?>> splitGenerator, IAggregatedPredictionPerformanceMeasure<?, ?> metric) {
        this.splitGenerator = splitGenerator;
        this.metric = metric;
    }

    public Double evaluate(ISupervisedLearner<ILabeledInstance, ILabeledDataset<? extends ILabeledInstance>> learner) throws InterruptedException, ObjectEvaluationFailedException {
        try {
            long evaluationStart = System.currentTimeMillis();
            this.logger.info("Using {} to split the given data into two folds.", (Object)this.splitGenerator.getClass().getName());
            IDatasetSplitSet splitSet = this.splitGenerator.nextSplitSet();
            if (splitSet.getNumberOfFoldsPerSplit() != 2) {
                throw new IllegalStateException("Number of folds for each split should be 2 but is " + splitSet.getNumberOfFoldsPerSplit() + "! Split generator: " + this.splitGenerator);
            }
            int n = splitSet.getNumberOfSplits();
            ArrayList<ILearnerRunReport> reports = new ArrayList<ILearnerRunReport>(n);
            for (int i = 0; i < n; ++i) {
                ILearnerRunReport report;
                List folds = splitSet.getFolds(i);
                this.logger.debug("Executing learner {} on folds of sizes {} (train) and {} (test) using {}.", new Object[]{learner, ((ILabeledDataset)folds.get(0)).size(), ((ILabeledDataset)folds.get(1)).size(), this.executor.getClass().getName()});
                try {
                    report = this.executor.execute(learner, (ILabeledDataset)folds.get(0), (ILabeledDataset)folds.get(1));
                    this.logger.trace("Obtained report. Training times was {}ms, testing time {}ms. Ground truth vector: {}, prediction vector: {}. Pipeline: {}", new Object[]{report.getTrainEndTime() - report.getTrainStartTime(), report.getTestEndTime() - report.getTestStartTime(), report.getPredictionDiffList().getGroundTruthAsList(), report.getPredictionDiffList().getPredictionsAsList(), learner});
                }
                catch (LearnerExecutionInterruptedException e) {
                    this.logger.info("Received interrupt of training in iteration #{} after a total evaluation time of {}ms. Sending an event over the bus and forwarding the exception.", (Object)(i + 1), (Object)(System.currentTimeMillis() - evaluationStart));
                    ILabeledDataset train = (ILabeledDataset)folds.get(0);
                    ILabeledDataset test = (ILabeledDataset)folds.get(1);
                    LearnerRunReport failReport = new LearnerRunReport(train, test, e.getTrainTimeStart(), e.getTrainTimeEnd(), e.getTestTimeStart(), e.getTestTimeEnd(), e);
                    this.eventBus.post(new TrainTestSplitEvaluationFailedEvent<ILabeledInstance, ILabeledDataset<? extends ILabeledInstance>>(learner, failReport));
                    throw e;
                }
                catch (LearnerExecutionFailedException e) {
                    this.logger.info("Catching {} in iteration #{} after a total evaluation time of {}ms. Sending an event over the bus and forwarding the exception.", new Object[]{((Object)((Object)e)).getClass().getName(), i + 1, System.currentTimeMillis() - evaluationStart});
                    ILabeledDataset train = (ILabeledDataset)folds.get(0);
                    ILabeledDataset test = (ILabeledDataset)folds.get(1);
                    LearnerRunReport failReport = new LearnerRunReport(train, test, e.getTrainTimeStart(), e.getTrainTimeEnd(), e.getTestTimeStart(), e.getTestTimeEnd(), e);
                    this.eventBus.post(new TrainTestSplitEvaluationFailedEvent<ILabeledInstance, ILabeledDataset<? extends ILabeledInstance>>(learner, failReport));
                    throw e;
                }
                if (this.hasListeners) {
                    this.eventBus.post(new TrainTestSplitEvaluationCompletedEvent<ILabeledInstance, ILabeledDataset<? extends ILabeledInstance>>(learner, report));
                }
                reports.add(report);
            }
            this.logger.debug("Compute metric ({}) for the diff of predictions and ground truth.", (Object)this.metric.getClass().getName());
            double score = this.metric.loss(reports.stream().map(ILearnerRunReport::getPredictionDiffList).collect(Collectors.toList()));
            this.logger.info("Computed value for metric {} of {} executions. Metric value is: {}. Pipeline: {}", new Object[]{this.metric, n, score, learner});
            return score;
        }
        catch (SplitFailedException | LearnerExecutionFailedException e) {
            this.logger.debug("Failed to evaluate the learner {}. Exception: {}", learner, (Object)ExceptionUtils.getStackTrace((Throwable)e));
            throw new ObjectEvaluationFailedException(e);
        }
    }

    public IFixedDatasetSplitSetGenerator<ILabeledDataset<? extends ILabeledInstance>> getSplitGenerator() {
        return this.splitGenerator;
    }

    public String getLoggerName() {
        return this.logger.getName();
    }

    public void setLoggerName(String name) {
        this.logger = LoggerFactory.getLogger((String)name);
        if (this.splitGenerator instanceof ILoggingCustomizable) {
            ((ILoggingCustomizable)this.splitGenerator).setLoggerName(name + ".splitgen");
            this.logger.trace("Setting logger of split generator {} to {}.splitgen", (Object)this.splitGenerator.getClass().getName(), (Object)name);
        } else {
            this.logger.trace("Split generator {} is not configurable for logging, so not configuring it.", (Object)this.splitGenerator.getClass().getName());
        }
        this.executor.setLoggerName(name + ".executor");
        this.logger.trace("Setting logger of learner executor {} to {}.executor", (Object)this.executor.getClass().getName(), (Object)name);
    }

    public void registerListener(Object listener) {
        this.eventBus.register(listener);
        this.hasListeners = true;
    }

    public IAggregatedPredictionPerformanceMeasure getMetric() {
        return this.metric;
    }
}

