/*
 * Decompiled with CFR 0.152.
 */
package com.antgroup.geaflow.shuffle.api.writer;

import com.antgroup.geaflow.common.encoder.IEncoder;
import com.antgroup.geaflow.common.metric.ShuffleWriteMetrics;
import com.antgroup.geaflow.shuffle.api.pipeline.buffer.HeapBuffer;
import com.antgroup.geaflow.shuffle.api.pipeline.buffer.OutBuffer;
import com.antgroup.geaflow.shuffle.api.pipeline.buffer.PipeBuffer;
import com.antgroup.geaflow.shuffle.api.pipeline.buffer.PipelineSlice;
import com.antgroup.geaflow.shuffle.api.writer.IWriterContext;
import com.antgroup.geaflow.shuffle.config.ShuffleConfig;
import com.antgroup.geaflow.shuffle.memory.ShuffleMemoryTracker;
import com.antgroup.geaflow.shuffle.message.PipelineBarrier;
import com.antgroup.geaflow.shuffle.serialize.EncoderRecordSerializer;
import com.antgroup.geaflow.shuffle.serialize.IRecordSerializer;
import com.antgroup.geaflow.shuffle.serialize.RecordSerializer;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;

public abstract class ShardBuffer<T, R> {
    protected ShuffleConfig shuffleConfig;
    protected long pipelineId;
    protected int edgeId;
    protected int taskIndex;
    protected String pipelineName;
    protected int targetChannels;
    protected String taskLogTag;
    protected List<OutBuffer.BufferBuilder> buffers;
    protected PipelineSlice[] resultSlices;
    protected int[] batchCounter;
    protected long[] bytesCounter;
    protected ShuffleMemoryTracker memoryTracker;
    protected ShuffleWriteMetrics writeMetrics;
    protected long maxBufferSize;
    protected IRecordSerializer<T> recordSerializer;

    public void init(IWriterContext writerContext) {
        this.shuffleConfig = ShuffleConfig.getInstance();
        this.memoryTracker = ShuffleMemoryTracker.getInstance();
        this.writeMetrics = new ShuffleWriteMetrics();
        this.targetChannels = writerContext.getTargetChannelNum();
        this.pipelineId = writerContext.getPipelineInfo().getPipelineId();
        this.pipelineName = writerContext.getPipelineInfo().getPipelineName();
        this.edgeId = writerContext.getEdgeId();
        this.taskIndex = writerContext.getTaskIndex();
        this.taskLogTag = writerContext.getTaskName();
        this.recordSerializer = ShardBuffer.getRecordSerializer(writerContext);
        this.batchCounter = new int[this.targetChannels];
        this.bytesCounter = new long[this.targetChannels];
        this.maxBufferSize = this.shuffleConfig.getWriteBufferSizeBytes();
        this.buildBufferBuilder(this.targetChannels);
    }

    private void buildBufferBuilder(int channels) {
        this.buffers = new ArrayList<OutBuffer.BufferBuilder>(channels);
        for (int i = 0; i < channels; ++i) {
            HeapBuffer.HeapBufferBuilder bufferBuilder = new HeapBuffer.HeapBufferBuilder();
            bufferBuilder.enableMemoryTrack();
            this.buffers.add(bufferBuilder);
        }
    }

    public void emit(long batchId, T value, boolean isRetract, int[] channels) throws IOException {
        for (int channel : channels) {
            OutBuffer.BufferBuilder outBuffer = this.buffers.get(channel);
            this.recordSerializer.serialize(value, isRetract, outBuffer);
            int n = channel;
            this.batchCounter[n] = this.batchCounter[n] + 1;
            if ((long)outBuffer.getBufferSize() < this.maxBufferSize) continue;
            this.send(channel, outBuffer.build(), batchId);
        }
    }

    public void emit(long batchId, List<T> data, int channel) {
        OutBuffer.BufferBuilder outBuffer = this.buffers.get(channel);
        int size = data.size();
        for (int i = 0; i < size; ++i) {
            this.recordSerializer.serialize(data.get(i), false, outBuffer);
            int n = channel;
            this.batchCounter[n] = this.batchCounter[n] + 1;
        }
        if ((long)outBuffer.getBufferSize() >= this.maxBufferSize) {
            this.send(channel, outBuffer.build(), batchId);
        }
    }

    protected void send(int selectChannel, OutBuffer outBuffer, long batchId) {
        this.sendBuffer(selectChannel, outBuffer, batchId);
        int n = selectChannel;
        this.bytesCounter[n] = this.bytesCounter[n] + (long)outBuffer.getBufferSize();
    }

    protected void sendBuffer(int sliceIndex, OutBuffer buffer, long batchId) {
        PipelineSlice resultSlice = this.resultSlices[sliceIndex];
        resultSlice.add(new PipeBuffer(buffer, batchId, true));
    }

    public abstract Optional<R> finish(long var1) throws IOException;

    public ShuffleWriteMetrics getShuffleWriteMetrics() {
        return this.writeMetrics;
    }

    public void close() {
    }

    protected void notify(PipelineBarrier barrier) throws IOException {
        for (int channel = 0; channel < this.targetChannels; ++channel) {
            this.notify(barrier, channel);
        }
    }

    protected void notify(PipelineBarrier barrier, int channel) {
        long batchId = barrier.getWindowId();
        int recordCount = this.batchCounter[channel];
        this.sendBarrier(channel, batchId, recordCount, barrier.isFinish());
        this.writeMetrics.increaseRecords((long)recordCount);
        this.writeMetrics.increaseEncodedSize(this.bytesCounter[channel]);
        this.batchCounter[channel] = 0;
        this.bytesCounter[channel] = 0L;
    }

    protected void sendBarrier(int sliceIndex, long batchId, int count, boolean isFinish) {
        PipelineSlice resultSlice = this.resultSlices[sliceIndex];
        resultSlice.add(new PipeBuffer(batchId, count, false, isFinish));
    }

    private static <T> IRecordSerializer<T> getRecordSerializer(IWriterContext writerContext) {
        IEncoder<?> encoder = writerContext.getEncoder();
        if (encoder == null) {
            return new RecordSerializer();
        }
        return new EncoderRecordSerializer(encoder);
    }
}

