/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.server.security;

import com.facebook.airlift.http.server.AuthenticationException;
import com.facebook.airlift.http.server.Authenticator;
import com.facebook.presto.ClientRequestFilterManager;
import com.facebook.presto.server.security.SecurityConfig;
import com.facebook.presto.spi.ClientRequestFilter;
import com.facebook.presto.spi.ErrorCodeSupplier;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.StandardErrorCode;
import com.google.common.base.Joiner;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.io.ByteStreams;
import com.google.common.net.MediaType;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.PrintWriter;
import java.security.Principal;
import java.util.Collection;
import java.util.Collections;
import java.util.Enumeration;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import javax.inject.Inject;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletInputStream;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.HttpServletResponse;

public class AuthenticationFilter
implements Filter {
    private static final String HTTPS_PROTOCOL = "https";
    private final List<Authenticator> authenticators;
    private final boolean allowForwardedHttps;
    private final ClientRequestFilterManager clientRequestFilterManager;
    private final List<String> headersBlockList = ImmutableList.of((Object)"X-Presto-Transaction-Id", (Object)"X-Presto-Started-Transaction-Id", (Object)"X-Presto-Clear-Transaction-Id", (Object)"X-Presto-Trace-Token");

    @Inject
    public AuthenticationFilter(List<Authenticator> authenticators, SecurityConfig securityConfig, ClientRequestFilterManager clientRequestFilterManager) {
        this.authenticators = ImmutableList.copyOf((Collection)Objects.requireNonNull(authenticators, "authenticators is null"));
        this.allowForwardedHttps = Objects.requireNonNull(securityConfig, "securityConfig is null").getAllowForwardedHttps();
        this.clientRequestFilterManager = Objects.requireNonNull(clientRequestFilterManager, "clientRequestFilterManager is null");
    }

    public void init(FilterConfig filterConfig) {
    }

    public void destroy() {
    }

    public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain nextFilter) throws IOException, ServletException {
        HttpServletRequest request = (HttpServletRequest)servletRequest;
        HttpServletResponse response = (HttpServletResponse)servletResponse;
        if (!this.doesRequestSupportAuthentication(request)) {
            nextFilter.doFilter((ServletRequest)request, (ServletResponse)response);
            return;
        }
        LinkedHashSet<String> messages = new LinkedHashSet<String>();
        LinkedHashSet authenticateHeaders = new LinkedHashSet();
        for (Authenticator authenticator : this.authenticators) {
            Principal principal;
            try {
                principal = authenticator.authenticate(request);
            }
            catch (AuthenticationException e) {
                if (e.getMessage() != null) {
                    messages.add(e.getMessage());
                }
                e.getAuthenticateHeader().ifPresent(authenticateHeaders::add);
                continue;
            }
            HttpServletRequest wrappedRequest = this.mergeExtraHeaders(request, principal);
            nextFilter.doFilter(AuthenticationFilter.withPrincipal(wrappedRequest, principal), (ServletResponse)response);
            return;
        }
        AuthenticationFilter.skipRequestBody(request);
        for (String value : authenticateHeaders) {
            response.addHeader("WWW-Authenticate", value);
        }
        if (messages.isEmpty()) {
            messages.add("Unauthorized");
        }
        String error = Joiner.on((String)" | ").join(messages);
        response.setStatus(401, error);
        response.setContentType(MediaType.PLAIN_TEXT_UTF_8.toString());
        try (PrintWriter writer = response.getWriter();){
            writer.write(error);
        }
    }

    public HttpServletRequest mergeExtraHeaders(HttpServletRequest request, Principal principal) {
        List clientRequestFilters = this.clientRequestFilterManager.getClientRequestFilters();
        if (clientRequestFilters.isEmpty()) {
            return request;
        }
        ImmutableMap.Builder extraHeadersMapBuilder = ImmutableMap.builder();
        HashSet<String> addedHeaders = new HashSet<String>();
        for (ClientRequestFilter requestFilter : clientRequestFilters) {
            Map extraHeaderValueMap;
            boolean headersPresent = requestFilter.getExtraHeaderKeys().stream().allMatch(headerName -> request.getHeader(headerName) != null);
            if (headersPresent || (extraHeaderValueMap = requestFilter.getExtraHeaders(principal)).isEmpty()) continue;
            for (Map.Entry extraHeaderEntry : extraHeaderValueMap.entrySet()) {
                String headerKey = (String)extraHeaderEntry.getKey();
                if (this.headersBlockList.contains(headerKey)) {
                    throw new PrestoException((ErrorCodeSupplier)StandardErrorCode.HEADER_MODIFICATION_ATTEMPT, "Modification attempt detected: The header " + headerKey + " is not allowed to be modified. The following headers cannot be modified: " + String.join((CharSequence)", ", this.headersBlockList));
                }
                if (addedHeaders.contains(headerKey)) {
                    throw new PrestoException((ErrorCodeSupplier)StandardErrorCode.HEADER_MODIFICATION_ATTEMPT, "Header conflict detected: " + headerKey + " already added by another filter.");
                }
                if (request.getHeader(headerKey) != null || !requestFilter.getExtraHeaderKeys().contains(headerKey)) continue;
                extraHeadersMapBuilder.put((Object)headerKey, (Object)((String)extraHeaderEntry.getValue()));
                addedHeaders.add(headerKey);
            }
        }
        return new ModifiedHttpServletRequest(request, (Map<String, String>)extraHeadersMapBuilder.build());
    }

    private boolean doesRequestSupportAuthentication(HttpServletRequest request) {
        if (this.authenticators.isEmpty()) {
            return false;
        }
        if (request.isSecure()) {
            return true;
        }
        if (this.allowForwardedHttps) {
            return Strings.nullToEmpty((String)request.getHeader("X-Forwarded-Proto")).equalsIgnoreCase(HTTPS_PROTOCOL);
        }
        return false;
    }

    private static ServletRequest withPrincipal(HttpServletRequest request, final Principal principal) {
        Objects.requireNonNull(principal, "principal is null");
        return new HttpServletRequestWrapper(request){

            public Principal getUserPrincipal() {
                return principal;
            }
        };
    }

    private static void skipRequestBody(HttpServletRequest request) throws IOException {
        try (ServletInputStream inputStream = request.getInputStream();){
            ByteStreams.copy((InputStream)inputStream, (OutputStream)ByteStreams.nullOutputStream());
        }
    }

    public static class ModifiedHttpServletRequest
    extends HttpServletRequestWrapper {
        private final Map<String, String> customHeaders;

        public ModifiedHttpServletRequest(HttpServletRequest request, Map<String, String> headers) {
            super(request);
            this.customHeaders = ImmutableMap.copyOf(Objects.requireNonNull(headers, "headers is null"));
        }

        public String getHeader(String name) {
            if (this.customHeaders.containsKey(name)) {
                return this.customHeaders.get(name);
            }
            return super.getHeader(name);
        }

        public Enumeration<String> getHeaderNames() {
            return Collections.enumeration(ImmutableSet.builder().addAll(this.customHeaders.keySet()).addAll(Collections.list(super.getHeaderNames())).build());
        }

        public Enumeration<String> getHeaders(String name) {
            if (this.customHeaders.containsKey(name)) {
                return Collections.enumeration(ImmutableList.of((Object)this.customHeaders.get(name)));
            }
            return super.getHeaders(name);
        }
    }
}

