/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.exec.tez;

import io.prestosql.hive.$internal.org.slf4j.Logger;
import io.prestosql.hive.$internal.org.slf4j.LoggerFactory;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.apache.hadoop.hive.ql.exec.tez.CustomEdgeConfiguration;
import org.apache.hadoop.hive.ql.exec.tez.DataInputByteBuffer;
import org.apache.tez.dag.api.EdgeManagerPlugin;
import org.apache.tez.dag.api.EdgeManagerPluginContext;
import org.apache.tez.runtime.api.events.DataMovementEvent;
import org.apache.tez.runtime.api.events.InputReadErrorEvent;

public class CustomPartitionEdge
extends EdgeManagerPlugin {
    private static final Logger LOG = LoggerFactory.getLogger(CustomPartitionEdge.class.getName());
    CustomEdgeConfiguration conf = null;
    final EdgeManagerPluginContext context;

    public CustomPartitionEdge(EdgeManagerPluginContext context) {
        super(context);
        this.context = context;
    }

    public int getNumDestinationTaskPhysicalInputs(int destinationTaskIndex) {
        return this.context.getSourceVertexNumTasks();
    }

    public int getNumSourceTaskPhysicalOutputs(int sourceTaskIndex) {
        return this.conf.getNumBuckets();
    }

    public int getNumDestinationConsumerTasks(int sourceTaskIndex) {
        return this.context.getDestinationVertexNumTasks();
    }

    public void initialize() {
        ByteBuffer payload = this.context.getUserPayload().getPayload();
        LOG.info("Initializing the edge, payload: " + payload);
        if (payload == null) {
            throw new RuntimeException("Invalid payload");
        }
        DataInputByteBuffer dibb = new DataInputByteBuffer();
        dibb.reset(payload);
        this.conf = new CustomEdgeConfiguration();
        try {
            this.conf.readFields(dibb);
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
        LOG.info("Routing table: " + this.conf.getRoutingTable() + " num Buckets: " + this.conf.getNumBuckets());
    }

    public void routeDataMovementEventToDestination(DataMovementEvent event, int sourceTaskIndex, int sourceOutputIndex, Map<Integer, List<Integer>> mapDestTaskIndices) {
        if (this.conf.getRoutingTable().get(sourceOutputIndex).size() == 0) {
            mapDestTaskIndices.put(-1, new ArrayList());
            return;
        }
        List<Integer> outputIndices = Collections.singletonList(sourceTaskIndex);
        for (Integer destIndex : this.conf.getRoutingTable().get(sourceOutputIndex)) {
            mapDestTaskIndices.put(destIndex, outputIndices);
        }
    }

    public void routeInputSourceTaskFailedEventToDestination(int sourceTaskIndex, Map<Integer, List<Integer>> mapDestTaskIndices) {
        List<Integer> outputIndices = Collections.singletonList(sourceTaskIndex);
        for (int i = 0; i < this.context.getDestinationVertexNumTasks(); ++i) {
            mapDestTaskIndices.put(i, outputIndices);
        }
    }

    public int routeInputErrorEventToSource(InputReadErrorEvent event, int destinationTaskIndex, int destinationFailedInputIndex) {
        return event.getIndex();
    }
}

