/*
 * Decompiled with CFR 0.152.
 */
package ai.h2o.mojos.runtime;

import ai.h2o.mojos.runtime.api.MojoColumnMeta;
import ai.h2o.mojos.runtime.api.MojoTransformMeta;
import ai.h2o.mojos.runtime.api.MojoTransformationGroup;
import ai.h2o.mojos.runtime.transforms.H;
import ai.h2o.mojos.runtime.transforms.K;
import ai.h2o.mojos.runtime.transforms.MojoTransformBuilder;
import ai.h2o.mojos.runtime.transforms.l;
import ai.h2o.mojos.runtime.utils.ArrayReaderUtils;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class PipelineWiring {
    private static final Logger log = LoggerFactory.getLogger(PipelineWiring.class);
    private final List<MojoColumnMeta> globalColumns;
    public final List<MojoTransformMeta> transformsFlattened = new ArrayList<MojoTransformMeta>();
    private final MojoTransformMeta[] producers;
    private final Set<String> transformedFeatures;

    public PipelineWiring(List<MojoColumnMeta> globalColumns, l root) {
        this.globalColumns = globalColumns;
        this.transformedFeatures = root.b.transformedFeatures == null ? Collections.emptySet() : new LinkedHashSet<String>(root.b.transformedFeatures);
        this.addChildrenFlattened(root.a);
        this.producers = new MojoTransformMeta[globalColumns.size()];
        for (MojoTransformMeta mojoTransformMeta : this.transformsFlattened) {
            for (int n2 : mojoTransformMeta.getOutputIndices()) {
                this.producers[n2] = mojoTransformMeta;
            }
        }
    }

    private void addChildrenFlattened(List<MojoTransformMeta> nestedTransforms) {
        for (MojoTransformMeta mojoTransformMeta : nestedTransforms) {
            MojoTransformBuilder mojoTransformBuilder = mojoTransformMeta.getTransform();
            if (mojoTransformBuilder instanceof l) {
                l l2 = (l)mojoTransformBuilder;
                this.addChildrenFlattened(l2.a);
                continue;
            }
            this.transformsFlattened.add(mojoTransformMeta);
        }
    }

    public MojoTransformMeta getProducer(int columnIndex) {
        return this.producers[columnIndex];
    }

    public List<MojoColumnMeta> getColumns() {
        return this.globalColumns;
    }

    public <R> List<R> search(Collection<Integer> entryPoints, Visitor<R> visitor) {
        Integer n2;
        LinkedHashSet<MojoTransformMeta> linkedHashSet = new LinkedHashSet<MojoTransformMeta>();
        ArrayList<R> arrayList = new ArrayList<R>();
        LinkedList<Integer> linkedList = new LinkedList<Integer>(entryPoints);
        while ((n2 = linkedList.pollFirst()) != null) {
            MojoTransformMeta mojoTransformMeta = this.getProducer(n2);
            if (mojoTransformMeta == null || linkedHashSet.contains(mojoTransformMeta)) continue;
            linkedHashSet.add(mojoTransformMeta);
            try {
                R r2 = visitor.visit(this, mojoTransformMeta);
                if (r2 != null) {
                    arrayList.add(r2);
                    continue;
                }
            }
            catch (IllegalArgumentException illegalArgumentException) {}
            linkedList.addAll(ArrayReaderUtils.fromArrayToList(mojoTransformMeta.getInputIndices()));
        }
        return arrayList;
    }

    public MojoTransformMeta shapCapableOrigin(int fromIndex) {
        int n2 = fromIndex;
        while (true) {
            MojoTransformMeta mojoTransformMeta = this.getProducer(n2);
            log.trace("traversing through {}, seeking index {}", (Object)mojoTransformMeta.toString(), (Object)n2);
            MojoTransformBuilder mojoTransformBuilder = mojoTransformMeta.getTransform();
            if (mojoTransformBuilder instanceof H) {
                int n3 = -1;
                for (int i2 = 0; i2 < mojoTransformBuilder.oindices.length; ++i2) {
                    if (mojoTransformBuilder.oindices[i2] != n2) continue;
                    n3 = mojoTransformBuilder.iindices[i2];
                }
                if (n3 < 0) {
                    throw new IllegalArgumentException("output index not found in softmax: " + n2);
                }
                n2 = n3;
                continue;
            }
            if (mojoTransformMeta.getTransform() instanceof K) {
                return mojoTransformMeta;
            }
            if (mojoTransformBuilder.iindices.length != 1) {
                throw new IllegalArgumentException("only 1:1 transform expected while traversing from blending operations up to the ShapCapable model, but found " + mojoTransformMeta.getName());
            }
            n2 = mojoTransformBuilder.iindices[0];
        }
    }

    private MojoColumnMeta getGroupInputColumn(MojoTransformationGroup tg, int inputIndex) {
        int n2 = this.getGroupInputColumnIndex(tg, inputIndex);
        return this.globalColumns.get(n2);
    }

    public int getGroupInputColumnIndex(MojoTransformationGroup tg, int inputIndex) {
        int n2 = inputIndex;
        if (tg != null) {
            MojoTransformationGroup mojoTransformationGroup;
            MojoTransformMeta mojoTransformMeta;
            while ((mojoTransformMeta = this.getProducer(n2)) != null && (mojoTransformationGroup = mojoTransformMeta.getTransformationGroup()) != null && mojoTransformationGroup.getId().equals(tg.getId())) {
                int[] nArray = mojoTransformMeta.getInputIndices();
                if (nArray.length != 1) {
                    throw new IllegalStateException(String.format("producer of column #%d has %d columns; exactly 1 is required: %s", n2, nArray.length, mojoTransformMeta));
                }
                if (mojoTransformMeta.getOutputIndices().length != 1) {
                    throw new IllegalStateException(String.format("producer of column #%d has %d columns; exactly 1 is required: %s", n2, mojoTransformMeta.getOutputIndices().length, mojoTransformMeta));
                }
                String string = this.globalColumns.get(n2).getColumnName();
                if (this.transformedFeatures.contains(string)) {
                    log.warn("traversal stops on '{}' PRIOR reaching boundary of group '{}', due to a hint from `Pipeline.transformed`; is it constructed correctly?", (Object)string, (Object)tg);
                    break;
                }
                n2 = nArray[0];
            }
        }
        return n2;
    }

    public Set<String> getGroupInputColumns(MojoTransformationGroup tg, int[] inputIndices) {
        LinkedHashSet<String> linkedHashSet = new LinkedHashSet<String>();
        int[] nArray = inputIndices;
        int n2 = inputIndices.length;
        for (int i2 = 0; i2 < n2; ++i2) {
            int n3 = nArray[i2];
            MojoColumnMeta mojoColumnMeta = this.getGroupInputColumn(tg, n3);
            linkedHashSet.add(mojoColumnMeta.getColumnName());
        }
        return linkedHashSet;
    }

    public static interface Visitor<R> {
        public R visit(PipelineWiring var1, MojoTransformMeta var2) throws IllegalArgumentException;
    }
}

