/*
 * Decompiled with CFR 0.152.
 */
package ai.h2o.mojos.runtime;

import ai.h2o.mojos.runtime.MojoPipeline;
import ai.h2o.mojos.runtime.OriginalMatrix;
import ai.h2o.mojos.runtime.PipelineWiring;
import ai.h2o.mojos.runtime.a.a;
import ai.h2o.mojos.runtime.api.BasePipelineListener;
import ai.h2o.mojos.runtime.api.MojoColumnMeta;
import ai.h2o.mojos.runtime.c.b;
import ai.h2o.mojos.runtime.frame.MojoColumn;
import ai.h2o.mojos.runtime.frame.MojoColumnFactoryImpl;
import ai.h2o.mojos.runtime.frame.MojoFrame;
import ai.h2o.mojos.runtime.frame.MojoFrameBuilder;
import ai.h2o.mojos.runtime.frame.MojoFrameMeta;
import ai.h2o.mojos.runtime.frame.StringConverter;
import ai.h2o.mojos.runtime.frame.StringToDateConverter;
import ai.h2o.mojos.runtime.transforms.L;
import ai.h2o.mojos.runtime.transforms.MojoTransform;
import ai.h2o.mojos.runtime.transforms.MojoTransformExecPipeBuilder;
import ai.h2o.mojos.runtime.utils.DateParser;
import ai.h2o.mojos.runtime.utils.MojoDateTimeParserFactory;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MojoPipelineProtoImpl
extends MojoPipeline {
    private static final Logger log = LoggerFactory.getLogger(MojoPipelineProtoImpl.class);
    private final List<MojoColumnMeta> globalColumns;
    private final MojoTransformExecPipeBuilder root;
    private BasePipelineListener listener = BasePipelineListener.NOOP;
    private final Map<String, StringConverter> dateTimeConverters = new HashMap<String, StringConverter>(0);
    private boolean shapEnabled = false;
    private final Set<MojoColumnMeta> shapContribColumns = new LinkedHashSet<MojoColumnMeta>();
    private AllocatedBuffers allocatedBuffers;
    private boolean shapOriginal;

    public MojoPipelineProtoImpl(List<MojoColumnMeta> globalColumns, MojoTransformExecPipeBuilder root) {
        super(root.pipelineMeta.uuid, root.pipelineMeta.creationTime, root.pipelineMeta.license);
        this.root = root;
        this.globalColumns = globalColumns;
        if (root.pipelineMeta.datetimeStringFormats != null) {
            for (Map.Entry<String, String> entry : root.pipelineMeta.datetimeStringFormats.entrySet()) {
                DateParser dateParser = new DateParser(MojoDateTimeParserFactory.forPattern(entry.getValue(), false));
                this.dateTimeConverters.put(entry.getKey(), new StringToDateConverter(dateParser));
            }
        }
    }

    protected MojoFrameBuilder getFrameBuilder(MojoColumn.Kind kind) {
        return new MojoFrameBuilder(this.getMeta(kind), Arrays.asList(this.root.pipelineMeta.missingValues), this.dateTimeConverters);
    }

    protected MojoFrameMeta getMeta(MojoColumn.Kind kind) {
        switch (kind) {
            case Feature: {
                return this.buffers().inputFrameMeta;
            }
            case Output: {
                return this.buffers().outputFrameMeta;
            }
        }
        throw new UnsupportedOperationException("Cannot generate meta for interim frame");
    }

    MojoFrame constructGlobalFrame(MojoFrame inputFrame, MojoFrame outputFrame) {
        ArrayList<MojoColumn> arrayList = new ArrayList<MojoColumn>();
        MojoFrameMeta mojoFrameMeta = inputFrame.getMeta();
        MojoFrameMeta mojoFrameMeta2 = outputFrame.getMeta();
        ArrayList<MojoColumnMeta> arrayList2 = new ArrayList<MojoColumnMeta>(this.globalColumns);
        int n2 = inputFrame.getNrows();
        MojoColumnFactoryImpl mojoColumnFactoryImpl = new MojoColumnFactoryImpl();
        for (MojoColumnMeta mojoColumnMeta : arrayList2) {
            Integer n3 = mojoFrameMeta.indexOf(mojoColumnMeta);
            if (n3 != null) {
                arrayList.add(inputFrame.getColumn(n3.intValue()));
                continue;
            }
            Integer n4 = mojoFrameMeta2.indexOf(mojoColumnMeta);
            if (n4 != null) {
                arrayList.add(outputFrame.getColumn(n4.intValue()));
                continue;
            }
            MojoColumn mojoColumn = mojoColumnFactoryImpl.create(mojoColumnMeta.getColumnType(), n2);
            arrayList.add(mojoColumn);
        }
        return MojoFrameBuilder.fromColumns((MojoFrameMeta)this.buffers().globalMeta, (MojoColumn[])arrayList.toArray(new MojoColumn[0]));
    }

    public MojoFrame transform(MojoFrame inputFrame, MojoFrame outputFrame) {
        assert (outputFrame.getNcols() > 0);
        MojoFrame mojoFrame = this.constructGlobalFrame(inputFrame, outputFrame);
        this.listener.onBatchStart(mojoFrame, this.root.iindices);
        AllocatedBuffers allocatedBuffers = this.buffers();
        for (MojoTransform object : allocatedBuffers.wiring.transformsFlattened) {
            this.listener.onTransformHead(object);
            object.transform(mojoFrame);
            this.listener.onTransformResult(object);
        }
        if (this.shapEnabled) {
            int n2 = inputFrame.getNrows();
            for (MojoTransform mojoTransform : allocatedBuffers.wiring.shapTransforms) {
                L l2 = (L)mojoTransform;
                int[] nArray = allocatedBuffers.pcIndicesByTransform.get(mojoTransform);
                assert (nArray != null);
                double[][] dArrayArray = new double[nArray.length][];
                for (int i2 = 0; i2 < dArrayArray.length; ++i2) {
                    dArrayArray[i2] = (double[])mojoFrame.getColumnData(nArray[i2]);
                }
                double[] i2 = new double[mojoTransform.iindices.length];
                double[][] dArrayArray2 = new double[mojoTransform.oindices.length][];
                for (int i3 = 0; i3 < dArrayArray2.length; ++i3) {
                    dArrayArray2[i3] = new double[mojoTransform.iindices.length + 1];
                }
                OriginalMatrix i3 = l2.a();
                if (this.shapOriginal && i3 == null) {
                    throw new UnsupportedOperationException("Missing original matrix - cannot compute original SHAP for " + mojoTransform);
                }
                for (int i4 = 0; i4 < n2; ++i4) {
                    Object object;
                    int d2;
                    MojoColumn.Type type;
                    int n3;
                    int dArrayArray3;
                    for (dArrayArray3 = 0; dArrayArray3 < mojoTransform.iindices.length; ++dArrayArray3) {
                        double d3;
                        n3 = mojoTransform.iindices[dArrayArray3];
                        type = mojoFrame.getColumnType(n3);
                        switch (type) {
                            case Float32: {
                                d3 = ((float[])mojoFrame.getColumnData(n3))[i4];
                                break;
                            }
                            case Float64: {
                                d3 = ((double[])mojoFrame.getColumnData(n3))[i4];
                                break;
                            }
                            default: {
                                throw new UnsupportedOperationException(String.format("cannot do SHAP on %s:%s", mojoFrame.getColumnName(n3), type));
                            }
                        }
                        i2[dArrayArray3] = d3;
                    }
                    double[][] dArrayArray4 = dArrayArray2;
                    n3 = dArrayArray2.length;
                    for (d2 = 0; d2 < n3; ++d2) {
                        object = dArrayArray4[d2];
                        Arrays.fill((double[])object, Double.NaN);
                    }
                    l2.a(i2, dArrayArray2);
                    dArrayArray3 = 0;
                    for (n3 = 0; n3 < dArrayArray2.length; ++n3) {
                        d2 = mojoTransform.oindices[n3];
                        object = allocatedBuffers.blender.get(d2);
                        type = this.globalColumns.get(d2);
                        if (object == null) {
                            throw new IllegalStateException(String.format("Error in blender - no scaler found for column %d('%s')", d2, type.getColumnName()));
                        }
                        for (int i5 = 0; i5 < dArrayArray2[n3].length; ++i5) {
                            double d4 = dArrayArray2[n3][i5];
                            String string = mojoFrame.getColumnName(nArray[dArrayArray3]);
                            if (Double.isNaN(d4)) {
                                throw new IllegalStateException(String.format("Row %d: %s(%s) did not compute shapOutput[%d][%d] : `%s`", i4, mojoTransform.getId(), mojoTransform.getClass().getName(), n3, i5, string));
                            }
                            double d5 = ((b)object).a(d4, mojoFrame, i4);
                            double[] dArray = dArrayArray[dArrayArray3];
                            int n4 = i4;
                            dArray[n4] = dArray[n4] + d5;
                            if (this.shapOriginal) {
                                if (i3 == null) {
                                    throw new UnsupportedOperationException("Missing original matrix - cannot compute original SHAP for " + mojoTransform);
                                }
                                Map<String, Double> map = i3.getRow(string);
                                for (Map.Entry<String, Double> entry : map.entrySet()) {
                                    String string2 = entry.getKey();
                                    Double d6 = entry.getValue();
                                    int n5 = mojoFrame.getMeta().getColumnIndex(string2);
                                    double[] dArray2 = (double[])mojoFrame.getColumnData(n5);
                                    double[] dArray3 = dArray2;
                                    int n6 = i4;
                                    dArray2[n6] = dArray2[n6] + d5 * d6;
                                }
                            }
                            ++dArrayArray3;
                        }
                    }
                }
            }
        }
        this.listener.onBatchEnd();
        return outputFrame;
    }

    public void setShapPredictContrib(boolean enable) {
        if (enable == this.shapEnabled) {
            return;
        }
        if (this.allocatedBuffers != null) {
            throw new IllegalStateException("Cannot change SHAP flag after internal buffers have been allocated");
        }
        this.shapEnabled = enable;
    }

    public void setShapPredictContribOriginal(boolean enable) {
        this.setShapPredictContrib(true);
        if (enable == this.shapOriginal) {
            return;
        }
        this.shapOriginal = true;
    }

    public void setListener(BasePipelineListener listener) {
        this.listener = listener;
    }

    private AllocatedBuffers buffers() {
        if (this.allocatedBuffers == null) {
            log.trace("Allocating buffers");
            this.allocatedBuffers = new AllocatedBuffers();
        }
        return this.allocatedBuffers;
    }

    private class AllocatedBuffers {
        final MojoFrameMeta globalMeta;
        final Map<MojoTransform, int[]> pcIndicesByTransform = new LinkedHashMap<MojoTransform, int[]>();
        final PipelineWiring wiring;
        final MojoFrameMeta inputFrameMeta;
        final MojoFrameMeta outputFrameMeta;
        final Map<Integer, b> blender;

        public AllocatedBuffers() {
            this.wiring = new PipelineWiring(MojoPipelineProtoImpl.this.globalColumns, MojoPipelineProtoImpl.this.root);
            if (!MojoPipelineProtoImpl.this.shapEnabled) {
                this.globalMeta = new MojoFrameMeta(MojoPipelineProtoImpl.this.globalColumns);
                this.outputFrameMeta = this.globalMeta.subFrame(((MojoPipelineProtoImpl)MojoPipelineProtoImpl.this).root.oindices);
                this.blender = null;
            } else {
                MojoTransform mojoTransform;
                Object object3;
                if (this.wiring.isTreeMetalearner()) {
                    throw new UnsupportedOperationException("Computing SHAP contributions is not supported for pipelines with tree metalearner");
                }
                LinkedHashMap<String, Integer> linkedHashMap = new LinkedHashMap<String, Integer>();
                for (Object object2 : this.wiring.transformsFlattened) {
                    if (!(object2 instanceof L)) continue;
                    object3 = this.wiring.getGroupInputColumns(object2.getTransformationGroup(), object2.iindices);
                    int n2 = object2.oindices.length;
                    int[] nArray = this.buildShapColumns((Map<String, Integer>)linkedHashMap, (Set<String>)object3, n2);
                    this.pcIndicesByTransform.put((MojoTransform)object2, nArray);
                }
                this.wiring.reportPrematureTraversals();
                if (this.pcIndicesByTransform.size() > 1) {
                    this.blender = a.a(this.wiring, ((MojoPipelineProtoImpl)MojoPipelineProtoImpl.this).root.oindices);
                } else {
                    this.blender = new LinkedHashMap<Integer, b>();
                    mojoTransform = new b();
                    for (Object object3 : this.pcIndicesByTransform.keySet()) {
                        int[] nArray = object3.oindices;
                        int n3 = object3.oindices.length;
                        for (int i2 = 0; i2 < n3; ++i2) {
                            int n4 = nArray[i2];
                            this.blender.put(n4, (b)mojoTransform);
                        }
                    }
                }
                if (MojoPipelineProtoImpl.this.shapOriginal) {
                    Object object2;
                    mojoTransform = new LinkedHashSet();
                    object2 = ((MojoPipelineProtoImpl)MojoPipelineProtoImpl.this).root.iindices;
                    int n5 = ((MojoPipelineProtoImpl)MojoPipelineProtoImpl.this).root.iindices.length;
                    for (int i3 = 0; i3 < n5; ++i3) {
                        MojoTransform mojoTransform2 = object2[i3];
                        mojoTransform.add(((MojoColumnMeta)MojoPipelineProtoImpl.this.globalColumns.get((int)mojoTransform2)).getColumnName());
                    }
                    MojoPipelineProtoImpl.this.shapContribColumns.clear();
                    int n6 = ((MojoPipelineProtoImpl)MojoPipelineProtoImpl.this).root.pipelineMeta.probabilityComplementDetected ? 1 : ((MojoPipelineProtoImpl)MojoPipelineProtoImpl.this).root.oindices.length;
                    object3 = this.buildShapColumns((Map<String, Integer>)linkedHashMap, (Set<String>)mojoTransform, n6);
                    log.trace("Original SHAP column indices are: {}", (Object)Arrays.toString((int[])object3));
                }
                this.outputFrameMeta = new MojoFrameMeta(new ArrayList(MojoPipelineProtoImpl.this.shapContribColumns));
                this.globalMeta = new MojoFrameMeta(MojoPipelineProtoImpl.this.globalColumns);
            }
            if (this.outputFrameMeta.size() == 0) {
                throw new IllegalStateException("No columns in output frame");
            }
            this.inputFrameMeta = this.globalMeta.subFrame(((MojoPipelineProtoImpl)MojoPipelineProtoImpl.this).root.iindices);
            ((MojoPipelineProtoImpl)MojoPipelineProtoImpl.this).root.pipelineMeta.consistencyChecks(this.globalMeta);
            for (MojoTransform mojoTransform : this.wiring.transformsFlattened) {
                if (!(mojoTransform instanceof ai.h2o.mojos.runtime.transforms.a)) continue;
                log.trace("Steps are traceable in {}", (Object)mojoTransform);
                ai.h2o.mojos.runtime.transforms.a a2 = (ai.h2o.mojos.runtime.transforms.a)mojoTransform;
                a2.a(MojoPipelineProtoImpl.this.listener);
            }
        }

        private int[] buildShapColumns(Map<String, Integer> shapColumnsByName, Set<String> inputColNames, int ocnt) {
            int[] nArray = new int[ocnt * (inputColNames.size() + 1)];
            int n2 = 0;
            for (int i2 = 0; i2 < ocnt; ++i2) {
                String string = ocnt > 1 ? "." + ((MojoPipelineProtoImpl)MojoPipelineProtoImpl.this).root.pipelineMeta.outputClassLabels.get(i2) : "";
                for (String string2 : inputColNames) {
                    nArray[n2] = this.shapColumn(shapColumnsByName, "contrib_" + string2 + string);
                    ++n2;
                }
                nArray[n2] = this.shapColumn(shapColumnsByName, "contrib_bias" + string);
                ++n2;
            }
            return nArray;
        }

        private int shapColumn(Map<String, Integer> shapColumnsByName, String name) {
            MojoColumnMeta mojoColumnMeta;
            Integer n2 = shapColumnsByName.get(name);
            if (n2 != null) {
                mojoColumnMeta = (MojoColumnMeta)MojoPipelineProtoImpl.this.globalColumns.get(n2);
            } else {
                n2 = MojoPipelineProtoImpl.this.globalColumns.size();
                mojoColumnMeta = MojoColumnMeta.create((String)name, (MojoColumn.Type)MojoColumn.Type.Float64);
                shapColumnsByName.put(name, n2);
                MojoPipelineProtoImpl.this.globalColumns.add(mojoColumnMeta);
            }
            MojoPipelineProtoImpl.this.shapContribColumns.add(mojoColumnMeta);
            return n2;
        }
    }
}

