/*
 * Decompiled with CFR 0.152.
 */
package com.antgroup.geaflow.shuffle.network.netty;

import com.antgroup.geaflow.shuffle.api.pipeline.channel.ChannelId;
import com.antgroup.geaflow.shuffle.api.pipeline.channel.RemoteInputChannel;
import com.antgroup.geaflow.shuffle.network.protocol.CancelRequest;
import com.antgroup.geaflow.shuffle.network.protocol.ErrorResponse;
import com.antgroup.geaflow.shuffle.network.protocol.NettyMessage;
import com.antgroup.geaflow.shuffle.network.protocol.SliceResponse;
import com.antgroup.geaflow.shuffle.util.SliceNotFoundException;
import com.antgroup.geaflow.shuffle.util.TransportException;
import com.google.common.annotations.VisibleForTesting;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import java.io.IOException;
import java.net.SocketAddress;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.atomic.AtomicReference;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SliceRequestClientHandler
extends SimpleChannelInboundHandler<NettyMessage> {
    private static final Logger LOGGER = LoggerFactory.getLogger(SliceRequestClientHandler.class);
    private final ConcurrentMap<ChannelId, RemoteInputChannel> inputChannels = new ConcurrentHashMap<ChannelId, RemoteInputChannel>();
    private final AtomicReference<Throwable> channelError = new AtomicReference();
    private final ConcurrentMap<ChannelId, ChannelId> cancelled = new ConcurrentHashMap<ChannelId, ChannelId>();
    private volatile ChannelHandlerContext ctx;

    public void addInputChannel(RemoteInputChannel listener) throws IOException {
        this.checkError();
        this.inputChannels.putIfAbsent(listener.getInputChannelId(), listener);
    }

    public void removeInputChannel(RemoteInputChannel listener) {
        this.inputChannels.remove(listener.getInputChannelId());
    }

    public RemoteInputChannel getInputChannel(ChannelId inputChannelId) {
        return (RemoteInputChannel)this.inputChannels.get(inputChannelId);
    }

    public void cancelRequest(ChannelId inputChannelId) {
        if (inputChannelId == null || this.ctx == null) {
            return;
        }
        if (this.cancelled.putIfAbsent(inputChannelId, inputChannelId) == null) {
            this.ctx.writeAndFlush((Object)new CancelRequest(inputChannelId));
        }
    }

    public void channelActive(ChannelHandlerContext ctx) throws Exception {
        if (this.ctx == null) {
            this.ctx = ctx;
        }
        super.channelActive(ctx);
    }

    public void channelInactive(ChannelHandlerContext ctx) throws Exception {
        if (!this.inputChannels.isEmpty()) {
            SocketAddress remoteAddr = ctx.channel().remoteAddress();
            this.notifyAllChannelsOfErrorAndClose(new TransportException("Connection unexpectedly closed by remote server '" + remoteAddr + "'. This might indicate that the remote server was lost.", remoteAddr));
        }
        super.channelInactive(ctx);
    }

    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
        if (cause instanceof TransportException) {
            this.notifyAllChannelsOfErrorAndClose(cause);
        } else {
            TransportException tex;
            SocketAddress remoteAddr = ctx.channel().remoteAddress();
            if (cause instanceof IOException && "Connection reset by peer".equals(cause.getMessage())) {
                tex = new TransportException("Lost connection to server '" + remoteAddr + "'. This indicates that the remote server was lost.", remoteAddr, cause);
            } else {
                SocketAddress localAddr = ctx.channel().localAddress();
                tex = new TransportException(String.format("%s (connection to '%s')", cause.getMessage(), remoteAddr), localAddr, cause);
            }
            this.notifyAllChannelsOfErrorAndClose(tex);
        }
    }

    protected void channelRead0(ChannelHandlerContext channelHandlerContext, NettyMessage nettyMessage) throws Exception {
        try {
            this.decodeMsg(nettyMessage);
        }
        catch (Throwable t) {
            this.notifyAllChannelsOfErrorAndClose(t);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void notifyAllChannelsOfErrorAndClose(Throwable cause) {
        if (this.channelError.compareAndSet(null, cause)) {
            try {
                for (RemoteInputChannel inputChannel : this.inputChannels.values()) {
                    inputChannel.onError(cause);
                }
            }
            catch (Throwable t) {
                LOGGER.warn("Exception was thrown during error notification of a remote input channel.", t);
            }
            finally {
                this.inputChannels.clear();
                if (this.ctx != null) {
                    this.ctx.close();
                }
            }
        }
    }

    @VisibleForTesting
    void checkError() throws IOException {
        Throwable t = this.channelError.get();
        if (t != null) {
            if (t instanceof IOException) {
                throw (IOException)t;
            }
            throw new IOException("There has been an error in the channel.", t);
        }
    }

    private void decodeMsg(Object msg) {
        Class<?> msgClazz = msg.getClass();
        if (msgClazz == SliceResponse.class) {
            SliceResponse response = (SliceResponse)msg;
            RemoteInputChannel inputChannel = (RemoteInputChannel)this.inputChannels.get(response.getReceiverId());
            if (inputChannel == null || inputChannel.isReleased()) {
                this.cancelRequest(response.getReceiverId());
                return;
            }
            try {
                this.processBuffer(inputChannel, response);
            }
            catch (Throwable t) {
                inputChannel.onError(t);
            }
        } else if (msgClazz == ErrorResponse.class) {
            ErrorResponse error = (ErrorResponse)msg;
            SocketAddress remoteAddr = this.ctx.channel().remoteAddress();
            if (error.isFatalError()) {
                this.notifyAllChannelsOfErrorAndClose(new TransportException("Fatal error at remote server '" + remoteAddr + "'.", remoteAddr, error.getCause()));
            } else {
                RemoteInputChannel inputChannel = (RemoteInputChannel)this.inputChannels.get(error.getChannelId());
                if (inputChannel != null) {
                    if (error.getCause().getClass() == SliceNotFoundException.class) {
                        inputChannel.onFailedFetchRequest();
                    } else {
                        inputChannel.onError(new TransportException("Error at remote server '" + remoteAddr + "'.", remoteAddr, error.getCause()));
                    }
                }
            }
        } else {
            throw new IllegalStateException("Received unknown message from producer: " + msg.getClass());
        }
    }

    private void processBuffer(RemoteInputChannel inputChannel, SliceResponse response) throws Throwable {
        if (response.getBuffer().isData() && response.getBufferSize() == 0) {
            inputChannel.onEmptyBuffer(response.getSequenceNumber());
        } else if (response.getBuffer() != null) {
            inputChannel.onBuffer(response.getBuffer(), response.getSequenceNumber());
        } else {
            throw new IllegalStateException("The read buffer is null in input channel: " + inputChannel.getChannelIndex());
        }
    }
}

