/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.evaluator.spark;

import java.util.ArrayList;
import java.util.List;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.Transformer;
import org.apache.spark.ml.feature.ColumnPruner;
import org.dmg.pmml.ResultFeature;
import org.jpmml.evaluator.Evaluator;
import org.jpmml.evaluator.OutputField;
import org.jpmml.evaluator.TargetField;
import org.jpmml.evaluator.spark.ColumnExploder;
import org.jpmml.evaluator.spark.ColumnProducer;
import org.jpmml.evaluator.spark.OutputColumnProducer;
import org.jpmml.evaluator.spark.PMMLTransformer;
import org.jpmml.evaluator.spark.ProbabilityColumnProducer;
import org.jpmml.evaluator.spark.ScalaUtil;
import org.jpmml.evaluator.spark.TargetColumnProducer;

public class TransformerBuilder {
    private Evaluator evaluator = null;
    private List<ColumnProducer<?>> columnProducers = new ArrayList();
    private boolean exploded = false;

    public TransformerBuilder(Evaluator evaluator) {
        this.setEvaluator(evaluator);
    }

    public TransformerBuilder withTargetCols() {
        Evaluator evaluator = this.getEvaluator();
        List targetFields = evaluator.getTargetFields();
        for (TargetField targetField : targetFields) {
            this.columnProducers.add(new TargetColumnProducer(targetField, null));
        }
        return this;
    }

    public TransformerBuilder withLabelCol(String columnName) {
        Evaluator evaluator = this.getEvaluator();
        List targetFields = evaluator.getTargetFields();
        if (targetFields.size() != 1) {
            throw new IllegalArgumentException();
        }
        TargetField targetField = (TargetField)targetFields.get(0);
        this.columnProducers.add(new TargetColumnProducer(targetField, columnName));
        return this;
    }

    public TransformerBuilder withProbabilityCol(String columnName) {
        return this.withProbabilityCol(columnName, null);
    }

    public TransformerBuilder withProbabilityCol(String columnName, List<String> labels) {
        Evaluator evaluator = this.getEvaluator();
        List targetFields = evaluator.getTargetFields();
        if (targetFields.size() != 1) {
            throw new IllegalArgumentException();
        }
        TargetField targetField = (TargetField)targetFields.get(0);
        ArrayList<String> values = new ArrayList<String>();
        List outputFields = evaluator.getOutputFields();
        for (OutputField outputField : outputFields) {
            org.dmg.pmml.OutputField pmmlOutputField = outputField.getOutputField();
            ResultFeature resultFeature = pmmlOutputField.getResultFeature();
            switch (resultFeature) {
                case PROBABILITY: {
                    String value = pmmlOutputField.getValue();
                    if (value == null) break;
                    values.add(value);
                    break;
                }
            }
        }
        if (values.isEmpty()) {
            throw new IllegalArgumentException();
        }
        if (!(labels == null || labels.size() == values.size() && labels.containsAll(values))) {
            throw new IllegalArgumentException();
        }
        this.columnProducers.add(new ProbabilityColumnProducer(targetField, columnName, labels != null ? labels : values));
        return this;
    }

    public TransformerBuilder withOutputCols() {
        Evaluator evaluator = this.getEvaluator();
        List outputFields = evaluator.getOutputFields();
        for (OutputField outputField : outputFields) {
            this.columnProducers.add(new OutputColumnProducer(outputField, null));
        }
        return this;
    }

    public TransformerBuilder exploded(boolean exploded) {
        this.exploded = exploded;
        return this;
    }

    public Transformer build() {
        Evaluator evaluator = this.getEvaluator();
        PMMLTransformer pmmlTransformer = new PMMLTransformer(evaluator, this.columnProducers);
        if (this.exploded) {
            ColumnExploder columnExploder = new ColumnExploder(pmmlTransformer.getOutputCol());
            ColumnPruner columnPruner = new ColumnPruner(ScalaUtil.singletonSet(pmmlTransformer.getOutputCol()));
            PipelineModel pipelineModel = new PipelineModel(null, new Transformer[]{pmmlTransformer, columnExploder, columnPruner});
            return pipelineModel;
        }
        return pmmlTransformer;
    }

    private Evaluator getEvaluator() {
        return this.evaluator;
    }

    private void setEvaluator(Evaluator evaluator) {
        this.evaluator = evaluator;
    }
}

