/*
 * Decompiled with CFR 0.152.
 */
package com.azure.spring.cloud.autoconfigure.aad.implementation.jwt;

import com.azure.spring.cloud.autoconfigure.aad.implementation.jwt.AadJwtEncoder;
import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.JWKSet;
import com.nimbusds.jose.jwk.KeyType;
import com.nimbusds.jose.jwk.source.ImmutableJWKSet;
import com.nimbusds.jose.jwk.source.JWKSource;
import com.nimbusds.jose.proc.SecurityContext;
import java.time.Duration;
import java.time.Instant;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;
import org.springframework.core.convert.converter.Converter;
import org.springframework.security.oauth2.client.endpoint.AbstractOAuth2AuthorizationGrantRequest;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.jose.jws.JwsAlgorithm;
import org.springframework.security.oauth2.jose.jws.MacAlgorithm;
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.util.Assert;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;

public final class AadJwtClientAuthenticationParametersConverter<T extends AbstractOAuth2AuthorizationGrantRequest>
implements Converter<T, MultiValueMap<String, String>> {
    private static final String INVALID_KEY_ERROR_CODE = "invalid_key";
    private static final String INVALID_ALGORITHM_ERROR_CODE = "invalid_algorithm";
    public static final String CLIENT_ASSERTION_TYPE_VALUE = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer";
    private final Function<ClientRegistration, JWK> jwkResolver;
    private final Map<String, JwsEncoderHolder> jwsEncoders = new ConcurrentHashMap<String, JwsEncoderHolder>();

    public AadJwtClientAuthenticationParametersConverter(Function<ClientRegistration, JWK> jwkResolver) {
        Assert.notNull(jwkResolver, (String)"jwkResolver cannot be null");
        this.jwkResolver = jwkResolver;
    }

    public MultiValueMap<String, String> convert(T authorizationGrantRequest) {
        Assert.notNull(authorizationGrantRequest, (String)"authorizationGrantRequest cannot be null");
        ClientRegistration clientRegistration = authorizationGrantRequest.getClientRegistration();
        if (!ClientAuthenticationMethod.PRIVATE_KEY_JWT.equals((Object)clientRegistration.getClientAuthenticationMethod())) {
            return null;
        }
        JWK jwk = this.jwkResolver.apply(clientRegistration);
        if (jwk == null) {
            OAuth2Error oauth2Error = new OAuth2Error(INVALID_KEY_ERROR_CODE, "Failed to resolve JWK signing key for client registration '" + clientRegistration.getRegistrationId() + "'.", null);
            throw new OAuth2AuthorizationException(oauth2Error);
        }
        JwsAlgorithm jwsAlgorithm = AadJwtClientAuthenticationParametersConverter.resolveAlgorithm(jwk);
        if (jwsAlgorithm == null) {
            OAuth2Error oauth2Error = new OAuth2Error(INVALID_ALGORITHM_ERROR_CODE, "Unable to resolve JWS (signing) algorithm from JWK associated to client registration '" + clientRegistration.getRegistrationId() + "'.", null);
            throw new OAuth2AuthorizationException(oauth2Error);
        }
        HashMap<String, Object> jwsHeader = new HashMap<String, Object>();
        jwsHeader.put("typ", "JWT");
        jwsHeader.put("alg", SignatureAlgorithm.RS256.getName());
        jwsHeader.put("x5t", jwk.getX509CertThumbprint().toString());
        HashMap<String, Object> jwtClaimsSet = new HashMap<String, Object>();
        Instant issuedAt = Instant.now();
        Instant expiresAt = issuedAt.plus(Duration.ofSeconds(60L));
        jwtClaimsSet.put("iss", clientRegistration.getClientId());
        jwtClaimsSet.put("sub", clientRegistration.getClientId());
        jwtClaimsSet.put("aud", Collections.singletonList(clientRegistration.getProviderDetails().getTokenUri()));
        jwtClaimsSet.put("jti", UUID.randomUUID().toString());
        jwtClaimsSet.put("iat", issuedAt);
        jwtClaimsSet.put("exp", expiresAt);
        JwsEncoderHolder jwsEncoderHolder = this.jwsEncoders.compute(clientRegistration.getRegistrationId(), (clientRegistrationId, currentJwsEncoderHolder) -> {
            if (currentJwsEncoderHolder != null && currentJwsEncoderHolder.getJwk().equals((Object)jwk)) {
                return currentJwsEncoderHolder;
            }
            ImmutableJWKSet jwkSource = new ImmutableJWKSet(new JWKSet(jwk));
            return new JwsEncoderHolder(new AadJwtEncoder((JWKSource<SecurityContext>)jwkSource), jwk);
        });
        AadJwtEncoder jwtEncoder = jwsEncoderHolder.getJwtEncoder();
        Jwt jwt = jwtEncoder.encode(jwsHeader, jwtClaimsSet);
        LinkedMultiValueMap parameters = new LinkedMultiValueMap();
        parameters.set((Object)"client_assertion_type", (Object)CLIENT_ASSERTION_TYPE_VALUE);
        parameters.set((Object)"client_assertion", (Object)jwt.getTokenValue());
        return parameters;
    }

    private static JwsAlgorithm resolveAlgorithm(JWK jwk) {
        SignatureAlgorithm jwsAlgorithm = null;
        if (jwk.getAlgorithm() != null && (jwsAlgorithm = SignatureAlgorithm.from((String)jwk.getAlgorithm().getName())) == null) {
            jwsAlgorithm = MacAlgorithm.from((String)jwk.getAlgorithm().getName());
        }
        if (jwsAlgorithm == null && KeyType.RSA.equals((Object)jwk.getKeyType())) {
            jwsAlgorithm = SignatureAlgorithm.RS256;
        }
        return jwsAlgorithm;
    }

    private static final class JwsEncoderHolder {
        private final AadJwtEncoder jwtEncoder;
        private final JWK jwk;

        private JwsEncoderHolder(AadJwtEncoder jwtEncoder, JWK jwk) {
            this.jwtEncoder = jwtEncoder;
            this.jwk = jwk;
        }

        private AadJwtEncoder getJwtEncoder() {
            return this.jwtEncoder;
        }

        private JWK getJwk() {
            return this.jwk;
        }
    }
}

