/*
 * Decompiled with CFR 0.152.
 */
package com.antgroup.geaflow.shuffle.api.writer;

import com.antgroup.geaflow.shuffle.api.pipeline.buffer.OutBuffer;
import com.antgroup.geaflow.shuffle.api.pipeline.buffer.PipelineShard;
import com.antgroup.geaflow.shuffle.api.pipeline.buffer.PipelineSlice;
import com.antgroup.geaflow.shuffle.api.writer.IWriterContext;
import com.antgroup.geaflow.shuffle.api.writer.ShardBuffer;
import com.antgroup.geaflow.shuffle.memory.ShuffleDataManager;
import com.antgroup.geaflow.shuffle.message.ISliceMeta;
import com.antgroup.geaflow.shuffle.message.PipelineBarrier;
import com.antgroup.geaflow.shuffle.message.PipelineSliceMeta;
import com.antgroup.geaflow.shuffle.message.Shard;
import com.antgroup.geaflow.shuffle.message.ShuffleId;
import com.antgroup.geaflow.shuffle.message.SliceId;
import com.antgroup.geaflow.shuffle.message.WriterId;
import com.antgroup.geaflow.shuffle.network.IConnectionManager;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SpillableShardBuffer<T>
extends ShardBuffer<T, Shard> {
    private static final Logger LOGGER = LoggerFactory.getLogger(SpillableShardBuffer.class);
    protected boolean cacheEnabled;
    protected double cacheSpillThreshold;
    protected WriterId writerId;
    protected ShuffleId shuffleId;
    protected IWriterContext writerContext;
    protected IConnectionManager connectionManager;
    protected int taskId;

    public SpillableShardBuffer() {
    }

    public SpillableShardBuffer(IConnectionManager connectionManager) {
        this.connectionManager = connectionManager;
    }

    public void setConnectionManager(IConnectionManager connectionManager) {
        this.connectionManager = connectionManager;
    }

    @Override
    public void init(IWriterContext writerContext) {
        super.init(writerContext);
        this.writerContext = writerContext;
        this.taskId = writerContext.getTaskId();
        this.cacheEnabled = writerContext.getShuffleDescriptor().isCacheEnabled();
        if (this.cacheEnabled) {
            LOGGER.info("cache is enabled in {}", (Object)this.taskLogTag);
        }
        int channels = writerContext.getTargetChannelNum();
        this.writerId = new WriterId(writerContext.getPipelineInfo().getPipelineId(), this.edgeId, this.taskIndex);
        int refCount = this.cacheEnabled ? Integer.MAX_VALUE : 1;
        this.initResultSlices(channels, refCount);
        this.cacheSpillThreshold = this.shuffleConfig.getCacheSpillThreshold();
    }

    private void initResultSlices(int channels, int refCount) {
        ShuffleDataManager shuffleDataManager = ShuffleDataManager.getInstance();
        PipelineShard pipeShard = shuffleDataManager.getShard(this.writerId);
        if (pipeShard == null) {
            PipelineSlice[] slices = new PipelineSlice[channels];
            for (int i = 0; i < channels; ++i) {
                slices[i] = new PipelineSlice(this.taskLogTag, new SliceId(this.writerId, i), refCount);
            }
            pipeShard = new PipelineShard(this.taskLogTag, slices);
        }
        this.resultSlices = pipeShard.getSlices();
    }

    @Override
    public Optional<Shard> finish(long batchId) throws IOException {
        long beginTime = System.currentTimeMillis();
        this.flushFloatingBuffers(batchId);
        List<ISliceMeta> slices = this.buildSliceMeta(batchId);
        long maxSliceSize = 0L;
        for (int i = 0; i < slices.size(); ++i) {
            ISliceMeta sliceMeta = slices.get(i);
            if (sliceMeta.getRecordNum() > 0L) {
                this.writeMetrics.increaseWrittenChannels();
                if (sliceMeta.getEncodedSize() > maxSliceSize) {
                    maxSliceSize = sliceMeta.getEncodedSize();
                }
            }
            ((OutBuffer.BufferBuilder)this.buffers.get(i)).close();
        }
        this.writeMetrics.setMaxSliceKB(maxSliceSize / 1024L);
        this.writeMetrics.setNumChannels((long)slices.size());
        long flushTime = System.currentTimeMillis() - beginTime;
        this.writeMetrics.setFlushMs(flushTime);
        LOGGER.info("taskId {} {} flush batchId:{} useTime:{}ms {}", new Object[]{this.taskId, this.taskLogTag, batchId, flushTime, this.writeMetrics});
        this.buffers.clear();
        this.buffers = null;
        this.batchCounter = null;
        this.resultSlices = null;
        this.bytesCounter = null;
        return Optional.of(new Shard(this.edgeId, slices));
    }

    private List<ISliceMeta> buildSliceMeta(long batchId) {
        ArrayList<ISliceMeta> slices = new ArrayList<ISliceMeta>();
        PipelineBarrier barrier = new PipelineBarrier(batchId, this.edgeId, this.taskIndex);
        barrier.setFinish(true);
        int writtenChannels = 0;
        for (int i = 0; i < this.targetChannels; ++i) {
            SliceId sliceId = this.resultSlices[i].getSliceId();
            PipelineSliceMeta sliceMeta = new PipelineSliceMeta(sliceId, batchId, this.connectionManager.getShuffleAddress());
            sliceMeta.setRecordNum(this.batchCounter[i]);
            sliceMeta.setEncodedSize(this.bytesCounter[i]);
            slices.add(sliceMeta);
            if (sliceMeta.getRecordNum() <= 0L) continue;
            this.notify(barrier, i);
            ++writtenChannels;
        }
        if (writtenChannels > 0) {
            ShuffleDataManager.getInstance().register(this.writerId, new PipelineShard(this.taskLogTag, this.resultSlices, writtenChannels));
        }
        return slices;
    }

    @Override
    public void close() {
    }

    protected void flushFloatingBuffers(long batchId) {
        for (int i = 0; i < this.buffers.size(); ++i) {
            OutBuffer.BufferBuilder bufferBuilder = (OutBuffer.BufferBuilder)this.buffers.get(i);
            if (bufferBuilder.getBufferSize() <= 0) continue;
            this.send(i, bufferBuilder.build(), batchId);
        }
    }
}

