/*
 * Decompiled with CFR 0.152.
 */
package com.antgroup.geaflow.shuffle.api.reader;

import com.antgroup.geaflow.common.config.Configuration;
import com.antgroup.geaflow.common.exception.GeaflowRuntimeException;
import com.antgroup.geaflow.shuffle.api.pipeline.buffer.PipeFetcherBuffer;
import com.antgroup.geaflow.shuffle.api.pipeline.fetcher.MultiShardFetcher;
import com.antgroup.geaflow.shuffle.api.pipeline.fetcher.OneShardFetcher;
import com.antgroup.geaflow.shuffle.api.pipeline.fetcher.ShardFetcher;
import com.antgroup.geaflow.shuffle.api.reader.AbstractFetcher;
import com.antgroup.geaflow.shuffle.api.reader.FetchContext;
import com.antgroup.geaflow.shuffle.message.FetchRequest;
import com.antgroup.geaflow.shuffle.message.ISliceMeta;
import com.antgroup.geaflow.shuffle.message.PipelineBarrier;
import com.antgroup.geaflow.shuffle.message.PipelineEvent;
import com.antgroup.geaflow.shuffle.message.PipelineMessage;
import com.antgroup.geaflow.shuffle.message.PipelineSliceMeta;
import com.antgroup.geaflow.shuffle.message.SliceId;
import com.antgroup.geaflow.shuffle.network.IConnectionManager;
import com.antgroup.geaflow.shuffle.serialize.IMessageIterator;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class PipelineFetcher
extends AbstractFetcher<PipelineSliceMeta> {
    private static final Logger LOGGER = LoggerFactory.getLogger(PipelineFetcher.class);
    private static final int INFINITE_BATCHES = -1;
    private IConnectionManager connectionManager;
    private Set<Integer> inputEdgeSet;
    private ShardFetcher inputFetcher;
    private volatile boolean isRunning;

    @Override
    public void setup(IConnectionManager connectionManager, Configuration config) {
        super.setup(connectionManager, config);
        this.connectionManager = connectionManager;
        this.inputEdgeSet = new HashSet<Integer>();
        this.isRunning = true;
    }

    @Override
    public void init(FetchContext<PipelineSliceMeta> fetchContext) {
        super.init(fetchContext);
        ShardFetcher previous = this.inputFetcher;
        this.inputFetcher = this.getInputFetcher(fetchContext);
        try {
            this.inputFetcher.requestSlices(this.targetBatchId);
            if (previous != null && previous != this.inputFetcher) {
                previous.close();
            }
        }
        catch (IOException e) {
            LOGGER.error(e.getMessage(), e.getCause());
            throw new GeaflowRuntimeException("fetch error", (Throwable)e);
        }
    }

    @Override
    public PipelineEvent next() {
        long startTime = System.currentTimeMillis();
        try {
            Optional<PipeFetcherBuffer> next = this.inputFetcher.getNext();
            if (next.isPresent()) {
                PipeFetcherBuffer buffer = next.get();
                if (buffer.isBarrier()) {
                    if (buffer.getBatchId() == this.targetBatchId || buffer.isFinish()) {
                        ++this.processedNum;
                    }
                    SliceId sliceId = buffer.getSliceId();
                    PipelineBarrier barrier = new PipelineBarrier(buffer.getBatchId(), sliceId.getEdgeId(), sliceId.getShardIndex(), sliceId.getSliceIndex(), buffer.getBatchCount());
                    barrier.setFinish(buffer.isFinish());
                    PipelineBarrier pipelineBarrier = barrier;
                    return pipelineBarrier;
                }
                int edgeId = buffer.getSliceId().getEdgeId();
                this.readMetrics.increaseDecodeBytes((long)buffer.getBufferSize());
                IMessageIterator<?> msgIterator = this.getMessageIterator(edgeId, buffer.getBuffer());
                PipelineMessage pipelineMessage = new PipelineMessage(buffer.getBatchId(), buffer.getStreamName(), msgIterator);
                return pipelineMessage;
            }
            if (!this.isRunning) {
                PipelineEvent pipelineEvent = null;
                return pipelineEvent;
            }
            try {
                throw new GeaflowRuntimeException(this.taskName + " get null");
            }
            catch (IOException | InterruptedException e) {
                LOGGER.error(e.getMessage(), e.getCause());
                throw new GeaflowRuntimeException((Throwable)e);
            }
        }
        finally {
            this.readMetrics.incFetchWaitMs(System.currentTimeMillis() - startTime);
        }
    }

    @Override
    public boolean hasNext() {
        return !this.inputFetcher.isFinished() && (this.totalSliceNum == -1 || this.processedNum < this.totalSliceNum);
    }

    @Override
    public void close() {
        this.isRunning = false;
        if (this.inputFetcher != null) {
            this.inputFetcher.close();
        }
    }

    private ShardFetcher getInputFetcher(FetchContext fetchContext) {
        ShardFetcher shardFetcher = this.inputFetcher;
        Map<Integer, List<Object>> inputSlices = fetchContext.getInputSliceMap();
        if (inputSlices == null) {
            inputSlices = fetchContext.getRequest().getInputSlices();
        }
        if (inputSlices != null && !inputSlices.isEmpty()) {
            shardFetcher = this.createShardFetcher(fetchContext.getRequest(), inputSlices);
        }
        return shardFetcher;
    }

    private ShardFetcher createShardFetcher(FetchRequest req, Map<Integer, List<ISliceMeta>> inputSlices) {
        Set<Integer> edgeSet = inputSlices.keySet();
        if (this.checkInputUnchanged(inputSlices)) {
            return this.inputFetcher;
        }
        this.inputEdgeSet.clear();
        this.inputEdgeSet.addAll(edgeSet);
        int channels = 0;
        int fetcherIndex = 0;
        long batchId = req.getTargetBatchId();
        Map<Integer, String> inputStreamMap = req.getInputStreamMap();
        ArrayList<OneShardFetcher> fetchers = new ArrayList<OneShardFetcher>(inputSlices.size());
        for (Map.Entry<Integer, List<ISliceMeta>> entry : inputSlices.entrySet()) {
            int edgeId = entry.getKey();
            String streamName = inputStreamMap.get(edgeId);
            OneShardFetcher inputFetcher = new OneShardFetcher(req.getVertexId(), this.taskName, fetcherIndex, edgeId, streamName, entry.getValue(), batchId, this.connectionManager);
            fetchers.add(inputFetcher);
            ++fetcherIndex;
            channels += entry.getValue().size();
        }
        int n = this.totalSliceNum = batchId < 0L ? -1 : channels;
        if (fetchers.size() == 1) {
            return (ShardFetcher)fetchers.get(0);
        }
        return new MultiShardFetcher(fetchers.toArray(new OneShardFetcher[0]));
    }

    private boolean checkInputUnchanged(Map<Integer, List<ISliceMeta>> inputSlices) {
        Set<Integer> edgeSet = inputSlices.keySet();
        if (edgeSet.size() == this.inputEdgeSet.size()) {
            return this.inputEdgeSet.containsAll(edgeSet);
        }
        return false;
    }
}

