package com.mulesoft.connectors.http.commons.connection.provider;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.mulesoft.connectors.http.commons.connection.ConnectorHttpConnection;
import com.mulesoft.connectors.http.commons.connection.provider.param.ConnectionParameterGroup;
import com.mulesoft.connectors.http.commons.connection.provider.param.jwt.AccessTokenRequestParameterGroup;
import com.mulesoft.connectors.http.commons.connection.provider.param.jwt.JwtClaimsParameterGroup;
import com.mulesoft.connectors.http.commons.connection.provider.param.jwt.JwtHeadersParameterGroup;
import com.mulesoft.extensions.request.builder.RequestBuilder;
import io.jsonwebtoken.Jwts;
import io.jsonwebtoken.SignatureAlgorithm;
import io.jsonwebtoken.io.JacksonSerializer;
import org.apache.commons.lang3.RandomUtils;
import org.bouncycastle.jce.provider.BouncyCastleProvider;
import org.bouncycastle.openssl.PEMReader;
import org.mule.runtime.api.connection.CachedConnectionProvider;
import org.mule.runtime.api.connection.ConnectionException;
import org.mule.runtime.api.util.MultiMap;
import org.mule.runtime.extension.api.annotation.param.Parameter;
import org.mule.runtime.extension.api.annotation.param.ParameterGroup;
import org.mule.runtime.extension.api.annotation.param.display.Path;
import org.mule.runtime.extension.api.annotation.param.display.Placement;
import org.mule.runtime.extension.api.annotation.param.display.Summary;
import org.mule.runtime.http.api.client.HttpClient;
import org.mule.runtime.http.api.client.auth.HttpAuthentication;

import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.security.KeyPair;
import java.security.Security;
import java.time.Instant;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.TimeoutException;

import static java.lang.String.format;
import static java.util.stream.Collectors.joining;
import static org.mule.runtime.api.meta.model.display.PathModel.Type.FILE;
import static org.mule.runtime.extension.api.annotation.param.ParameterGroup.CONNECTION;

public abstract class AbstractJWTHttpConnectionProvider<CONNECTION extends ConnectorHttpConnection> extends AbstractHttpConnectionProvider<CONNECTION, ConnectionParameterGroup> implements CachedConnectionProvider<CONNECTION> {
    private final ObjectMapper mapper = new ObjectMapper();

    @ParameterGroup(name = CONNECTION)
    @Placement(order = 1)
    private ConnectionParameterGroup connectionParams;

    /**
     * The path to the keystore used to sign the JWT claims.
     */
    @Parameter
    @Placement(order = 1)
    @Summary("The path to the keystore used to sign the JWT claims.")
    @Path(type = FILE)
    private String keystorePath;

    @Override
    protected MultiMap<String, String> getAuthorizationHeaders(HttpClient httpClient) throws ConnectionException {
        try {
            Security.addProvider(new BouncyCastleProvider());
            MultiMap<String, String> result = new MultiMap<>();
            Map<String, Object> headers = new HashMap<>();
            JwtHeadersParameterGroup headersGroup = getHeadersGroup();
            Optional.ofNullable(headersGroup.getAlg()).ifPresent(alg -> headers.put("alg", alg));
            Optional.ofNullable(headersGroup.getTyp()).ifPresent(typ -> headers.put("typ", typ));
            Optional.ofNullable(headersGroup.getCty()).ifPresent(cty -> headers.put("cty", cty));
            headers.putAll(headersGroup.getCustomHeaders());
            Map<String, Object> claims = new HashMap<>();
            JwtClaimsParameterGroup claimsGroup = getClaimsGroup();
            Optional.ofNullable(claimsGroup.getIss()).ifPresent(iss -> claims.put("iss", iss));
            Optional.ofNullable(claimsGroup.getAud()).ifPresent(aud -> claims.put("aud", aud));
            claims.put("exp", Optional.ofNullable(claimsGroup.getExp()).map(Long::valueOf).orElse(Instant.now().getEpochSecond() + 60));
            Optional.ofNullable(claimsGroup.getIat()).map(Integer::valueOf).ifPresent(iat -> claims.put("iat", iat));
            claims.put("jti", Optional.ofNullable(claimsGroup.getJti()).orElse(new String(RandomUtils.nextBytes(128))));
            Optional.ofNullable(claimsGroup.getNbf()).map(Integer::valueOf).ifPresent(nbf -> claims.put("nbf", nbf));
            Optional.ofNullable(claimsGroup.getSub()).ifPresent(sub -> claims.put("sub", sub));
            claims.putAll(claimsGroup.getCustomClaims());
            Map<String, String> accessTokenRequestBodyParameters = new HashMap<>();
            accessTokenRequestBodyParameters.putAll(getAccessTokenRequestParameters().getParameters());
            accessTokenRequestBodyParameters.put("assertion", Jwts.builder()
                    .signWith(KeyPair.class.cast(new PEMReader(new FileReader(new File(keystorePath))).readObject()).getPrivate(), SignatureAlgorithm.valueOf(headersGroup.getAlg()))
                    .setHeaderParams(headers)
                    .addClaims(claims)
                    .serializeToJsonWith(new JacksonSerializer())
                    .compact());
            result.put("Authorization", format("Bearer %s", mapper.readValue(RequestBuilder.post(httpClient, getAccessTokenRequestParameters().getUrl())
                    .entity(accessTokenRequestBodyParameters.entrySet().stream()
                            .map(entry -> format("%s=%s", entry.getKey(), entry.getValue()))
                            .collect(joining("&")))
                    .execute(), HashMap.class).get("access_token").toString()));
            return result;
        } catch (FileNotFoundException e) {
            throw new ConnectionException("Unable to find keystore file.", e);
        } catch (TimeoutException e) {
            throw new ConnectionException("Timeout.", e);
        } catch (IOException e) {
            throw new ConnectionException("An error occurred while retrieving the access token.", e);
        }
    }

    @Override
    protected HttpAuthentication getAuthentication(HttpClient httpClient) throws ConnectionException {
        return null;
    }

    @Override
    public ConnectionParameterGroup getConnectionParams() {
        return connectionParams;
    }

    protected abstract JwtHeadersParameterGroup getHeadersGroup();

    protected abstract JwtClaimsParameterGroup getClaimsGroup();

    protected abstract AccessTokenRequestParameterGroup getAccessTokenRequestParameters();
}