/*
 * (c) 2003-2020 MuleSoft, Inc. This software is protected under international copyright
 * law. All use of this software is subject to MuleSoft's Master Subscription Agreement
 * (or other master license agreement) separately entered into in writing between you and
 * MuleSoft. If such an agreement is not in place, you may not use the software.
 */
package com.mulesoft.modules.oauth2.provider.api.ratelimit;

import static com.mulesoft.modules.oauth2.provider.api.ratelimit.RateLimiter.Outcome.SUCCESS;
import static java.time.Instant.now;
import static org.mule.runtime.api.meta.ExpressionSupport.NOT_SUPPORTED;
import org.mule.runtime.api.util.LazyValue;
import org.mule.runtime.extension.api.annotation.Expression;
import org.mule.runtime.extension.api.annotation.param.Optional;
import org.mule.runtime.extension.api.annotation.param.Parameter;

import com.mulesoft.modules.oauth2.provider.internal.ratelimit.RateLimitExceededException;

import java.time.Instant;
import java.time.Duration;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.TimeUnit;

import org.apache.commons.lang3.tuple.Pair;

/**
 * Rate limiter that resets according to a defined time period.
 * </p>
 * If the fauilure count is reached, the rate limiter will start rejecting operations until the period id completed, and reset after.
 *
 * @since 1.0.0
 */
public class PeriodRateLimiter implements RateLimiter {

  @Parameter
  @Expression(NOT_SUPPORTED)
  @Optional(defaultValue = "600")
  private int duration;

  @Parameter
  @Expression(NOT_SUPPORTED)
  @Optional(defaultValue = "SECONDS")
  private TimeUnit durationTimeUnit;

  @Parameter
  @Expression(NOT_SUPPORTED)
  @Optional(defaultValue = "5")
  private int maximumFailureCount;

  private LazyValue<Integer> autoResetAfterSeconds = new LazyValue<>(() -> (int) durationTimeUnit.toSeconds(duration));

  private final ConcurrentMap<String, Pair<? extends Instant, Integer>> cache = new ConcurrentHashMap<>();


  @Override
  public void checkOperationAuthorized(final Operation operation, final String context)
      throws RateLimitExceededException {
    final String cacheKey = getCacheKey(operation, context);

    final Pair<? extends Instant, Integer> autoResetInstantAndFailureCount = cache.get(cacheKey);

    if (autoResetInstantAndFailureCount == null) {
      return;
    }

    if (autoResetInstantAndFailureCount.getLeft().isBefore(now())) {
      cache.remove(cacheKey);
      return;
    }

    if (autoResetInstantAndFailureCount.getRight() < maximumFailureCount) {
      return;
    }

    throw new RateLimitExceededException("Maximum of " + maximumFailureCount + " failed attempts reached");
  }

  @Override
  public void recordOperationOutcome(final Operation operation, final String context, final Outcome outcome) {
    final String cacheKey = getCacheKey(operation, context);

    if (outcome == SUCCESS) {
      cache.remove(cacheKey);
      return;
    }

    final Pair<? extends Instant, Integer> autoResetInstantAndFailureCount = cache.get(cacheKey);

    if (autoResetInstantAndFailureCount == null) {
      cache.put(cacheKey,
                Pair.of(Instant.now().plus(Duration.ofSeconds(autoResetAfterSeconds.get())), 1));
    } else {
      cache.put(cacheKey, Pair.of(autoResetInstantAndFailureCount.getLeft(),
                                  autoResetInstantAndFailureCount.getRight() + 1));
    }
  }

  private String getCacheKey(final Operation operation, final String context) {
    return operation.toString() + "|" + context;
  }

  public int getMaximumFailureCount() {
    return maximumFailureCount;
  }

  public void setMaximumFailureCount(final int maximumFailureCount) {
    this.maximumFailureCount = maximumFailureCount;
  }

  public void setDuration(int duration, TimeUnit timeUnit) {
    this.duration = duration;
    this.durationTimeUnit = timeUnit;
    autoResetAfterSeconds = new LazyValue<>(() -> (int) durationTimeUnit.toSeconds(duration));
  }

}
