/*
 * Decompiled with CFR 0.152.
 */
package ai.h2o.mojos.runtime;

import ai.h2o.mojos.runtime.MojoPipeline;
import ai.h2o.mojos.runtime.PipelineWiring;
import ai.h2o.mojos.runtime.a.a;
import ai.h2o.mojos.runtime.api.BasePipelineListener;
import ai.h2o.mojos.runtime.api.MojoColumnMeta;
import ai.h2o.mojos.runtime.c.b;
import ai.h2o.mojos.runtime.frame.MojoColumn;
import ai.h2o.mojos.runtime.frame.MojoColumnFactoryImpl;
import ai.h2o.mojos.runtime.frame.MojoFrame;
import ai.h2o.mojos.runtime.frame.MojoFrameBuilder;
import ai.h2o.mojos.runtime.frame.MojoFrameMeta;
import ai.h2o.mojos.runtime.frame.StringConverter;
import ai.h2o.mojos.runtime.frame.StringToDateConverter;
import ai.h2o.mojos.runtime.transforms.M;
import ai.h2o.mojos.runtime.transforms.MojoTransform;
import ai.h2o.mojos.runtime.transforms.n;
import ai.h2o.mojos.runtime.utils.DateParser;
import ai.h2o.mojos.runtime.utils.MojoDateTimeParserFactory;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MojoPipelineProtoImpl
extends MojoPipeline {
    private static final Logger log = LoggerFactory.getLogger(MojoPipelineProtoImpl.class);
    private final List<MojoColumnMeta> globalColumns;
    private final n root;
    private BasePipelineListener listener = BasePipelineListener.NOOP;
    private final Map<String, StringConverter> dateTimeConverters = new HashMap<String, StringConverter>(0);
    private boolean shapEnabled = false;
    private final Set<MojoColumnMeta> shapContribColumns = new LinkedHashSet<MojoColumnMeta>();
    private AllocatedBuffers allocatedBuffers;

    public MojoPipelineProtoImpl(List<MojoColumnMeta> globalColumns, n root) {
        super(root.c.uuid, root.c.creationTime, root.c.license);
        this.root = root;
        this.globalColumns = globalColumns;
        if (root.c.datetimeStringFormats != null) {
            for (Map.Entry<String, String> entry : root.c.datetimeStringFormats.entrySet()) {
                DateParser dateParser = new DateParser(MojoDateTimeParserFactory.forPattern(entry.getValue()));
                this.dateTimeConverters.put(entry.getKey(), new StringToDateConverter(dateParser));
            }
        }
    }

    @Override
    protected MojoFrameBuilder getFrameBuilder(MojoColumn.Kind kind) {
        return new MojoFrameBuilder(this.getMeta(kind), Arrays.asList(this.root.c.missingValues), this.dateTimeConverters);
    }

    @Override
    protected MojoFrameMeta getMeta(MojoColumn.Kind kind) {
        switch (kind) {
            case Feature: {
                return this.buffers().inputFrameMeta;
            }
            case Output: {
                return this.buffers().outputFrameMeta;
            }
        }
        throw new UnsupportedOperationException("Cannot generate meta for interim frame");
    }

    MojoFrame constructGlobalFrame(MojoFrame inputFrame, MojoFrame outputFrame) {
        ArrayList<MojoColumn> arrayList = new ArrayList<MojoColumn>();
        MojoFrameMeta mojoFrameMeta = inputFrame.getMeta();
        MojoFrameMeta mojoFrameMeta2 = outputFrame.getMeta();
        ArrayList<MojoColumnMeta> arrayList2 = new ArrayList<MojoColumnMeta>(this.globalColumns);
        int n2 = inputFrame.getNrows();
        MojoColumnFactoryImpl mojoColumnFactoryImpl = new MojoColumnFactoryImpl();
        for (MojoColumnMeta mojoColumnMeta : arrayList2) {
            Integer n3 = mojoFrameMeta.indexOf(mojoColumnMeta);
            if (n3 != null) {
                arrayList.add(inputFrame.getColumn(n3));
                continue;
            }
            Integer n4 = mojoFrameMeta2.indexOf(mojoColumnMeta);
            if (n4 != null) {
                arrayList.add(outputFrame.getColumn(n4));
                continue;
            }
            MojoColumn mojoColumn = mojoColumnFactoryImpl.create(mojoColumnMeta.getColumnType(), n2);
            arrayList.add(mojoColumn);
        }
        return MojoFrameBuilder.fromColumns(this.buffers().globalMeta, arrayList.toArray(new MojoColumn[0]));
    }

    @Override
    public MojoFrame transform(MojoFrame inputFrame, MojoFrame outputFrame) {
        assert (outputFrame.getNcols() > 0);
        MojoFrame mojoFrame = this.constructGlobalFrame(inputFrame, outputFrame);
        this.listener.onBatchStart(mojoFrame);
        AllocatedBuffers allocatedBuffers = this.buffers();
        for (MojoTransform mojoTransform : allocatedBuffers.wiring.transformsFlattened) {
            this.listener.onTransformHead(mojoTransform);
            mojoTransform.transform(mojoFrame);
            this.listener.onTransformResult(mojoTransform);
        }
        if (this.shapEnabled) {
            for (MojoTransform mojoTransform : allocatedBuffers.wiring.shapTransforms) {
                int n2;
                int n3;
                M m2 = (M)((Object)mojoTransform);
                int[] nArray = allocatedBuffers.pcIndicesByTransform.get(mojoTransform);
                assert (nArray != null);
                double[][] dArrayArray = new double[nArray.length][];
                for (n3 = 0; n3 < dArrayArray.length; ++n3) {
                    dArrayArray[n3] = (double[])mojoFrame.getColumnData(nArray[n3]);
                }
                n3 = inputFrame.getNrows();
                double[] dArray = new double[mojoTransform.iindices.length];
                double[][] dArrayArray2 = new double[mojoTransform.oindices.length][];
                for (n2 = 0; n2 < dArrayArray2.length; ++n2) {
                    dArrayArray2[n2] = new double[mojoTransform.iindices.length + 1];
                }
                for (n2 = 0; n2 < n3; ++n2) {
                    Object object;
                    int n4;
                    Object object2;
                    int n5;
                    int n6;
                    for (n6 = 0; n6 < mojoTransform.iindices.length; ++n6) {
                        double d2;
                        n5 = mojoTransform.iindices[n6];
                        object2 = mojoFrame.getColumnType(n5);
                        switch (1.$SwitchMap$ai$h2o$mojos$runtime$frame$MojoColumn$Type[((Enum)object2).ordinal()]) {
                            case 1: {
                                d2 = ((float[])mojoFrame.getColumnData(n5))[n2];
                                break;
                            }
                            case 2: {
                                d2 = ((double[])mojoFrame.getColumnData(n5))[n2];
                                break;
                            }
                            default: {
                                throw new UnsupportedOperationException(String.format("cannot do SHAP on %s:%s", mojoFrame.getColumnName(n5), object2));
                            }
                        }
                        dArray[n6] = d2;
                    }
                    double[][] dArrayArray3 = dArrayArray2;
                    n5 = dArrayArray2.length;
                    for (n4 = 0; n4 < n5; ++n4) {
                        object = dArrayArray3[n4];
                        Arrays.fill((double[])object, Double.NaN);
                    }
                    m2.a(dArray, dArrayArray2);
                    n6 = 0;
                    for (n5 = 0; n5 < dArrayArray2.length; ++n5) {
                        n4 = mojoTransform.oindices[n5];
                        object = allocatedBuffers.blender.get(n4);
                        object2 = this.globalColumns.get(n4);
                        log.trace("Scaler for column {}('{}') is {}", n4, object2, object);
                        if (object == null) {
                            throw new IllegalStateException(String.format("Error in blender - no scaler found for column %d('%s')", n4, ((MojoColumnMeta)object2).getColumnName()));
                        }
                        for (int i2 = 0; i2 < dArrayArray2[n5].length; ++i2) {
                            double d3 = dArrayArray2[n5][i2];
                            if (Double.isNaN(d3)) {
                                String string = mojoFrame.getColumnName(nArray[n6]);
                                throw new IllegalStateException(String.format("Row %d: %s(%s) did not compute shapOutput[%d][%d] : `%s`", n2, mojoTransform.getId(), mojoTransform.getClass().getName(), n5, i2, string));
                            }
                            double d4 = ((b)object).a(d3, mojoFrame, n2);
                            double[] dArray2 = dArrayArray[n6];
                            int n7 = n2;
                            dArray2[n7] = dArray2[n7] + d4;
                            ++n6;
                        }
                    }
                }
            }
        }
        this.listener.onBatchEnd();
        return outputFrame;
    }

    @Override
    public void setShapPredictContrib(boolean enable) {
        if (enable == this.shapEnabled) {
            return;
        }
        if (this.allocatedBuffers != null) {
            throw new IllegalStateException("Cannot change SHAP flag after internal buffers have been allocated");
        }
        this.shapEnabled = enable;
    }

    @Override
    public void setListener(BasePipelineListener listener) {
        this.listener = listener;
    }

    private AllocatedBuffers buffers() {
        if (this.allocatedBuffers == null) {
            log.trace("Allocating buffers");
            this.allocatedBuffers = new AllocatedBuffers();
        }
        return this.allocatedBuffers;
    }

    private class AllocatedBuffers {
        final MojoFrameMeta globalMeta;
        final Map<MojoTransform, int[]> pcIndicesByTransform = new LinkedHashMap<MojoTransform, int[]>();
        final PipelineWiring wiring;
        final MojoFrameMeta inputFrameMeta;
        final MojoFrameMeta outputFrameMeta;
        final Map<Integer, b> blender;

        public AllocatedBuffers() {
            this.wiring = new PipelineWiring(MojoPipelineProtoImpl.this.globalColumns, MojoPipelineProtoImpl.this.root);
            if (!MojoPipelineProtoImpl.this.shapEnabled) {
                this.globalMeta = new MojoFrameMeta(MojoPipelineProtoImpl.this.globalColumns);
                this.outputFrameMeta = this.globalMeta.subFrame(((MojoPipelineProtoImpl)MojoPipelineProtoImpl.this).root.oindices);
                this.blender = null;
            } else {
                int n2;
                int n3;
                LinkedHashMap linkedHashMap = new LinkedHashMap();
                for (MojoTransform mojoTransform : this.wiring.transformsFlattened) {
                    if (!(mojoTransform instanceof M)) continue;
                    Set<String> set = this.wiring.getGroupInputColumns(mojoTransform.getTransformationGroup(), mojoTransform.iindices);
                    int n4 = mojoTransform.oindices.length;
                    int[] nArray = new int[n4 * (set.size() + 1)];
                    n3 = 0;
                    for (n2 = 0; n2 < n4; ++n2) {
                        String string = n4 > 1 ? "." + ((MojoPipelineProtoImpl)MojoPipelineProtoImpl.this).root.c.outputClassLabels.get(n2) : "";
                        for (String string2 : set) {
                            nArray[n3] = this.shapColumn(linkedHashMap, "contrib_" + string2 + string);
                            ++n3;
                        }
                        nArray[n3] = this.shapColumn(linkedHashMap, "contrib_bias" + string);
                        ++n3;
                    }
                    this.pcIndicesByTransform.put(mojoTransform, nArray);
                }
                this.wiring.reportPrematureTraversals();
                if (this.pcIndicesByTransform.size() > 1) {
                    this.blender = a.a(this.wiring, ((MojoPipelineProtoImpl)MojoPipelineProtoImpl.this).root.oindices);
                } else {
                    this.blender = new LinkedHashMap<Integer, b>();
                    b b2 = new b();
                    for (MojoTransform mojoTransform : this.pcIndicesByTransform.keySet()) {
                        int[] nArray = mojoTransform.oindices;
                        int n5 = mojoTransform.oindices.length;
                        for (n3 = 0; n3 < n5; ++n3) {
                            n2 = nArray[n3];
                            this.blender.put(n2, b2);
                        }
                    }
                }
                this.outputFrameMeta = new MojoFrameMeta(new ArrayList<MojoColumnMeta>(MojoPipelineProtoImpl.this.shapContribColumns));
                this.globalMeta = new MojoFrameMeta(MojoPipelineProtoImpl.this.globalColumns);
            }
            if (this.outputFrameMeta.size() == 0) {
                throw new IllegalStateException("No columns in output frame");
            }
            this.inputFrameMeta = this.globalMeta.subFrame(((MojoPipelineProtoImpl)MojoPipelineProtoImpl.this).root.iindices);
            ((MojoPipelineProtoImpl)MojoPipelineProtoImpl.this).root.c.consistencyChecks(this.globalMeta);
            for (MojoTransform mojoTransform : this.wiring.transformsFlattened) {
                if (!(mojoTransform instanceof ai.h2o.mojos.runtime.transforms.a)) continue;
                log.trace("Steps are traceable in {}", (Object)mojoTransform);
                ai.h2o.mojos.runtime.transforms.a a2 = (ai.h2o.mojos.runtime.transforms.a)((Object)mojoTransform);
                a2.a(MojoPipelineProtoImpl.this.listener);
            }
        }

        private int shapColumn(Map<String, Integer> shapColumnsByName, String name) {
            Integer n2 = shapColumnsByName.get(name);
            if (n2 != null) {
                return n2;
            }
            n2 = MojoPipelineProtoImpl.this.globalColumns.size();
            MojoColumnMeta mojoColumnMeta = MojoColumnMeta.create(name, MojoColumn.Type.Float64);
            shapColumnsByName.put(name, n2);
            MojoPipelineProtoImpl.this.globalColumns.add(mojoColumnMeta);
            MojoPipelineProtoImpl.this.shapContribColumns.add(mojoColumnMeta);
            return n2;
        }
    }
}

