/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.converter.mining;

import com.google.common.base.Function;
import com.google.common.collect.Iterables;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.True;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segment;
import org.dmg.pmml.mining.Segmentation;
import org.dmg.pmml.regression.RegressionModel;
import org.dmg.pmml.regression.RegressionTable;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.ContinuousLabel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ValueUtil;
import org.jpmml.converter.regression.RegressionModelUtil;

public class MiningModelUtil {
    private static final Function<Model, Feature> MODEL_PREDICTION = new Function<Model, Feature>(){

        public Feature apply(Model model) {
            Output output = model.getOutput();
            if (output == null || !output.hasOutputFields()) {
                throw new IllegalArgumentException();
            }
            OutputField outputField = (OutputField)Iterables.getLast((Iterable)output.getOutputFields());
            return new ContinuousFeature(null, outputField.getName(), outputField.getDataType());
        }
    };

    private MiningModelUtil() {
    }

    public static MiningModel createRegression(Model model, Schema schema) {
        ContinuousLabel continuousLabel = (ContinuousLabel)schema.getLabel();
        Feature feature = (Feature)MODEL_PREDICTION.apply((Object)model);
        RegressionTable regressionTable = RegressionModelUtil.createRegressionTable(Collections.singletonList(feature), null, Collections.singletonList(1.0));
        RegressionModel regressionModel = new RegressionModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(continuousLabel), null).addRegressionTables(new RegressionTable[]{regressionTable});
        return MiningModelUtil.createModelChain(Arrays.asList(model, regressionModel), schema);
    }

    public static MiningModel createBinaryLogisticClassification(Model model, double intercept, double coefficient, RegressionModel.NormalizationMethod normalizationMethod, boolean hasProbabilityDistribution, Schema schema) {
        CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();
        if (categoricalLabel.size() != 2) {
            throw new IllegalArgumentException();
        }
        Feature feature = (Feature)MODEL_PREDICTION.apply((Object)model);
        RegressionModel regressionModel = RegressionModelUtil.createBinaryLogisticClassification(Collections.singletonList(feature), intercept, Collections.singletonList(coefficient), normalizationMethod, hasProbabilityDistribution, schema);
        return MiningModelUtil.createModelChain(Arrays.asList(model, regressionModel), schema);
    }

    public static MiningModel createClassification(List<? extends Model> models, RegressionModel.NormalizationMethod normalizationMethod, boolean hasProbabilityDistribution, Schema schema) {
        CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();
        if (categoricalLabel.size() < 3 || categoricalLabel.size() != models.size()) {
            throw new IllegalArgumentException();
        }
        ArrayList<RegressionTable> regressionTables = new ArrayList<RegressionTable>();
        for (int i = 0; i < categoricalLabel.size(); ++i) {
            Feature feature = (Feature)MODEL_PREDICTION.apply((Object)models.get(i));
            RegressionTable regressionTable = RegressionModelUtil.createRegressionTable(Collections.singletonList(feature), null, Collections.singletonList(1.0)).setTargetCategory(categoricalLabel.getValue(i));
            regressionTables.add(regressionTable);
        }
        RegressionModel regressionModel = new RegressionModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(categoricalLabel), regressionTables).setNormalizationMethod(normalizationMethod).setOutput(hasProbabilityDistribution ? ModelUtil.createProbabilityOutput(categoricalLabel) : null);
        ArrayList<? extends Model> segmentationModels = new ArrayList<Model>(models);
        segmentationModels.add((Model)regressionModel);
        return MiningModelUtil.createModelChain(segmentationModels, schema);
    }

    public static MiningModel createModelChain(List<? extends Model> models, Schema schema) {
        Segmentation segmentation = MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.MODEL_CHAIN, models);
        Model lastModel = (Model)Iterables.getLast(models);
        MiningModel miningModel = new MiningModel(lastModel.getMiningFunction(), ModelUtil.createMiningSchema(schema)).setSegmentation(segmentation);
        return miningModel;
    }

    public static Segmentation createSegmentation(Segmentation.MultipleModelMethod multipleModelMethod, List<? extends Model> models) {
        return MiningModelUtil.createSegmentation(multipleModelMethod, models, null);
    }

    public static Segmentation createSegmentation(Segmentation.MultipleModelMethod multipleModelMethod, List<? extends Model> models, List<? extends Number> weights) {
        if (weights != null && models.size() != weights.size()) {
            throw new IllegalArgumentException();
        }
        ArrayList<Segment> segments = new ArrayList<Segment>();
        for (int i = 0; i < models.size(); ++i) {
            Model model = models.get(i);
            Number weight = weights != null ? (Number)weights.get(i) : (Number)null;
            Segment segment = new Segment().setId(String.valueOf(i + 1)).setPredicate((Predicate)new True()).setModel(model);
            if (weight != null && !ValueUtil.isOne(weight)) {
                segment.setWeight(ValueUtil.asDouble(weight));
            }
            segments.add(segment);
        }
        Segmentation segmentation = new Segmentation(multipleModelMethod, segments);
        return segmentation;
    }
}

