/*
 * Decompiled with CFR 0.152.
 */
package ai.h2o.mojos.runtime;

import ai.h2o.mojos.runtime.api.MojoColumnMeta;
import ai.h2o.mojos.runtime.api.MojoTransformationGroup;
import ai.h2o.mojos.runtime.transforms.J;
import ai.h2o.mojos.runtime.transforms.M;
import ai.h2o.mojos.runtime.transforms.MojoTransform;
import ai.h2o.mojos.runtime.transforms.n;
import ai.h2o.mojos.runtime.utils.ArrayReaderUtils;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class PipelineWiring {
    private static final Logger log = LoggerFactory.getLogger(PipelineWiring.class);
    private final List<MojoColumnMeta> globalColumns;
    public final List<MojoTransform> transformsFlattened = new ArrayList<MojoTransform>();
    private final MojoTransform[] producers;
    private final Set<String> transformedFeatures;
    private int prematureTraversalStops = 0;
    public List<MojoTransform> shapTransforms = new ArrayList<MojoTransform>();

    public void noshap(MojoTransform transform) {
        log.trace("Do not compute SHAP on: {}", (Object)transform);
        if (!this.shapTransforms.remove(transform)) {
            throw new IllegalStateException("Failed to remove transformation from shapTransforms: " + transform);
        }
    }

    public PipelineWiring(List<MojoColumnMeta> globalColumns, n root) {
        this.globalColumns = globalColumns;
        this.transformedFeatures = root.c.transformedFeatures == null ? Collections.emptySet() : new LinkedHashSet<String>(root.c.transformedFeatures);
        this.addChildrenFlattened(root.b);
        this.producers = new MojoTransform[globalColumns.size()];
        for (MojoTransform mojoTransform : this.transformsFlattened) {
            int[] nArray = mojoTransform.oindices;
            int n2 = mojoTransform.oindices.length;
            for (int i2 = 0; i2 < n2; ++i2) {
                int n3 = nArray[i2];
                this.producers[n3] = mojoTransform;
            }
            if (!(mojoTransform instanceof M)) continue;
            this.shapTransforms.add(mojoTransform);
        }
    }

    private void addChildrenFlattened(List<MojoTransform> nestedTransforms) {
        for (MojoTransform mojoTransform : nestedTransforms) {
            if (mojoTransform instanceof n) {
                n n2 = (n)mojoTransform;
                this.addChildrenFlattened(n2.b);
                continue;
            }
            this.transformsFlattened.add(mojoTransform);
        }
    }

    public MojoTransform getProducer(int columnIndex) {
        return this.producers[columnIndex];
    }

    public List<MojoColumnMeta> getColumns() {
        return this.globalColumns;
    }

    public <R> List<R> search(Collection<Integer> entryPoints, Visitor<R> visitor) {
        Integer n2;
        LinkedHashSet<MojoTransform> linkedHashSet = new LinkedHashSet<MojoTransform>();
        ArrayList<R> arrayList = new ArrayList<R>();
        LinkedList<Integer> linkedList = new LinkedList<Integer>(entryPoints);
        while ((n2 = linkedList.pollFirst()) != null) {
            MojoTransform mojoTransform = this.getProducer(n2);
            if (mojoTransform == null || linkedHashSet.contains(mojoTransform)) continue;
            linkedHashSet.add(mojoTransform);
            try {
                R r2 = visitor.visit(this, mojoTransform);
                if (r2 != null) {
                    arrayList.add(r2);
                    continue;
                }
            }
            catch (IllegalArgumentException illegalArgumentException) {}
            linkedList.addAll(ArrayReaderUtils.fromArrayToList(mojoTransform.iindices));
        }
        return arrayList;
    }

    public MojoTransform shapCapableOrigin(int fromIndex) {
        int n2 = fromIndex;
        while (true) {
            MojoTransform mojoTransform = this.getProducer(n2);
            log.trace("traversing through {}({}), seeking index {}", mojoTransform.toString(), mojoTransform.getTransformationGroup(), n2);
            if (mojoTransform instanceof J) {
                int n3 = -1;
                for (int i2 = 0; i2 < mojoTransform.oindices.length; ++i2) {
                    if (mojoTransform.oindices[i2] != n2) continue;
                    n3 = mojoTransform.iindices[i2];
                }
                if (n3 < 0) {
                    throw new IllegalArgumentException("output index not found in softmax: " + n2);
                }
                n2 = n3;
                continue;
            }
            if (mojoTransform instanceof M) {
                return mojoTransform;
            }
            if (mojoTransform.iindices.length != 1) {
                throw new IllegalArgumentException("only 1:1 transform expected while traversing from blending operations up to the ShapCapable model, but found " + mojoTransform.getName());
            }
            n2 = mojoTransform.iindices[0];
        }
    }

    private MojoColumnMeta getGroupInputColumn(MojoTransformationGroup tg, int inputIndex) {
        int n2 = this.getGroupInputColumnIndex(tg, inputIndex);
        return this.globalColumns.get(n2);
    }

    public int getGroupInputColumnIndex(MojoTransformationGroup tg, int inputIndex) {
        int n2 = inputIndex;
        if (tg != null) {
            MojoTransformationGroup mojoTransformationGroup;
            MojoTransform mojoTransform;
            while ((mojoTransform = this.getProducer(n2)) != null && (mojoTransformationGroup = mojoTransform.getTransformationGroup()) != null && mojoTransformationGroup.getId().equals(tg.getId())) {
                int[] nArray = mojoTransform.iindices;
                if (mojoTransform.iindices.length != 1) {
                    throw new IllegalStateException(String.format("producer of column #%d has %d columns; exactly 1 is required: %s", n2, nArray.length, mojoTransform));
                }
                if (mojoTransform.oindices.length != 1) {
                    throw new IllegalStateException(String.format("producer of column #%d has %d columns; exactly 1 is required: %s", n2, mojoTransform.oindices.length, mojoTransform));
                }
                String string = this.globalColumns.get(n2).getColumnName();
                if (this.transformedFeatures.contains(string)) {
                    log.debug("traversal stops on '{}' PRIOR reaching boundary of group '{}', due to a hint from `Pipeline.transformed`; is it constructed correctly?", (Object)string, (Object)tg);
                    ++this.prematureTraversalStops;
                    break;
                }
                n2 = nArray[0];
            }
        }
        return n2;
    }

    public Set<String> getGroupInputColumns(MojoTransformationGroup tg, int[] inputIndices) {
        LinkedHashSet<String> linkedHashSet = new LinkedHashSet<String>();
        int[] nArray = inputIndices;
        int n2 = inputIndices.length;
        for (int i2 = 0; i2 < n2; ++i2) {
            int n3 = nArray[i2];
            MojoColumnMeta mojoColumnMeta = this.getGroupInputColumn(tg, n3);
            linkedHashSet.add(mojoColumnMeta.getColumnName());
        }
        return linkedHashSet;
    }

    public void reportPrematureTraversals() {
        if (this.prematureTraversalStops > 0) {
            log.warn("Premature traversal stops occurred {} times. See DEBUG log for full list. Columns in `Pipeline.transformed` might need review.", (Object)this.prematureTraversalStops);
        }
    }

    public static interface Visitor<R> {
        public R visit(PipelineWiring var1, MojoTransform var2) throws IllegalArgumentException;
    }
}

