/*
 * Decompiled with CFR 0.152.
 */
package io.spiffe.workloadapi;

import io.spiffe.bundle.jwtbundle.JwtBundle;
import io.spiffe.bundle.jwtbundle.JwtBundleSet;
import io.spiffe.exception.BundleNotFoundException;
import io.spiffe.exception.JwtSourceException;
import io.spiffe.exception.JwtSvidException;
import io.spiffe.exception.SocketEndpointAddressException;
import io.spiffe.exception.WatcherException;
import io.spiffe.spiffeid.SpiffeId;
import io.spiffe.spiffeid.TrustDomain;
import io.spiffe.svid.jwtsvid.JwtSvid;
import io.spiffe.workloadapi.DefaultWorkloadApiClient;
import io.spiffe.workloadapi.JwtSource;
import io.spiffe.workloadapi.JwtSourceOptions;
import io.spiffe.workloadapi.Watcher;
import io.spiffe.workloadapi.WorkloadApiClient;
import io.spiffe.workloadapi.internal.ThreadUtils;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.Arrays;
import java.util.Collections;
import java.util.Date;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.logging.Level;
import java.util.logging.Logger;
import lombok.Generated;
import lombok.NonNull;
import org.apache.commons.lang3.tuple.ImmutablePair;

public class CachedJwtSource
implements JwtSource {
    @Generated
    private static final Logger log = Logger.getLogger(CachedJwtSource.class.getName());
    static final String TIMEOUT_SYSTEM_PROPERTY = "spiffe.newJwtSource.timeout";
    static final Duration DEFAULT_TIMEOUT = Duration.parse(System.getProperty("spiffe.newJwtSource.timeout", "PT0S"));
    private final Map<ImmutablePair<SpiffeId, Set<String>>, List<JwtSvid>> jwtSvids = new ConcurrentHashMap<ImmutablePair<SpiffeId, Set<String>>, List<JwtSvid>>();
    private JwtBundleSet bundles;
    private final WorkloadApiClient workloadApiClient;
    private volatile boolean closed;
    private Clock clock = Clock.systemDefaultZone();

    private CachedJwtSource(WorkloadApiClient workloadApiClient) {
        this.workloadApiClient = workloadApiClient;
    }

    public static JwtSource newSource() throws JwtSourceException, SocketEndpointAddressException {
        JwtSourceOptions options = JwtSourceOptions.builder().initTimeout(DEFAULT_TIMEOUT).build();
        return CachedJwtSource.newSource(options);
    }

    public static JwtSource newSource(@NonNull JwtSourceOptions options) throws SocketEndpointAddressException, JwtSourceException {
        if (options == null) {
            throw new NullPointerException("options is marked non-null but is null");
        }
        if (options.getWorkloadApiClient() == null) {
            options.setWorkloadApiClient(CachedJwtSource.createClient(options));
        }
        if (options.getInitTimeout() == null) {
            options.setInitTimeout(DEFAULT_TIMEOUT);
        }
        CachedJwtSource jwtSource = new CachedJwtSource(options.getWorkloadApiClient());
        try {
            jwtSource.init(options.getInitTimeout());
        }
        catch (Exception e) {
            jwtSource.close();
            throw new JwtSourceException("Error creating JWT source", e);
        }
        return jwtSource;
    }

    @Override
    public JwtSvid fetchJwtSvid(String audience, String ... extraAudiences) throws JwtSvidException {
        if (this.isClosed()) {
            throw new IllegalStateException("JWT SVID source is closed");
        }
        return this.getJwtSvids(audience, extraAudiences).get(0);
    }

    @Override
    public JwtSvid fetchJwtSvid(SpiffeId subject, String audience, String ... extraAudiences) throws JwtSvidException {
        if (this.isClosed()) {
            throw new IllegalStateException("JWT SVID source is closed");
        }
        return this.getJwtSvids(subject, audience, extraAudiences).get(0);
    }

    @Override
    public List<JwtSvid> fetchJwtSvids(String audience, String ... extraAudiences) throws JwtSvidException {
        if (this.isClosed()) {
            throw new IllegalStateException("JWT SVID source is closed");
        }
        return this.getJwtSvids(audience, extraAudiences);
    }

    @Override
    public List<JwtSvid> fetchJwtSvids(SpiffeId subject, String audience, String ... extraAudiences) throws JwtSvidException {
        if (this.isClosed()) {
            throw new IllegalStateException("JWT SVID source is closed");
        }
        return this.getJwtSvids(subject, audience, extraAudiences);
    }

    @Override
    public JwtBundle getBundleForTrustDomain(@NonNull TrustDomain trustDomain) throws BundleNotFoundException {
        if (trustDomain == null) {
            throw new NullPointerException("trustDomain is marked non-null but is null");
        }
        if (this.isClosed()) {
            throw new IllegalStateException("JWT bundle source is closed");
        }
        return this.bundles.getBundleForTrustDomain(trustDomain);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void close() {
        block6: {
            if (this.closed) break block6;
            CachedJwtSource cachedJwtSource = this;
            synchronized (cachedJwtSource) {
                if (!this.closed) {
                    this.workloadApiClient.close();
                    this.closed = true;
                }
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private List<JwtSvid> getJwtSvids(SpiffeId subject, String audience, String ... extraAudiences) throws JwtSvidException {
        Set<String> audiencesSet = CachedJwtSource.getAudienceSet(audience, extraAudiences);
        ImmutablePair cacheKey = new ImmutablePair((Object)subject, audiencesSet);
        List<JwtSvid> svidList = this.jwtSvids.get(cacheKey);
        if (svidList != null && !this.isTokenPastHalfLifetime(svidList.get(0))) {
            return svidList;
        }
        CachedJwtSource cachedJwtSource = this;
        synchronized (cachedJwtSource) {
            svidList = this.jwtSvids.get(cacheKey);
            if (svidList != null && !this.isTokenPastHalfLifetime(svidList.get(0))) {
                return svidList;
            }
            svidList = cacheKey.left == null ? this.workloadApiClient.fetchJwtSvids(audience, extraAudiences) : this.workloadApiClient.fetchJwtSvids((SpiffeId)cacheKey.left, audience, extraAudiences);
            this.jwtSvids.put((ImmutablePair<SpiffeId, Set<String>>)cacheKey, svidList);
            return svidList;
        }
    }

    private List<JwtSvid> getJwtSvids(String audience, String ... extraAudiences) throws JwtSvidException {
        return this.getJwtSvids((SpiffeId)null, audience, extraAudiences);
    }

    private static Set<String> getAudienceSet(String audience, String[] extraAudiences) {
        Set<String> audiencesString;
        if (extraAudiences != null && extraAudiences.length > 0) {
            audiencesString = new HashSet<String>(Arrays.asList(extraAudiences));
            audiencesString.add(audience);
        } else {
            audiencesString = Collections.singleton(audience);
        }
        return audiencesString;
    }

    private boolean isTokenPastHalfLifetime(JwtSvid jwtSvid) {
        Instant now = this.clock.instant();
        Date halfLife = new Date(jwtSvid.getExpiry().getTime() - (jwtSvid.getExpiry().getTime() - jwtSvid.getIssuedAt().getTime()) / 2L);
        Instant halfLifeInstant = Instant.ofEpochMilli(halfLife.getTime());
        return now.isAfter(halfLifeInstant);
    }

    private void init(Duration timeout) throws TimeoutException {
        boolean success;
        CountDownLatch done = new CountDownLatch(1);
        this.setJwtBundlesWatcher(done);
        if (timeout.isZero()) {
            ThreadUtils.await(done);
            success = true;
        } else {
            success = ThreadUtils.await(done, timeout.getSeconds(), TimeUnit.SECONDS);
        }
        if (!success) {
            throw new TimeoutException("Timeout waiting for JWT bundles update");
        }
    }

    private void setJwtBundlesWatcher(final CountDownLatch done) {
        this.workloadApiClient.watchJwtBundles(new Watcher<JwtBundleSet>(){

            @Override
            public void onUpdate(JwtBundleSet update) {
                log.log(Level.INFO, "Received JwtBundleSet update");
                CachedJwtSource.this.setJwtBundleSet(update);
                done.countDown();
            }

            @Override
            public void onError(Throwable error) {
                log.log(Level.SEVERE, "Error in JwtBundleSet watcher", error);
                done.countDown();
                throw new WatcherException("Error fetching JwtBundleSet", error);
            }
        });
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void setJwtBundleSet(JwtBundleSet update) {
        CachedJwtSource cachedJwtSource = this;
        synchronized (cachedJwtSource) {
            this.bundles = update;
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private boolean isClosed() {
        CachedJwtSource cachedJwtSource = this;
        synchronized (cachedJwtSource) {
            return this.closed;
        }
    }

    private static WorkloadApiClient createClient(JwtSourceOptions options) throws SocketEndpointAddressException {
        DefaultWorkloadApiClient.ClientOptions clientOptions = DefaultWorkloadApiClient.ClientOptions.builder().spiffeSocketPath(options.getSpiffeSocketPath()).build();
        return DefaultWorkloadApiClient.newClient(clientOptions);
    }

    void setClock(Clock clock) {
        this.clock = clock;
    }
}

