/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.core.evaluation.evaluator;

import ai.libs.jaicore.logging.LoggerUtil;
import ai.libs.jaicore.ml.core.evaluation.evaluator.LearnerRunReport;
import ai.libs.jaicore.ml.core.evaluation.evaluator.TypelessPredictionDiff;
import java.util.List;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance;
import org.api4.java.ai.ml.core.evaluation.execution.ILearnerRunReport;
import org.api4.java.ai.ml.core.evaluation.execution.ISupervisedLearnerExecutor;
import org.api4.java.ai.ml.core.evaluation.execution.LearnerExecutionFailedException;
import org.api4.java.ai.ml.core.evaluation.execution.LearnerExecutionInterruptedException;
import org.api4.java.ai.ml.core.exception.PredictionException;
import org.api4.java.ai.ml.core.exception.TrainingException;
import org.api4.java.ai.ml.core.learner.ISupervisedLearner;
import org.api4.java.common.control.ILoggingCustomizable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SupervisedLearnerExecutor
implements ISupervisedLearnerExecutor,
ILoggingCustomizable {
    private Logger logger = LoggerFactory.getLogger(SupervisedLearnerExecutor.class);

    public <I extends ILabeledInstance, D extends ILabeledDataset<? extends I>> ILearnerRunReport execute(ISupervisedLearner<I, D> learner, D train, D test) throws LearnerExecutionFailedException, LearnerExecutionInterruptedException {
        long startTrainTime = System.currentTimeMillis();
        try {
            this.logger.info("Fitting the learner (class: {}) {} with {} instances, each of which with {} attributes", new Object[]{learner.getClass().getName(), learner, train.size(), train.getNumAttributes()});
            learner.fit(train);
        }
        catch (InterruptedException e) {
            long now = System.currentTimeMillis();
            this.logger.info("Training was interrupted after {}ms, sending respective LearnerExecutionInterruptedException.", (Object)(now - startTrainTime));
            throw new LearnerExecutionInterruptedException(startTrainTime, now);
        }
        catch (TrainingException e) {
            long now = System.currentTimeMillis();
            this.logger.info("Training failed due to an {} after {}ms.", (Object)((Object)((Object)e)).getClass().getName(), (Object)(now - startTrainTime));
            throw new LearnerExecutionFailedException(startTrainTime, now, (Exception)((Object)e));
        }
        long endTrainTime = System.currentTimeMillis();
        this.logger.debug("Training finished successfully after {}ms. Now acquiring predictions.", (Object)(endTrainTime - startTrainTime));
        try {
            return this.getReportForTrainedLearner(learner, train, test, startTrainTime, endTrainTime);
        }
        catch (InterruptedException e) {
            long now = System.currentTimeMillis();
            this.logger.info("Learner was interrupted during prediction after a runtime of {}ms for training and {}ms for testing ({}ms total walltime).", new Object[]{endTrainTime - startTrainTime, now - endTrainTime, now - startTrainTime});
            if (Thread.currentThread().isInterrupted()) {
                this.logger.warn("Observed an InterruptedException while evaluating a learner of type {} ({}) AND the thread is interrupted. This should never happen! Here is the detailed information: {}", new Object[]{learner.getClass(), learner, LoggerUtil.getExceptionInfo((Throwable)e)});
            }
            throw new LearnerExecutionInterruptedException(startTrainTime, endTrainTime, endTrainTime, System.currentTimeMillis());
        }
        catch (PredictionException e) {
            this.logger.info("Prediction failed with exception {}.", (Object)((Object)((Object)e)).getClass().getName());
            throw new LearnerExecutionFailedException(startTrainTime, endTrainTime, endTrainTime, System.currentTimeMillis(), (Exception)((Object)e));
        }
    }

    public <I extends ILabeledInstance, D extends ILabeledDataset<? extends I>> ILearnerRunReport execute(ISupervisedLearner<I, D> learner, D test) throws LearnerExecutionFailedException {
        long startTestTime = System.currentTimeMillis();
        try {
            return this.getReportForTrainedLearner(learner, null, test, -1L, -1L);
        }
        catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            throw new LearnerExecutionFailedException(-1L, -1L, startTestTime, System.currentTimeMillis(), (Exception)e);
        }
        catch (PredictionException e) {
            throw new LearnerExecutionFailedException(-1L, -1L, startTestTime, System.currentTimeMillis(), (Exception)((Object)e));
        }
    }

    private <I extends ILabeledInstance, D extends ILabeledDataset<? extends I>> ILearnerRunReport getReportForTrainedLearner(ISupervisedLearner<I, D> learner, D train, D test, long trainingStartTime, long trainingEndTime) throws PredictionException, InterruptedException {
        long start = System.currentTimeMillis();
        List predictions = learner.predict(test).getPredictions();
        long endTestTime = System.currentTimeMillis();
        TypelessPredictionDiff diff = new TypelessPredictionDiff();
        for (int i = 0; i < predictions.size(); ++i) {
            diff.addPair(((ILabeledInstance)test.get(i)).getLabel(), predictions.get(i));
        }
        return new LearnerRunReport(train, test, trainingStartTime, trainingEndTime, start, endTestTime, diff);
    }

    public String getLoggerName() {
        return this.logger.getName();
    }

    public void setLoggerName(String name) {
        this.logger = LoggerFactory.getLogger((String)name);
    }
}

