/*
 * Decompiled with CFR 0.152.
 */
package io.unitycatalog.server.auth.decorator;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.linecorp.armeria.common.HttpData;
import com.linecorp.armeria.common.HttpRequest;
import com.linecorp.armeria.common.HttpResponse;
import com.linecorp.armeria.common.MediaType;
import com.linecorp.armeria.internal.server.annotation.AnnotatedService;
import com.linecorp.armeria.server.DecoratingHttpServiceFunction;
import com.linecorp.armeria.server.HttpService;
import com.linecorp.armeria.server.Service;
import com.linecorp.armeria.server.ServiceRequestContext;
import com.linecorp.armeria.server.SimpleDecoratingHttpService;
import com.linecorp.armeria.server.annotation.Param;
import io.unitycatalog.server.auth.UnityCatalogAuthorizer;
import io.unitycatalog.server.auth.annotation.AuthorizeExpression;
import io.unitycatalog.server.auth.annotation.AuthorizeKey;
import io.unitycatalog.server.auth.annotation.AuthorizeKeys;
import io.unitycatalog.server.auth.decorator.KeyLocator;
import io.unitycatalog.server.auth.decorator.KeyMapper;
import io.unitycatalog.server.auth.decorator.UnityAccessEvaluator;
import io.unitycatalog.server.exception.BaseException;
import io.unitycatalog.server.exception.ErrorCode;
import io.unitycatalog.server.model.SecurableType;
import io.unitycatalog.server.persist.Repositories;
import io.unitycatalog.server.persist.UserRepository;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class UnityAccessDecorator
implements DecoratingHttpServiceFunction {
    private static final Logger LOGGER = LoggerFactory.getLogger(UnityAccessDecorator.class);
    private static final ObjectMapper MAPPER = new ObjectMapper();
    private final KeyMapper keyMapper;
    private final UserRepository userRepository;
    private final UnityAccessEvaluator evaluator;

    public UnityAccessDecorator(UnityCatalogAuthorizer authorizer, Repositories repositories) throws BaseException {
        try {
            this.evaluator = new UnityAccessEvaluator(authorizer);
        }
        catch (IllegalAccessException | NoSuchMethodException e) {
            throw new BaseException(ErrorCode.INTERNAL, "Error initializing access evaluator.", e);
        }
        this.keyMapper = new KeyMapper(repositories);
        this.userRepository = repositories.getUserRepository();
    }

    public HttpResponse serve(HttpService delegate, ServiceRequestContext ctx, HttpRequest req) throws Exception {
        LOGGER.debug("AccessDecorator checking {}", (Object)req.path());
        Method method = UnityAccessDecorator.findServiceMethod(ctx.config().service());
        if (method != null) {
            String expression = UnityAccessDecorator.findAuthorizeExpression(method);
            List<KeyLocator> locator = UnityAccessDecorator.findAuthorizeKeys(method);
            if (expression != null) {
                if (!locator.isEmpty()) {
                    UUID principal = this.userRepository.findPrincipalId();
                    return this.authorizeByRequest(delegate, ctx, req, principal, locator, expression);
                }
                LOGGER.warn("No authorization resource(s) found.");
            } else {
                LOGGER.debug("No authorization expression found.");
            }
        } else {
            LOGGER.warn("Couldn't unwrap service.");
        }
        return delegate.serve(ctx, req);
    }

    private HttpResponse authorizeByRequest(HttpService delegate, ServiceRequestContext ctx, HttpRequest req, UUID principal, List<KeyLocator> locators, String expression) throws Exception {
        HashMap<SecurableType, Object> resourceKeys = new HashMap<SecurableType, Object>();
        List<KeyLocator> systemLocators = locators.stream().filter(l -> l.getSource().equals((Object)KeyLocator.Source.SYSTEM)).toList();
        List<KeyLocator> paramLocators = locators.stream().filter(l -> l.getSource().equals((Object)KeyLocator.Source.PARAM)).toList();
        List<KeyLocator> payloadLocators = locators.stream().filter(l -> l.getSource().equals((Object)KeyLocator.Source.PAYLOAD)).toList();
        systemLocators.forEach(l -> resourceKeys.put(l.getType(), "metastore"));
        paramLocators.forEach(l -> {
            String value = ctx.pathParam(l.getKey()) != null ? ctx.pathParam(l.getKey()) : ctx.queryParam(l.getKey());
            resourceKeys.put(l.getType(), value);
        });
        if (payloadLocators.isEmpty()) {
            LOGGER.debug("Checking authorization before method.");
            this.checkAuthorization(principal, expression, resourceKeys);
            return delegate.serve(ctx, req);
        }
        LOGGER.debug("Checking authorization before in peekData.");
        PeekDataHandler peekDataHandler = new PeekDataHandler(req.contentType(), payloadLocators, resourceKeys);
        HttpRequest peekReq = req.peekData(data -> {
            LOGGER.debug("Authorization peekData invoked.");
            if (peekDataHandler.processPeekData((HttpData)data)) {
                this.checkAuthorization(principal, expression, resourceKeys);
            }
        });
        return delegate.serve(ctx, peekReq);
    }

    private static Object findPayloadValue(String key, Map<String, Object> payload) {
        String[] args = key.split("[.]", 2);
        if (args.length == 1) {
            return payload.get(args[0]);
        }
        if (payload.get(args[0]) instanceof Map) {
            Map value = (Map)payload.get(args[0]);
            return UnityAccessDecorator.findPayloadValue(args[1], value);
        }
        return null;
    }

    private void checkAuthorization(UUID principal, String expression, Map<SecurableType, Object> resourceKeys) {
        LOGGER.debug("resourceKeys = {}", resourceKeys);
        Map<SecurableType, Object> resourceIds = this.keyMapper.mapResourceKeys(resourceKeys);
        if (!resourceIds.keySet().containsAll(resourceKeys.keySet())) {
            LOGGER.warn("Some resource keys have unresolved ids.");
        }
        LOGGER.debug("resourceIds = {}", resourceIds);
        if (!this.evaluator.evaluate(principal, expression, resourceIds)) {
            throw new BaseException(ErrorCode.PERMISSION_DENIED, "Access denied.");
        }
    }

    private static String findAuthorizeExpression(Method method) {
        AuthorizeExpression annotation = method.getAnnotation(AuthorizeExpression.class);
        if (annotation != null) {
            LOGGER.debug("authorize expression = {}", (Object)annotation.value());
            return annotation.value();
        }
        LOGGER.debug("authorize = (none found)");
        return null;
    }

    private static List<KeyLocator> findAuthorizeKeys(Method method) {
        ArrayList<KeyLocator> locators = new ArrayList<KeyLocator>();
        AuthorizeKey methodKey = method.getAnnotation(AuthorizeKey.class);
        if (methodKey != null) {
            locators.add(KeyLocator.builder().source(KeyLocator.Source.SYSTEM).type(methodKey.value()).build());
        }
        for (Parameter parameter : method.getParameters()) {
            AuthorizeKey paramKey = parameter.getAnnotation(AuthorizeKey.class);
            AuthorizeKeys paramKeys = parameter.getAnnotation(AuthorizeKeys.class);
            if (paramKey != null && paramKeys != null) {
                LOGGER.warn("Both AuthorizeKey and AuthorizeKeys present");
            }
            ArrayList<AuthorizeKey> allKeys = new ArrayList<AuthorizeKey>();
            if (paramKey != null) {
                allKeys.add(paramKey);
            }
            if (paramKeys != null) {
                allKeys.addAll(Arrays.asList(paramKeys.value()));
            }
            for (AuthorizeKey key : allKeys) {
                if (!key.key().isEmpty()) {
                    locators.add(KeyLocator.builder().source(KeyLocator.Source.PAYLOAD).type(key.value()).key(key.key()).build());
                    continue;
                }
                Param param = parameter.getAnnotation(Param.class);
                if (param != null) {
                    locators.add(KeyLocator.builder().source(KeyLocator.Source.PARAM).type(key.value()).key(param.value()).build());
                    continue;
                }
                LOGGER.warn("Couldn't find param key for authorization key");
            }
        }
        return locators;
    }

    private static Method findServiceMethod(HttpService httpService) throws ClassNotFoundException {
        SimpleDecoratingHttpService decoratingService;
        Service service = httpService.unwrap();
        if (service instanceof SimpleDecoratingHttpService && (service = (decoratingService = (SimpleDecoratingHttpService)service).unwrap()) instanceof AnnotatedService) {
            AnnotatedService service2 = (AnnotatedService)service;
            LOGGER.debug("serviceName = {}, methodName = {}", (Object)service2.serviceName(), (Object)service2.methodName());
            Class<?> clazz = Class.forName(service2.serviceName());
            List<Method> methods = UnityAccessDecorator.findMethodsByName(clazz, service2.methodName());
            return methods.size() == 1 ? methods.get(0) : null;
        }
        return null;
    }

    private static List<Method> findMethodsByName(Class<?> clazz, String methodName) {
        Method[] methods;
        ArrayList<Method> matchingMethods = new ArrayList<Method>();
        for (Method method : methods = clazz.getDeclaredMethods()) {
            if (!method.getName().equals(methodName)) continue;
            matchingMethods.add(method);
        }
        return matchingMethods;
    }

    private static class PeekDataHandler {
        private final MediaType contentType;
        private final List<KeyLocator> payloadLocators;
        private final Map<SecurableType, Object> resourceKeys;
        private final ByteArrayOutputStream dataStream = new ByteArrayOutputStream();

        private PeekDataHandler(MediaType contentType, List<KeyLocator> payloadLocators, Map<SecurableType, Object> resourceKeys) {
            this.contentType = contentType;
            this.payloadLocators = payloadLocators;
            this.resourceKeys = resourceKeys;
        }

        private boolean processPeekData(HttpData data) {
            if (this.contentType.equals((Object)MediaType.JSON)) {
                try {
                    this.dataStream.write(data.array());
                    LOGGER.debug("Payload: {}", (Object)this.dataStream.toString());
                }
                catch (IOException iOException) {
                    // empty catch block
                }
                if (data.array()[data.array().length - 1] == 125) {
                    try {
                        Map payload = (Map)MAPPER.readValue(this.dataStream.toByteArray(), (TypeReference)new TypeReference<Map<String, Object>>(){});
                        this.payloadLocators.forEach(l -> this.resourceKeys.put(l.getType(), UnityAccessDecorator.findPayloadValue(l.getKey(), payload)));
                        return true;
                    }
                    catch (IOException e) {
                        LOGGER.warn("Error parsing payload: {}", (Object)e.getMessage());
                    }
                }
            } else {
                LOGGER.warn("Skipping content-type: {}", (Object)this.contentType);
            }
            return false;
        }
    }
}

