/*
 * Decompiled with CFR 0.152.
 */
package org.drasyl.handler.rmi;

import io.netty.buffer.ByteBuf;
import io.netty.channel.AddressedEnvelope;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.DefaultAddressedEnvelope;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.FutureListener;
import io.netty.util.concurrent.GenericFutureListener;
import io.netty.util.internal.StringUtil;
import java.io.IOException;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.SocketAddress;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.UUID;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import org.drasyl.handler.rmi.RmiException;
import org.drasyl.handler.rmi.RmiUtil;
import org.drasyl.handler.rmi.annotation.RmiCaller;
import org.drasyl.handler.rmi.message.RmiCancel;
import org.drasyl.handler.rmi.message.RmiError;
import org.drasyl.handler.rmi.message.RmiMessage;
import org.drasyl.handler.rmi.message.RmiRequest;
import org.drasyl.handler.rmi.message.RmiResponse;
import org.drasyl.util.Pair;
import org.drasyl.util.logging.Logger;
import org.drasyl.util.logging.LoggerFactory;

public class RmiServerHandler
extends SimpleChannelInboundHandler<AddressedEnvelope<RmiMessage, SocketAddress>> {
    private static final Logger LOG = LoggerFactory.getLogger(RmiServerHandler.class);
    private static final Map<Class<?>, Optional<Field>> callerFields = new HashMap();
    private final Map<Integer, Object> bindings;
    private final Map<Integer, Map<Integer, Method>> bindingsMethods;
    private final Map<Pair<SocketAddress, UUID>, Future<?>> invocations;

    public RmiServerHandler(Map<Integer, Object> bindings, Map<Integer, Map<Integer, Method>> bindingsMethods, Map<Pair<SocketAddress, UUID>, Future<?>> invocations) {
        super(false);
        this.bindings = Objects.requireNonNull(bindings);
        this.bindingsMethods = Objects.requireNonNull(bindingsMethods);
        this.invocations = Objects.requireNonNull(invocations);
    }

    public RmiServerHandler() {
        this(new HashMap<Integer, Object>(), new HashMap<Integer, Map<Integer, Method>>(), new HashMap());
    }

    public boolean acceptInboundMessage(Object msg) throws Exception {
        return msg instanceof AddressedEnvelope && (((AddressedEnvelope)msg).content() instanceof RmiRequest || ((AddressedEnvelope)msg).content() instanceof RmiCancel);
    }

    protected void channelRead0(ChannelHandlerContext ctx, AddressedEnvelope<RmiMessage, SocketAddress> msg) {
        LOG.trace("Got `{}`.", msg);
        if (msg.content() instanceof RmiRequest) {
            this.handleRequest(ctx, (RmiRequest)msg.content(), msg.sender());
        } else if (msg.content() instanceof RmiCancel) {
            this.handleCancel((RmiCancel)msg.content(), msg.sender());
        } else {
            ctx.fireChannelRead(msg);
        }
    }

    private void handleRequest(ChannelHandlerContext ctx, RmiRequest request, SocketAddress caller) {
        UUID id = request.getId();
        int name = request.getName();
        int methodHash = request.getMethod();
        ByteBuf argsBuf = request.getArguments();
        Object binding = this.bindings.get(name);
        if (binding == null) {
            request.release();
            RmiServerHandler.replyError(ctx, caller, id, new RmiException("Binding not found."));
            return;
        }
        Map<Integer, Method> bindingMethods = this.bindingsMethods.get(name);
        Method method = bindingMethods.get(methodHash);
        if (method == null) {
            request.release();
            RmiServerHandler.replyError(ctx, caller, id, new RmiException("Method not found."));
            return;
        }
        try {
            Object[] args = RmiUtil.unmarshalArgs(method.getParameterTypes(), argsBuf);
            this.invokeMethod(ctx, caller, id, binding, method, args);
        }
        catch (IOException e) {
            RmiServerHandler.replyError(ctx, caller, id, e);
        }
    }

    private void handleCancel(RmiCancel cancel, SocketAddress sender) {
        UUID id = cancel.getId();
        Future<?> invocation = this.invocations.get(Pair.of((Object)sender, (Object)id));
        if (invocation != null) {
            invocation.cancel(false);
        }
    }

    private void invokeMethod(ChannelHandlerContext ctx, SocketAddress caller, UUID id, Object binding, Method method, Object[] args) {
        try {
            Field callerField = RmiServerHandler.getCallerField(binding.getClass());
            if (callerField != null) {
                callerField.set(binding, caller);
            }
            Supplier[] supplierArray = new Supplier[3];
            supplierArray[0] = method::getName;
            supplierArray[1] = () -> Arrays.stream(method.getParameterTypes()).map(StringUtil::simpleClassName).collect(Collectors.joining(","));
            supplierArray[2] = () -> StringUtil.simpleClassName((Object)binding);
            LOG.debug("Invoke `{}({})` on local object `{}`.", supplierArray);
            Object result = method.invoke(binding, args);
            if (result instanceof Future) {
                this.invocations.put((Pair<SocketAddress, UUID>)Pair.of((Object)caller, (Object)id), (Future)result);
                ((Future)result).addListener((GenericFutureListener)((FutureListener)future -> {
                    this.invocations.remove(Pair.of((Object)caller, (Object)id));
                    if (future.isSuccess()) {
                        try {
                            Object response = future.getNow();
                            Supplier[] supplierArray = new Supplier[4];
                            supplierArray[0] = method::getName;
                            supplierArray[1] = () -> Arrays.stream(method.getParameterTypes()).map(Class::getName).collect(Collectors.joining(","));
                            supplierArray[2] = () -> StringUtil.simpleClassName((Object)binding);
                            supplierArray[3] = () -> StringUtil.simpleClassName((Object)response);
                            LOG.debug("Invocation `{}({})` on local object `{}` returned `{}`.", supplierArray);
                            DefaultAddressedEnvelope msg = new DefaultAddressedEnvelope((Object)RmiResponse.of(id, RmiUtil.marshalResult(response, ctx.alloc().buffer())), caller);
                            LOG.trace("Send `{}`.", (Object)msg);
                            ctx.writeAndFlush((Object)msg).addListener((GenericFutureListener)((ChannelFutureListener)future2 -> {
                                if (future.cause() != null) {
                                    LOG.debug("Error", future.cause());
                                }
                            }));
                        }
                        catch (IOException e) {
                            RmiServerHandler.replyError(ctx, caller, id, e);
                        }
                    } else {
                        RmiServerHandler.replyError(ctx, caller, id, future.cause());
                    }
                }));
            }
        }
        catch (IllegalAccessException | IllegalArgumentException | InvocationTargetException e) {
            RmiServerHandler.replyError(ctx, caller, id, e);
        }
    }

    public void bind(String name, Object object) {
        int bindingKey = name.hashCode();
        if (this.bindings.containsKey(bindingKey)) {
            throw new IllegalArgumentException("`" + name + "` has already an associated binding.");
        }
        Class<?> clazz = object.getClass();
        Class<?>[] interfaces = clazz.getInterfaces();
        if (interfaces.length == 0) {
            throw new IllegalArgumentException("Given object did not implement any interfaces whose methods can be made available for remote invocations.");
        }
        HashMap<Integer, Method> bindingMethods = new HashMap<Integer, Method>();
        for (Class<?> iface : interfaces) {
            for (Method method : iface.getMethods()) {
                int methodHash = RmiUtil.computeMethodHash(method);
                bindingMethods.put(methodHash, method);
            }
        }
        LOG.debug("Bound `{}`: {}", (Object)name, object);
        this.bindings.put(bindingKey, object);
        this.bindingsMethods.put(bindingKey, bindingMethods);
    }

    public void unbind(String name) {
        int bindingKey = name.hashCode();
        LOG.debug("Unbound `{}`", (Object)name);
        this.bindings.remove(bindingKey);
        this.bindingsMethods.remove(bindingKey);
    }

    public void rebind(String name, Object object) {
        this.unbind(name);
        this.bind(name, object);
    }

    private static void replyError(ChannelHandlerContext ctx, SocketAddress recipient, UUID id, Throwable cause) {
        LOG.warn("Error:", cause);
        RmiError response = RmiError.of(id, cause);
        DefaultAddressedEnvelope msg = new DefaultAddressedEnvelope((Object)response, recipient);
        LOG.trace("Send `{}`.", (Object)msg);
        ctx.writeAndFlush((Object)msg).addListener((GenericFutureListener)((ChannelFutureListener)future -> {
            if (future.cause() != null) {
                LOG.warn("Error", future.cause());
            }
        }));
    }

    private static Field getCallerField(Class<?> clazz) {
        return callerFields.computeIfAbsent(clazz, key -> {
            Field[] fields;
            for (Field field : fields = clazz.getDeclaredFields()) {
                if (!field.isAnnotationPresent(RmiCaller.class)) continue;
                field.setAccessible(true);
                return Optional.of(field);
            }
            return Optional.empty();
        }).orElse(null);
    }
}

