/*
 * Decompiled with CFR 0.152.
 */
package com.microsoft.azure.spring.autoconfigure.aad;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.microsoft.azure.spring.autoconfigure.aad.AzureADGraphClient;
import com.microsoft.azure.spring.autoconfigure.aad.JacksonObjectMapperFactory;
import com.microsoft.azure.spring.autoconfigure.aad.ServiceEndpoints;
import com.microsoft.azure.spring.autoconfigure.aad.UserGroup;
import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.JWSObject;
import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.JWKSet;
import com.nimbusds.jose.jwk.source.JWKSource;
import com.nimbusds.jose.jwk.source.RemoteJWKSet;
import com.nimbusds.jose.proc.BadJOSEException;
import com.nimbusds.jose.proc.JWSKeySelector;
import com.nimbusds.jose.proc.JWSVerificationKeySelector;
import com.nimbusds.jose.proc.SecurityContext;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.proc.BadJWTException;
import com.nimbusds.jwt.proc.ConfigurableJWTProcessor;
import com.nimbusds.jwt.proc.DefaultJWTClaimsVerifier;
import com.nimbusds.jwt.proc.DefaultJWTProcessor;
import com.nimbusds.jwt.proc.JWTClaimsSetVerifier;
import java.io.IOException;
import java.net.MalformedURLException;
import java.net.URL;
import java.text.ParseException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.security.core.context.SecurityContextHolder;

public class UserPrincipal {
    private static final Logger LOG = LoggerFactory.getLogger(UserPrincipal.class);
    private ServiceEndpoints serviceEndpoints;
    private JWKSet jwsKeySet;
    private JWSObject jwsObject;
    private JWTClaimsSet jwtClaimsSet;
    private List<UserGroup> userGroups;

    public UserPrincipal() {
        this.jwsObject = null;
        this.jwtClaimsSet = null;
        this.userGroups = null;
        this.serviceEndpoints = new ServiceEndpoints();
    }

    public UserPrincipal(String idToken, ServiceEndpoints serviceEndpoints) throws MalformedURLException, ParseException, BadJOSEException, JOSEException {
        this.serviceEndpoints = serviceEndpoints;
        this.jwsKeySet = this.loadAadPublicKeys();
        ConfigurableJWTProcessor<SecurityContext> validator = this.getAadJwtTokenValidator();
        this.jwtClaimsSet = validator.process(idToken, null);
        JWTClaimsSetVerifier verifier = validator.getJWTClaimsSetVerifier();
        verifier.verify(this.jwtClaimsSet, null);
        this.jwsObject = JWSObject.parse((String)idToken);
        this.userGroups = null;
    }

    private JWKSet loadAadPublicKeys() {
        try {
            return JWKSet.load((URL)new URL(this.serviceEndpoints.getAadKeyDiscoveryUri()));
        }
        catch (IOException | ParseException e) {
            LOG.error("Error loading AAD public keys: {}", (Object)e.getMessage());
            return null;
        }
    }

    public String getIssuer() {
        return this.jwtClaimsSet == null ? null : this.jwtClaimsSet.getIssuer();
    }

    public String getSubject() {
        return this.jwtClaimsSet == null ? null : this.jwtClaimsSet.getSubject();
    }

    public Map<String, Object> getClaims() {
        return this.jwtClaimsSet == null ? null : this.jwtClaimsSet.getClaims();
    }

    public Object getClaim() {
        return this.jwtClaimsSet == null ? null : this.jwtClaimsSet.getClaim("tid");
    }

    public String getKid() {
        return this.jwsObject == null ? null : this.jwsObject.getHeader().getKeyID();
    }

    public JWK getJWKByKid(String kid) {
        return this.jwsKeySet == null ? null : this.jwsKeySet.getKeyByKeyId(kid);
    }

    public List<UserGroup> getGroups(String graphApiToken) throws Exception {
        if (this.userGroups == null) {
            this.userGroups = this.loadUserGroups(graphApiToken);
        }
        return this.userGroups;
    }

    public boolean isMemberOf(UserGroup group) {
        return this.userGroups != null && !this.userGroups.isEmpty() && this.userGroups.contains(group);
    }

    public List<GrantedAuthority> getAuthoritiesByUserGroups(List<UserGroup> userGroups, List<String> targetdGroupNames) {
        if (userGroups == null || targetdGroupNames == null || userGroups.isEmpty() || targetdGroupNames.isEmpty()) {
            return Collections.emptyList();
        }
        return userGroups.stream().filter(usergroup -> targetdGroupNames.contains(usergroup.getDisplayName())).map(usergroup -> "ROLE_" + usergroup.getDisplayName()).map(SimpleGrantedAuthority::new).collect(Collectors.toList());
    }

    public Collection<? extends GrantedAuthority> getAuthorities() {
        return SecurityContextHolder.getContext().getAuthentication().getAuthorities();
    }

    public Authentication getAuthentication() {
        return SecurityContextHolder.getContext().getAuthentication();
    }

    private ConfigurableJWTProcessor<SecurityContext> getAadJwtTokenValidator() throws MalformedURLException {
        DefaultJWTProcessor jwtProcessor = new DefaultJWTProcessor();
        RemoteJWKSet keySource = new RemoteJWKSet(new URL(this.serviceEndpoints.getAadKeyDiscoveryUri()));
        JWSAlgorithm expectedJWSAlg = JWSAlgorithm.RS256;
        JWSVerificationKeySelector keySelector = new JWSVerificationKeySelector(expectedJWSAlg, (JWKSource)keySource);
        jwtProcessor.setJWSKeySelector((JWSKeySelector)keySelector);
        jwtProcessor.setJWTClaimsSetVerifier((JWTClaimsSetVerifier)new DefaultJWTClaimsVerifier<SecurityContext>(){

            public void verify(JWTClaimsSet claimsSet, SecurityContext ctx) throws BadJWTException {
                super.verify(claimsSet, ctx);
                String issuer = claimsSet.getIssuer();
                if (issuer == null || !issuer.contains("https://sts.windows.net/") && !issuer.contains("https://sts.chinacloudapi.cn/")) {
                    throw new BadJWTException("Invalid token issuer");
                }
            }
        });
        return jwtProcessor;
    }

    private List<UserGroup> loadUserGroups(String graphApiToken) throws Exception {
        String responseInJson = AzureADGraphClient.getUserMembershipsV1(graphApiToken, this.serviceEndpoints.getAadMembershipRestUri());
        ArrayList<UserGroup> lUserGroups = new ArrayList<UserGroup>();
        ObjectMapper objectMapper = JacksonObjectMapperFactory.getInstance();
        JsonNode rootNode = (JsonNode)objectMapper.readValue(responseInJson, JsonNode.class);
        JsonNode valuesNode = rootNode.get("value");
        int i = 0;
        while (valuesNode != null && valuesNode.get(i) != null) {
            if (valuesNode.get(i).get("objectType").asText().equals("Group")) {
                lUserGroups.add(new UserGroup(valuesNode.get(i).get("objectId").asText(), valuesNode.get(i).get("displayName").asText()));
            }
            ++i;
        }
        return lUserGroups;
    }
}

