/*
 * Decompiled with CFR 0.152.
 */
package com.linecorp.armeria.server.cors;

import com.linecorp.armeria.common.FilteredHttpResponse;
import com.linecorp.armeria.common.HttpHeaderNames;
import com.linecorp.armeria.common.HttpHeaders;
import com.linecorp.armeria.common.HttpMethod;
import com.linecorp.armeria.common.HttpObject;
import com.linecorp.armeria.common.HttpRequest;
import com.linecorp.armeria.common.HttpResponse;
import com.linecorp.armeria.common.HttpStatus;
import com.linecorp.armeria.common.HttpStatusClass;
import com.linecorp.armeria.internal.shaded.guava.base.Ascii;
import com.linecorp.armeria.internal.shaded.guava.base.Joiner;
import com.linecorp.armeria.server.Service;
import com.linecorp.armeria.server.ServiceRequestContext;
import com.linecorp.armeria.server.SimpleDecoratingService;
import com.linecorp.armeria.server.cors.CorsConfig;
import io.netty.util.AsciiString;
import java.util.Iterator;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class CorsService
extends SimpleDecoratingService<HttpRequest, HttpResponse> {
    private static final Logger logger = LoggerFactory.getLogger(CorsService.class);
    private static final String ANY_ORIGIN = "*";
    private static final String NULL_ORIGIN = "null";
    private static final String DELIMITER = ",";
    private static final Joiner HEADER_JOINER = Joiner.on(",");
    private final CorsConfig config;

    public CorsService(Service<HttpRequest, HttpResponse> delegate, CorsConfig config) {
        super(delegate);
        this.config = Objects.requireNonNull(config, "config");
    }

    public CorsConfig config() {
        return this.config;
    }

    @Override
    public HttpResponse serve(ServiceRequestContext ctx, final HttpRequest req) throws Exception {
        if (this.config.isEnabled()) {
            if (CorsService.isCorsPreflightRequest(req)) {
                return this.handleCorsPreflight(req);
            }
            if (this.config.isShortCircuit() && !this.validateCorsOrigin(req)) {
                return CorsService.forbidden();
            }
        }
        return new FilteredHttpResponse((HttpResponse)this.delegate().serve(ctx, (HttpRequest)req)){

            @Override
            protected HttpObject filter(HttpObject obj) {
                if (!(obj instanceof HttpHeaders)) {
                    return obj;
                }
                HttpHeaders headers = (HttpHeaders)obj;
                HttpStatus status = headers.status();
                if (status == null || status.codeClass() == HttpStatusClass.INFORMATIONAL) {
                    return headers;
                }
                CorsService.this.setCorsResponseHeaders(req, headers);
                return headers;
            }
        };
    }

    private static boolean isCorsPreflightRequest(HttpRequest request) {
        return request.method() == HttpMethod.OPTIONS && request.headers().contains(HttpHeaderNames.ORIGIN) && request.headers().contains(HttpHeaderNames.ACCESS_CONTROL_REQUEST_METHOD);
    }

    private HttpResponse handleCorsPreflight(HttpRequest req) {
        HttpHeaders headers = HttpHeaders.of(HttpStatus.OK);
        if (this.setCorsOrigin(req, headers)) {
            this.setCorsAllowMethods(headers);
            this.setCorsAllowHeaders(headers);
            this.setCorsAllowCredentials(headers);
            this.setCorsMaxAge(headers);
            this.setPreflightHeaders(headers);
        }
        return HttpResponse.of(headers);
    }

    private void setPreflightHeaders(HttpHeaders headers) {
        Iterator<Map.Entry<AsciiString, String>> iterator = this.config.preflightResponseHeaders().iterator();
        while (iterator.hasNext()) {
            Map.Entry<AsciiString, String> entry = iterator.next();
            headers.add(entry.getKey(), entry.getValue());
        }
    }

    private void setCorsResponseHeaders(HttpRequest req, HttpHeaders headers) {
        if (this.setCorsOrigin(req, headers)) {
            this.setCorsAllowCredentials(headers);
            this.setCorsAllowHeaders(headers);
            this.setCorsExposeHeaders(headers);
        }
    }

    private static HttpResponse forbidden() {
        return HttpResponse.of(HttpStatus.FORBIDDEN);
    }

    private boolean validateCorsOrigin(HttpRequest request) {
        if (this.config.isAnyOriginSupported()) {
            return true;
        }
        String origin = (String)request.headers().get(HttpHeaderNames.ORIGIN);
        return origin == null || NULL_ORIGIN.equals(origin) && this.config.isNullOriginAllowed() || this.config.origins().contains(Ascii.toLowerCase(origin));
    }

    private boolean setCorsOrigin(HttpRequest request, HttpHeaders headers) {
        if (!this.config.isEnabled()) {
            return false;
        }
        String origin = (String)request.headers().get(HttpHeaderNames.ORIGIN);
        if (origin != null) {
            if (NULL_ORIGIN.equals(origin) && this.config.isNullOriginAllowed()) {
                CorsService.setCorsNullOrigin(headers);
                return true;
            }
            if (this.config.isAnyOriginSupported()) {
                if (this.config.isCredentialsAllowed()) {
                    CorsService.echoCorsRequestOrigin(request, headers);
                    CorsService.setCorsVaryHeader(headers);
                } else {
                    CorsService.setCorsAnyOrigin(headers);
                }
                return true;
            }
            if (this.config.origins().contains(Ascii.toLowerCase(origin))) {
                CorsService.setCorsOrigin(headers, origin);
                CorsService.setCorsVaryHeader(headers);
                return true;
            }
            logger.debug("Request origin [{}]] was not among the configured origins [{}]", (Object)origin, this.config.origins());
        }
        return false;
    }

    private static void setCorsOrigin(HttpHeaders headers, String origin) {
        headers.set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN, origin);
    }

    private static void echoCorsRequestOrigin(HttpRequest request, HttpHeaders headers) {
        CorsService.setCorsOrigin(headers, (String)request.headers().get(HttpHeaderNames.ORIGIN));
    }

    private static void setCorsVaryHeader(HttpHeaders headers) {
        headers.set(HttpHeaderNames.VARY, HttpHeaderNames.ORIGIN.toString());
    }

    private static void setCorsAnyOrigin(HttpHeaders headers) {
        CorsService.setCorsOrigin(headers, ANY_ORIGIN);
    }

    private static void setCorsNullOrigin(HttpHeaders headers) {
        CorsService.setCorsOrigin(headers, NULL_ORIGIN);
    }

    private void setCorsAllowCredentials(HttpHeaders headers) {
        if (this.config.isCredentialsAllowed() && !((String)headers.get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN)).equals(ANY_ORIGIN)) {
            headers.set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS, "true");
        }
    }

    private void setCorsExposeHeaders(HttpHeaders headers) {
        Set<AsciiString> exposedHeaders = this.config.exposedHeaders();
        if (exposedHeaders.isEmpty()) {
            return;
        }
        headers.set(HttpHeaderNames.ACCESS_CONTROL_EXPOSE_HEADERS, HEADER_JOINER.join(exposedHeaders));
    }

    private void setCorsAllowMethods(HttpHeaders headers) {
        String methods = this.config.allowedRequestMethods().stream().map(Enum::name).collect(Collectors.joining(DELIMITER));
        headers.set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_METHODS, methods);
    }

    private void setCorsAllowHeaders(HttpHeaders headers) {
        Set<AsciiString> allowedHeaders = this.config.allowedRequestHeaders();
        if (allowedHeaders.isEmpty()) {
            return;
        }
        headers.set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_HEADERS, HEADER_JOINER.join(allowedHeaders));
    }

    private void setCorsMaxAge(HttpHeaders headers) {
        headers.setLong(HttpHeaderNames.ACCESS_CONTROL_MAX_AGE, this.config.maxAge());
    }
}

