package ai.h2o.mojos.runtime.h2o3;

import ai.h2o.mojos.runtime.AbstractPipelineLoader;
import ai.h2o.mojos.runtime.MojoPipeline;
import ai.h2o.mojos.runtime.MojoPipelineMeta;
import ai.h2o.mojos.runtime.MojoPipelineProtoImpl;
import ai.h2o.mojos.runtime.api.MojoColumnMeta;
import ai.h2o.mojos.runtime.api.MojoTransformMeta;
import ai.h2o.mojos.runtime.api.backend.ReaderBackend;
import ai.h2o.mojos.runtime.frame.MojoColumn;
import ai.h2o.mojos.runtime.frame.MojoFrameMeta;
import ai.h2o.mojos.runtime.transforms.MojoTransform;
import ai.h2o.mojos.runtime.transforms.MojoTransformBuilder;
import ai.h2o.mojos.runtime.utils.ArrayReaderUtils;
import hex.genmodel.GenModel;
import hex.genmodel.MojoModel;
import hex.genmodel.MojoReaderBackend;
import hex.genmodel.easy.EasyPredictModelWrapper;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.joda.time.DateTime;

import static ai.h2o.mojos.runtime.h2o3.H2O3PipelineLoader.wrapModelForPrediction;

class KLimePipelineLoader extends AbstractPipelineLoader {
    private final MojoPipelineMeta pipelineMeta;

    public KLimePipelineLoader(ReaderBackend backend) throws IOException {
        super(backend);
        final MojoReaderBackend mojoReader = new H2O3BackendAdapter(backend);
        final MojoModel model = MojoModel.load(mojoReader);
        final EasyPredictModelWrapper easyPredictModelWrapper = wrapModelForPrediction(model);

        final String name = "klime:" + model.getModelCategory().toString();
        final List<MojoColumnMeta> columns = new ArrayList<>();
        final int[] inputIndices = readInputIndices(columns, model);
        final int[] outputIndices = readOutputIndices(columns, model);
        final MojoFrameMeta meta = new MojoFrameMeta(columns);

        final String desc = String.format("algo: %s, balanceClasses: %s, attributes: %s, default threshold: %s",
            model._algoName,
            model._balanceClasses,
            model._modelAttributes,
            model._defaultThreshold);
        final MojoTransformMeta mtm = new MojoTransformMeta(name, desc, inputIndices, outputIndices, 0, null);
        final MojoTransform transform = new KlimeTransform(easyPredictModelWrapper, inputIndices, outputIndices);
        mtm.setTransformBuilder(new MojoTransformBuilder(meta, inputIndices, outputIndices) {
            @Override
            public MojoTransform build() {
                return transform;
            }
        });
        mtm.setTransform(transform);
        final DateTime creationTime = new DateTime(1970, 1, 1, 0, 0); //TODO
        pipelineMeta = new MojoPipelineMeta(Collections.singletonList(mtm), meta, true,
            ArrayReaderUtils.fromArrayToList(inputIndices),
            ArrayReaderUtils.fromArrayToList(outputIndices),
            model.getUUID(), creationTime);
        pipelineMeta.license = "H2O-3 Opensource";

    }

    @Override
    public List<MojoColumnMeta> getColumns() {
        return pipelineMeta.getColumns();
    }

    @Override
    public List<MojoTransformMeta> getTransformations() {
        return pipelineMeta.transforms;
    }

    @Override
    protected final MojoPipeline internalLoad() {
        return new MojoPipelineProtoImpl(pipelineMeta);
    }

    static int[] readInputIndices(final List<MojoColumnMeta> columns, final GenModel genModel) {
        final int[] inputIndices = new int[genModel.getNumCols()];
        for (int i = 0; i < genModel.getNumCols(); i += 1) {
            final String columnName = genModel.getNames()[i];
            final MojoColumn.Type columnType = (genModel.getDomainValues(i) == null) ? MojoColumn.Type.Float64 : MojoColumn.Type.Str;
            inputIndices[i] = columns.size();
            columns.add(MojoColumnMeta.newInput(columnName, columnType));
        }
        return inputIndices;
    }

    private static int[] readOutputIndices(final List<MojoColumnMeta> columns, final GenModel genModel) {
        // TODO following is a bit strange exercise, let's check with MM and/or Navdeep why is that
        final List<String> mypredictorsList = new ArrayList<>(Arrays.asList(genModel.getNames()));
        mypredictorsList.remove(genModel.getResponseName());
        final String[] predictors = mypredictorsList.toArray(new String[0]);

        final int predsSize = genModel.getPredsSize();
        final int[] outputIndices = new int[predsSize];
        outputIndices[0] = columns.size();
        columns.add(MojoColumnMeta.newOutput(genModel.getResponseName(), MojoColumn.Type.Float64));
        outputIndices[1] = columns.size();
        columns.add(MojoColumnMeta.newOutput("cluster", MojoColumn.Type.Float64));
        for (int i = 2; i < predsSize; i += 1) {
            outputIndices[i] = columns.size();
            final String outputColumnName = "rc_" + predictors[i - 2]; // "rc" stands for Reason Code
            columns.add(MojoColumnMeta.newOutput(outputColumnName, MojoColumn.Type.Float64));
        }
        return outputIndices;
    }

}
