/*
 * Decompiled with CFR 0.152.
 */
package org.keycloak.services.managers;

import java.util.ArrayList;
import java.util.Collections;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import org.jboss.logging.Logger;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.KeycloakSessionFactory;
import org.keycloak.models.RealmModel;
import org.keycloak.models.UsernameLoginFailureModel;
import org.keycloak.services.ClientConnection;

public class BruteForceProtector
implements Runnable {
    protected static Logger logger = Logger.getLogger(BruteForceProtector.class);
    protected volatile boolean run = true;
    protected int maxDeltaTimeSeconds = 43200;
    protected KeycloakSessionFactory factory;
    protected CountDownLatch shutdownLatch = new CountDownLatch(1);
    protected volatile long failures;
    protected volatile long lastFailure;
    protected volatile long totalTime;
    protected LinkedBlockingQueue<LoginEvent> queue = new LinkedBlockingQueue();
    public static final int TRANSACTION_SIZE = 20;

    public BruteForceProtector(KeycloakSessionFactory factory) {
        this.factory = factory;
    }

    public void failure(KeycloakSession session, LoginEvent event) {
        logger.debug((Object)"failure");
        RealmModel realm = this.getRealmModel(session, event);
        this.logFailure(event);
        UsernameLoginFailureModel user = this.getUserModel(session, event);
        if (user == null) {
            user = realm.addUserLoginFailure(event.username);
        }
        user.setLastIPFailure(event.ip);
        long currentTime = System.currentTimeMillis();
        long last = user.getLastFailure();
        long deltaTime = 0L;
        if (last > 0L) {
            deltaTime = currentTime - last;
        }
        user.setLastFailure(currentTime);
        if (deltaTime > 0L && deltaTime > (long)realm.getMaxDeltaTimeSeconds() * 1000L) {
            user.clearFailures();
        }
        user.incrementFailures();
        logger.debugv("new num failures: {0}", (Object)user.getNumFailures());
        int waitSeconds = realm.getWaitIncrementSeconds() * (user.getNumFailures() / realm.getFailureFactor());
        logger.debugv("waitSeconds: {0}", (Object)waitSeconds);
        logger.debugv("deltaTime: {0}", (Object)deltaTime);
        if (waitSeconds == 0 && last > 0L && deltaTime < realm.getQuickLoginCheckMilliSeconds()) {
            logger.debugv("quick login, set min wait seconds", new Object[0]);
            waitSeconds = realm.getMinimumQuickLoginWaitSeconds();
        }
        if (waitSeconds > 0) {
            waitSeconds = Math.min(realm.getMaxFailureWaitSeconds(), waitSeconds);
            int notBefore = (int)(currentTime / 1000L) + waitSeconds;
            logger.debugv("set notBefore: {0}", (Object)notBefore);
            user.setFailedLoginNotBefore(notBefore);
        }
    }

    protected UsernameLoginFailureModel getUserModel(KeycloakSession session, LoginEvent event) {
        RealmModel realm = this.getRealmModel(session, event);
        if (realm == null) {
            return null;
        }
        UsernameLoginFailureModel user = realm.getUserLoginFailure(event.username);
        if (user == null) {
            return null;
        }
        return user;
    }

    protected RealmModel getRealmModel(KeycloakSession session, LoginEvent event) {
        RealmModel realm = session.getRealm(event.realmId);
        if (realm == null) {
            return null;
        }
        return realm;
    }

    public void start() {
        new Thread((Runnable)this, "Brute Force Protector").start();
    }

    public void shutdown() {
        this.run = false;
        try {
            this.queue.offer(new ShutdownEvent());
            this.shutdownLatch.await(5L, TimeUnit.SECONDS);
        }
        catch (InterruptedException e) {
            throw new RuntimeException(e);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void run() {
        ArrayList<LoginEvent> events = new ArrayList<LoginEvent>(21);
        while (this.run) {
            try {
                LoginEvent take = this.queue.poll(2L, TimeUnit.SECONDS);
                if (take == null) continue;
                try {
                    events.add(take);
                    this.queue.drainTo(events, 20);
                    Collections.sort(events);
                    KeycloakSession session = this.factory.createSession();
                    session.getTransaction().begin();
                    try {
                        for (LoginEvent event : events) {
                            if (!(event instanceof FailedLogin)) continue;
                            this.failure(session, event);
                        }
                        session.getTransaction().commit();
                    }
                    catch (Exception e) {
                        try {
                            session.getTransaction().rollback();
                            throw e;
                        }
                        catch (Throwable throwable) {
                            for (LoginEvent event : events) {
                                if (!(event instanceof FailedLogin)) continue;
                                ((FailedLogin)event).latch.countDown();
                            }
                            events.clear();
                            session.close();
                            throw throwable;
                        }
                    }
                    for (LoginEvent event : events) {
                        if (!(event instanceof FailedLogin)) continue;
                        ((FailedLogin)event).latch.countDown();
                    }
                    events.clear();
                    session.close();
                }
                catch (Exception e) {
                    logger.error((Object)"Failed processing event", (Throwable)e);
                }
            }
            catch (InterruptedException e) {
                break;
            }
            finally {
                this.shutdownLatch.countDown();
            }
        }
    }

    protected void logSuccess(LoginEvent event) {
        logger.warn((Object)("login success for user " + event.username + " from ip " + event.ip));
    }

    protected void logFailure(LoginEvent event) {
        logger.warn((Object)("login failure for user " + event.username + " from ip " + event.ip));
        ++this.failures;
        long delta = 0L;
        if (this.lastFailure > 0L) {
            delta = System.currentTimeMillis() - this.lastFailure;
            this.totalTime = delta > (long)this.maxDeltaTimeSeconds * 1000L ? 0L : (this.totalTime += delta);
        }
    }

    public void successfulLogin(RealmModel realm, String username, ClientConnection clientConnection) {
        logger.info((Object)("successful login user: " + username + " from ip " + clientConnection.getRemoteAddr()));
    }

    public void invalidUser(RealmModel realm, String username, ClientConnection clientConnection) {
        logger.warn((Object)("invalid user: " + username + " from ip " + clientConnection.getRemoteAddr()));
    }

    public void failedLogin(RealmModel realm, String username, ClientConnection clientConnection) {
        try {
            FailedLogin event = new FailedLogin(realm.getId(), username, clientConnection.getRemoteAddr());
            this.queue.offer(event);
            event.latch.await(5L, TimeUnit.SECONDS);
        }
        catch (InterruptedException e) {
            // empty catch block
        }
    }

    public boolean isTemporarilyDisabled(RealmModel realm, String username) {
        UsernameLoginFailureModel failure = realm.getUserLoginFailure(username);
        if (failure == null) {
            return false;
        }
        int currTime = (int)(System.currentTimeMillis() / 1000L);
        if (currTime < failure.getFailedLoginNotBefore()) {
            logger.debugv("Current: {0} notBefore: {1}", (Object)currTime, (Object)failure.getFailedLoginNotBefore());
            return true;
        }
        return false;
    }

    public long getFailures() {
        return this.failures;
    }

    public long getLastFailure() {
        return this.lastFailure;
    }

    protected class FailedLogin
    extends LoginEvent {
        protected final CountDownLatch latch;

        public FailedLogin(String realmId, String username, String ip) {
            super(realmId, username, ip);
            this.latch = new CountDownLatch(1);
        }
    }

    protected class ShutdownEvent
    extends LoginEvent {
        public ShutdownEvent() {
            super(null, null, null);
        }
    }

    protected class SuccessfulLogin
    extends LoginEvent {
        public SuccessfulLogin(String realmId, String userId, String ip) {
            super(realmId, userId, ip);
        }
    }

    protected abstract class LoginEvent
    implements Comparable<LoginEvent> {
        protected final String realmId;
        protected final String username;
        protected final String ip;

        protected LoginEvent(String realmId, String username, String ip) {
            this.realmId = realmId;
            this.username = username;
            this.ip = ip;
        }

        @Override
        public int compareTo(LoginEvent o) {
            return this.username.compareTo(o.username);
        }
    }
}

