/*
 * Decompiled with CFR 0.152.
 */
package ai.h2o.mojos.runtime;

import ai.h2o.mojos.runtime.MojoPipeline;
import ai.h2o.mojos.runtime.PipelineWiring;
import ai.h2o.mojos.runtime.a.b;
import ai.h2o.mojos.runtime.api.BasePipelineListener;
import ai.h2o.mojos.runtime.api.MojoColumnMeta;
import ai.h2o.mojos.runtime.api.MojoTransformMeta;
import ai.h2o.mojos.runtime.frame.MojoColumn;
import ai.h2o.mojos.runtime.frame.MojoColumnFactoryImpl;
import ai.h2o.mojos.runtime.frame.MojoFrame;
import ai.h2o.mojos.runtime.frame.MojoFrameBuilder;
import ai.h2o.mojos.runtime.frame.MojoFrameMeta;
import ai.h2o.mojos.runtime.frame.StringConverter;
import ai.h2o.mojos.runtime.frame.StringToDateConverter;
import ai.h2o.mojos.runtime.transforms.K;
import ai.h2o.mojos.runtime.transforms.MojoTransformBuilder;
import ai.h2o.mojos.runtime.transforms.l;
import ai.h2o.mojos.runtime.utils.DateParser;
import ai.h2o.mojos.runtime.utils.MojoDateTimeParserFactory;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MojoPipelineProtoImpl
extends MojoPipeline {
    private static final Logger log = LoggerFactory.getLogger(MojoPipelineProtoImpl.class);
    private final List<MojoColumnMeta> globalColumns;
    private final l root;
    private BasePipelineListener listener = new BasePipelineListener();
    private final Map<String, StringConverter> dateTimeConverters = new HashMap<String, StringConverter>(0);
    private boolean shapEnabled = false;
    private final Set<MojoColumnMeta> shapContribColumns = new LinkedHashSet<MojoColumnMeta>();
    private AllocatedBuffers allocatedBuffers;

    public MojoPipelineProtoImpl(List<MojoColumnMeta> globalColumns, l root) {
        super(root.b.uuid, root.b.creationTime, root.b.license);
        this.root = root;
        this.globalColumns = globalColumns;
        if (root.b.datetimeStringFormats != null) {
            for (Map.Entry<String, String> entry : root.b.datetimeStringFormats.entrySet()) {
                DateParser dateParser = new DateParser(MojoDateTimeParserFactory.forPattern(entry.getValue()));
                this.dateTimeConverters.put(entry.getKey(), new StringToDateConverter(dateParser));
            }
        }
    }

    @Override
    protected MojoFrameBuilder getFrameBuilder(MojoColumn.Kind kind) {
        return new MojoFrameBuilder(this.getMeta(kind), Arrays.asList(this.root.b.missingValues), this.dateTimeConverters);
    }

    @Override
    protected MojoFrameMeta getMeta(MojoColumn.Kind kind) {
        switch (kind) {
            case Feature: {
                return this.buffers().inputFrameMeta;
            }
            case Output: {
                return this.buffers().outputFrameMeta;
            }
        }
        throw new UnsupportedOperationException("Cannot generate meta for interim frame");
    }

    MojoFrame constructGlobalFrame(MojoFrame inputFrame, MojoFrame outputFrame) {
        ArrayList<MojoColumn> arrayList = new ArrayList<MojoColumn>();
        MojoFrameMeta mojoFrameMeta = inputFrame.getMeta();
        MojoFrameMeta mojoFrameMeta2 = outputFrame.getMeta();
        ArrayList<MojoColumnMeta> arrayList2 = new ArrayList<MojoColumnMeta>(this.globalColumns);
        int n2 = inputFrame.getNrows();
        MojoColumnFactoryImpl mojoColumnFactoryImpl = new MojoColumnFactoryImpl();
        for (MojoColumnMeta mojoColumnMeta : arrayList2) {
            Integer n3 = mojoFrameMeta.indexOf(mojoColumnMeta);
            if (n3 != null) {
                arrayList.add(inputFrame.getColumn(n3));
                continue;
            }
            Integer n4 = mojoFrameMeta2.indexOf(mojoColumnMeta);
            if (n4 != null) {
                arrayList.add(outputFrame.getColumn(n4));
                continue;
            }
            MojoColumn mojoColumn = mojoColumnFactoryImpl.create(mojoColumnMeta.getColumnType(), n2);
            arrayList.add(mojoColumn);
        }
        return MojoFrameBuilder.fromColumns(this.buffers().globalMeta, arrayList.toArray(new MojoColumn[0]));
    }

    @Override
    public MojoFrame transform(MojoFrame inputFrame, MojoFrame outputFrame) {
        assert (outputFrame.getNcols() > 0);
        MojoFrame mojoFrame = this.constructGlobalFrame(inputFrame, outputFrame);
        this.listener.onBatchStart(mojoFrame);
        AllocatedBuffers allocatedBuffers = this.buffers();
        for (MojoTransformMeta mojoTransformMeta : allocatedBuffers.wiring.transformsFlattened) {
            int n2;
            MojoTransformBuilder mojoTransformBuilder = mojoTransformMeta.getTransform();
            mojoTransformBuilder.transform(mojoFrame);
            this.listener.onBatchTransform(mojoTransformMeta);
            if (!this.shapEnabled || !(mojoTransformBuilder instanceof K)) continue;
            K k2 = (K)((Object)mojoTransformMeta.getTransform());
            int[] nArray = allocatedBuffers.pcIndicesByTransform.get(mojoTransformMeta);
            assert (nArray != null);
            double[][] dArrayArray = new double[nArray.length][];
            for (int i2 = 0; i2 < dArrayArray.length; ++i2) {
                dArrayArray[i2] = (double[])mojoFrame.getColumnData(nArray[i2]);
            }
            List<Double> list = allocatedBuffers.blender.get(mojoTransformMeta);
            assert (list != null) : "no weights found for " + mojoTransformMeta;
            assert (list.size() == mojoTransformBuilder.oindices.length) : String.format("scales mismatch on %s(%s): %d scales != %d oindices", mojoTransformBuilder.getId(), mojoTransformBuilder.getName(), list.size(), mojoTransformBuilder.oindices.length);
            int n3 = inputFrame.getNrows();
            double[] dArray = new double[mojoTransformBuilder.iindices.length];
            double[][] dArrayArray2 = new double[mojoTransformBuilder.oindices.length][];
            for (n2 = 0; n2 < dArrayArray2.length; ++n2) {
                dArrayArray2[n2] = new double[mojoTransformBuilder.iindices.length + 1];
            }
            for (n2 = 0; n2 < n3; ++n2) {
                int n4;
                int n5;
                for (n5 = 0; n5 < mojoTransformBuilder.iindices.length; ++n5) {
                    double d2;
                    n4 = mojoTransformBuilder.iindices[n5];
                    MojoColumn.Type type = mojoFrame.getColumnType(n4);
                    switch (type) {
                        case Float32: {
                            d2 = ((float[])mojoFrame.getColumnData(n4))[n2];
                            break;
                        }
                        case Float64: {
                            d2 = ((double[])mojoFrame.getColumnData(n4))[n2];
                            break;
                        }
                        default: {
                            throw new UnsupportedOperationException(String.format("cannot do SHAP on %s:%s", new Object[]{mojoFrame.getColumnName(n4), type}));
                        }
                    }
                    dArray[n5] = d2;
                }
                double[][] dArrayArray3 = dArrayArray2;
                n4 = dArrayArray2.length;
                for (int i3 = 0; i3 < n4; ++i3) {
                    double[] dArray2 = dArrayArray3[i3];
                    Arrays.fill(dArray2, Double.NaN);
                }
                k2.a(dArray, dArrayArray2);
                n5 = 0;
                for (n4 = 0; n4 < dArrayArray2.length; ++n4) {
                    double d3 = list.get(n4);
                    for (int i4 = 0; i4 < dArrayArray2[n4].length; ++i4) {
                        double d4 = dArrayArray2[n4][i4];
                        if (Double.isNaN(d4)) {
                            String string = mojoFrame.getColumnName(nArray[n5]);
                            throw new IllegalStateException(String.format("Row %d: %s(%s) did not compute shapOutput[%d][%d] : `%s`", n2, mojoTransformBuilder.getId(), mojoTransformBuilder.getClass().getName(), n4, i4, string));
                        }
                        double d5 = d4 * d3;
                        double[] dArray3 = dArrayArray[n5];
                        int n6 = n2;
                        dArray3[n6] = dArray3[n6] + d5;
                        ++n5;
                    }
                }
            }
        }
        this.listener.onBatchEnd();
        return outputFrame;
    }

    @Override
    public void setShapPredictContrib(boolean enable) {
        if (enable == this.shapEnabled) {
            return;
        }
        if (this.allocatedBuffers != null) {
            throw new IllegalStateException("Cannot change SHAP flag after internal buffers have been allocated");
        }
        this.shapEnabled = enable;
    }

    @Override
    public void setListener(BasePipelineListener listener) {
        this.listener = listener;
    }

    private AllocatedBuffers buffers() {
        if (this.allocatedBuffers == null) {
            log.trace("Allocating buffers");
            this.allocatedBuffers = new AllocatedBuffers();
        }
        return this.allocatedBuffers;
    }

    private class AllocatedBuffers {
        final MojoFrameMeta globalMeta;
        final Map<MojoTransformMeta, int[]> pcIndicesByTransform = new LinkedHashMap<MojoTransformMeta, int[]>();
        final PipelineWiring wiring;
        final MojoFrameMeta inputFrameMeta;
        final MojoFrameMeta outputFrameMeta;
        final Map<MojoTransformMeta, List<Double>> blender;

        public AllocatedBuffers() {
            this.wiring = new PipelineWiring(MojoPipelineProtoImpl.this.globalColumns, MojoPipelineProtoImpl.this.root);
            if (!MojoPipelineProtoImpl.this.shapEnabled) {
                this.globalMeta = new MojoFrameMeta(MojoPipelineProtoImpl.this.globalColumns);
                this.outputFrameMeta = this.globalMeta.subFrame(((MojoPipelineProtoImpl)MojoPipelineProtoImpl.this).root.oindices);
                this.blender = null;
            } else {
                int n2;
                Object object;
                Serializable serializable;
                LinkedHashMap<String, Integer> linkedHashMap = new LinkedHashMap<String, Integer>();
                for (MojoTransformMeta mojoTransformMeta : this.wiring.transformsFlattened) {
                    serializable = mojoTransformMeta.getTransform();
                    if (!(serializable instanceof K)) continue;
                    object = this.wiring.getGroupInputColumns(mojoTransformMeta.getTransformationGroup(), mojoTransformMeta.getInputIndices());
                    n2 = mojoTransformMeta.getOutputIndices().length;
                    int[] nArray = new int[n2 * (object.size() + 1)];
                    int n3 = 0;
                    for (int i2 = 0; i2 < n2; ++i2) {
                        String string = n2 > 1 ? "." + ((MojoPipelineProtoImpl)MojoPipelineProtoImpl.this).root.b.outputClassLabels.get(i2) : "";
                        Iterator iterator = object.iterator();
                        while (iterator.hasNext()) {
                            String string2 = (String)iterator.next();
                            nArray[n3] = this.shapColumn(linkedHashMap, "contrib_" + string2 + string);
                            ++n3;
                        }
                        nArray[n3] = this.shapColumn(linkedHashMap, "contrib_bias" + string);
                        ++n3;
                    }
                    this.pcIndicesByTransform.put(mojoTransformMeta, nArray);
                }
                if (this.pcIndicesByTransform.size() > 1) {
                    Iterator<MojoTransformMeta> iterator = new b(this.wiring);
                    this.blender = ((b)((Object)iterator)).a(((MojoPipelineProtoImpl)MojoPipelineProtoImpl.this).root.oindices);
                } else {
                    this.blender = new LinkedHashMap<MojoTransformMeta, List<Double>>();
                    for (MojoTransformMeta mojoTransformMeta : this.pcIndicesByTransform.keySet()) {
                        serializable = new ArrayList(mojoTransformMeta.getOutputIndices().length);
                        int[] nArray = mojoTransformMeta.getOutputIndices();
                        object = nArray;
                        n2 = nArray.length;
                        for (int i3 = 0; i3 < n2; ++i3) {
                            serializable.add(1.0);
                        }
                        this.blender.put(mojoTransformMeta, (List<Double>)((Object)serializable));
                    }
                }
                assert (this.pcIndicesByTransform.size() == this.blender.size()) : String.format("Mismatch between number of models (%d) and blender branches (%d)", this.pcIndicesByTransform.size(), this.blender.size());
                this.outputFrameMeta = new MojoFrameMeta(new ArrayList<MojoColumnMeta>(MojoPipelineProtoImpl.this.shapContribColumns));
                this.globalMeta = new MojoFrameMeta(MojoPipelineProtoImpl.this.globalColumns);
            }
            if (this.outputFrameMeta.size() == 0) {
                throw new IllegalStateException("No columns in output frame");
            }
            this.inputFrameMeta = this.globalMeta.subFrame(((MojoPipelineProtoImpl)MojoPipelineProtoImpl.this).root.iindices);
            ((MojoPipelineProtoImpl)MojoPipelineProtoImpl.this).root.b.consistencyChecks(this.globalMeta);
        }

        private int shapColumn(Map<String, Integer> shapColumnsByName, String name) {
            Integer n2 = shapColumnsByName.get(name);
            if (n2 != null) {
                return n2;
            }
            n2 = MojoPipelineProtoImpl.this.globalColumns.size();
            MojoColumnMeta mojoColumnMeta = MojoColumnMeta.create(name, MojoColumn.Type.Float64);
            shapColumnsByName.put(name, n2);
            MojoPipelineProtoImpl.this.globalColumns.add(mojoColumnMeta);
            MojoPipelineProtoImpl.this.shapContribColumns.add(mojoColumnMeta);
            return n2;
        }
    }
}

