/*
 * Decompiled with CFR 0.152.
 */
package ai.libs.jaicore.ml.regression.learner;

import ai.libs.jaicore.ml.core.learner.ASupervisedLearner;
import ai.libs.jaicore.ml.regression.singlelabel.SingleTargetRegressionPrediction;
import ai.libs.jaicore.ml.regression.singlelabel.SingleTargetRegressionPredictionBatch;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Objects;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance;
import org.api4.java.ai.ml.core.evaluation.IPrediction;
import org.api4.java.ai.ml.core.evaluation.IPredictionBatch;
import org.api4.java.ai.ml.core.exception.PredictionException;
import org.api4.java.ai.ml.core.exception.TrainingException;
import org.api4.java.ai.ml.regression.evaluation.IRegressionPrediction;
import org.api4.java.ai.ml.regression.evaluation.IRegressionResultBatch;

public class ConstantRegressor
extends ASupervisedLearner<ILabeledInstance, ILabeledDataset<? extends ILabeledInstance>, IPrediction, IPredictionBatch> {
    private Double constantValue;

    public void fit(ILabeledDataset<? extends ILabeledInstance> dTrain) throws TrainingException, InterruptedException {
        Objects.requireNonNull(dTrain);
        if (dTrain.isEmpty()) {
            throw new IllegalArgumentException("Cannot train majority classifier with empty training set.");
        }
        ArrayList targetValues = new ArrayList(dTrain.size());
        dTrain.stream().map(x -> (double)((Double)x.getLabel())).forEach(targetValues::add);
        this.constantValue = targetValues.stream().filter(x -> x != null).mapToDouble(x -> x).average().getAsDouble();
    }

    @Override
    public IRegressionPrediction predict(ILabeledInstance xTest) throws PredictionException, InterruptedException {
        return new SingleTargetRegressionPrediction(this.constantValue);
    }

    @Override
    public IRegressionResultBatch predict(ILabeledInstance[] dTest) throws PredictionException, InterruptedException {
        ArrayList<IRegressionPrediction> preds = new ArrayList<IRegressionPrediction>(dTest.length);
        for (ILabeledInstance i : dTest) {
            preds.add(this.predict(i));
        }
        return new SingleTargetRegressionPredictionBatch((Collection<IRegressionPrediction>)preds);
    }
}

