/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.mlplan.core;

import ai.libs.jaicore.components.api.IComponentInstance;
import ai.libs.jaicore.ml.core.learner.ASupervisedLearner;
import ai.libs.mlplan.core.ITimeTrackingLearner;
import java.util.ArrayList;
import java.util.List;
import org.api4.java.ai.ml.core.dataset.IInstance;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance;
import org.api4.java.ai.ml.core.evaluation.IPrediction;
import org.api4.java.ai.ml.core.evaluation.IPredictionBatch;
import org.api4.java.ai.ml.core.exception.PredictionException;
import org.api4.java.ai.ml.core.exception.TrainingException;
import org.api4.java.ai.ml.core.learner.ISupervisedLearner;
import org.api4.java.common.control.ILoggingCustomizable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class TimeTrackingLearnerWrapper
extends ASupervisedLearner<ILabeledInstance, ILabeledDataset<? extends ILabeledInstance>, IPrediction, IPredictionBatch>
implements ITimeTrackingLearner,
ILoggingCustomizable {
    private Logger logger = LoggerFactory.getLogger(TimeTrackingLearnerWrapper.class);
    private final ISupervisedLearner<ILabeledInstance, ILabeledDataset<? extends ILabeledInstance>> wrappedSLClassifier;
    private IComponentInstance ci;
    private List<Long> fitTimes;
    private List<Long> batchPredictTimes;
    private List<Long> perInstancePredictionTimes;
    private Double predictedInductionTime = null;
    private Double predictedInferenceTime = null;
    private Double score;

    public TimeTrackingLearnerWrapper(IComponentInstance ci, ISupervisedLearner<ILabeledInstance, ILabeledDataset<? extends ILabeledInstance>> wrappedLearner) {
        this.ci = ci;
        this.wrappedSLClassifier = wrappedLearner;
        this.fitTimes = new ArrayList<Long>();
        this.batchPredictTimes = new ArrayList<Long>();
        this.perInstancePredictionTimes = new ArrayList<Long>();
    }

    public void fit(ILabeledDataset<? extends ILabeledInstance> dTrain) throws TrainingException, InterruptedException {
        TimeTracker tracker = new TimeTracker();
        this.wrappedSLClassifier.fit(dTrain);
        this.fitTimes.add(tracker.stop());
    }

    public IPrediction predict(ILabeledInstance xTest) throws PredictionException, InterruptedException {
        TimeTracker tracker = new TimeTracker();
        IPrediction prediction = this.wrappedSLClassifier.predict((IInstance)xTest);
        this.perInstancePredictionTimes.add(tracker.stop());
        return prediction;
    }

    public IPredictionBatch predict(ILabeledInstance[] dTest) throws PredictionException, InterruptedException {
        TimeTracker tracker = new TimeTracker();
        IPredictionBatch prediction = this.wrappedSLClassifier.predict((IInstance[])dTest);
        long time = tracker.stop();
        this.batchPredictTimes.add(time);
        this.perInstancePredictionTimes.add(Math.round((double)time / (double)dTest.length));
        return prediction;
    }

    @Override
    public List<Long> getFitTimes() {
        return this.fitTimes;
    }

    @Override
    public List<Long> getBatchPredictionTimesInMS() {
        return this.batchPredictTimes;
    }

    @Override
    public List<Long> getInstancePredictionTimesInMS() {
        return this.perInstancePredictionTimes;
    }

    @Override
    public IComponentInstance getComponentInstance() {
        return this.ci;
    }

    @Override
    public void setPredictedInductionTime(String inductionTime) {
        try {
            this.predictedInductionTime = Double.parseDouble(inductionTime);
        }
        catch (Exception e) {
            this.logger.warn("Could not parse double from provided induction time {}.", (Object)inductionTime, (Object)e);
        }
    }

    @Override
    public void setPredictedInferenceTime(String inferenceTime) {
        try {
            this.predictedInferenceTime = Double.parseDouble(inferenceTime);
        }
        catch (Exception e) {
            this.logger.warn("Could not parse double from provided inference time {}.", (Object)inferenceTime, (Object)e);
        }
    }

    @Override
    public Double getPredictedInductionTime() {
        return this.predictedInductionTime;
    }

    @Override
    public Double getPredictedInferenceTime() {
        return this.predictedInferenceTime;
    }

    @Override
    public void setScore(Double score) {
        if (score == null) {
            return;
        }
        this.score = score;
    }

    @Override
    public Double getScore() {
        return this.score;
    }

    @Override
    public ISupervisedLearner<ILabeledInstance, ILabeledDataset<? extends ILabeledInstance>> getLearner() {
        return this.wrappedSLClassifier;
    }

    public String toString() {
        return this.getClass().getName() + " -> " + this.wrappedSLClassifier.toString();
    }

    public String getLoggerName() {
        return this.logger.getName();
    }

    public void setLoggerName(String name) {
        this.logger = LoggerFactory.getLogger((String)name);
        if (this.wrappedSLClassifier instanceof ILoggingCustomizable) {
            ((ILoggingCustomizable)this.wrappedSLClassifier).setLoggerName(name + ".bl");
        } else {
            this.logger.info("Underlying learner {} is not {}, so not customizing its logger.", this.wrappedSLClassifier.getClass(), ILoggingCustomizable.class);
        }
    }

    class TimeTracker {
        private final long startTime = System.currentTimeMillis();

        private TimeTracker() {
        }

        public long stop() {
            return System.currentTimeMillis() - this.startTime;
        }
    }
}

