/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ratis;

import java.io.IOException;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.EnumMap;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BooleanSupplier;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.IntSupplier;
import java.util.function.Predicate;
import java.util.function.Supplier;
import java.util.stream.StreamSupport;
import org.apache.ratis.BaseTest;
import org.apache.ratis.client.RaftClient;
import org.apache.ratis.proto.RaftProtos;
import org.apache.ratis.protocol.ClientId;
import org.apache.ratis.protocol.Message;
import org.apache.ratis.protocol.RaftClientReply;
import org.apache.ratis.protocol.RaftGroupId;
import org.apache.ratis.protocol.RaftPeerId;
import org.apache.ratis.server.RaftServer;
import org.apache.ratis.server.RaftServerConfigKeys;
import org.apache.ratis.server.impl.BlockRequestHandlingInjection;
import org.apache.ratis.server.impl.DelayLocalExecutionInjection;
import org.apache.ratis.server.impl.MiniRaftCluster;
import org.apache.ratis.server.raftlog.LogEntryHeader;
import org.apache.ratis.server.raftlog.LogProtoUtils;
import org.apache.ratis.server.raftlog.RaftLog;
import org.apache.ratis.server.raftlog.RaftLogBase;
import org.apache.ratis.thirdparty.com.google.common.base.Preconditions;
import org.apache.ratis.thirdparty.com.google.protobuf.ByteString;
import org.apache.ratis.util.AutoCloseableLock;
import org.apache.ratis.util.CollectionUtils;
import org.apache.ratis.util.JavaUtils;
import org.apache.ratis.util.ProtoUtils;
import org.apache.ratis.util.TimeDuration;
import org.junit.Assert;
import org.junit.AssumptionViolatedException;
import org.junit.jupiter.api.Assertions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public interface RaftTestUtil {
    public static final Logger LOG = LoggerFactory.getLogger(RaftTestUtil.class);
    public static final Comparator<RaftProtos.LogEntryProto> LOG_ENTRY_PROTO_COMPARATOR = Comparator.comparing(e -> e.getStateMachineLogEntry().getLogData().asReadOnlyByteBuffer());
    public static final Comparator<SimpleMessage> SIMPLE_MESSAGE_COMPARATOR = Comparator.comparing(m -> m.getContent().asReadOnlyByteBuffer());

    public static Object getDeclaredField(Object obj, String fieldName) {
        Class<?> clazz = obj.getClass();
        try {
            Field f = clazz.getDeclaredField(fieldName);
            f.setAccessible(true);
            return f.get(obj);
        }
        catch (Exception e) {
            throw new IllegalStateException("Failed to get '" + fieldName + "' from " + clazz, e);
        }
    }

    public static RaftServer.Division waitForLeader(MiniRaftCluster cluster) throws InterruptedException {
        return RaftTestUtil.waitForLeader(cluster, null);
    }

    public static RaftServer.Division waitForLeader(MiniRaftCluster cluster, RaftGroupId groupId) throws InterruptedException {
        return RaftTestUtil.waitForLeader(cluster, groupId, true);
    }

    public static RaftServer.Division waitForLeader(MiniRaftCluster cluster, RaftGroupId groupId, boolean expectLeader) throws InterruptedException {
        String name = "waitForLeader-" + groupId + "-(expectLeader? " + expectLeader + ")";
        int numAttempts = expectLeader ? 100 : 10;
        TimeDuration sleepTime = cluster.getTimeoutMax().apply(d -> d * 3L >> 1);
        LOG.info(cluster.printServers(groupId));
        AtomicReference exception = new AtomicReference();
        Runnable handleNoLeaders = () -> {
            throw cluster.newIllegalStateExceptionForNoLeaders(groupId);
        };
        Consumer<List> handleMultipleLeaders = leaders -> {
            IllegalStateException ise = cluster.newIllegalStateExceptionForMultipleLeaders(groupId, (List<RaftServer.Division>)leaders);
            exception.set(ise);
        };
        RaftServer.Division leader = (RaftServer.Division)JavaUtils.attempt(i -> {
            try {
                RaftServer.Division l = cluster.getLeader(groupId, handleNoLeaders, handleMultipleLeaders);
                if (l != null && !l.getInfo().isLeaderReady()) {
                    throw new IllegalStateException("Leader: " + l.getMemberId() + " not ready");
                }
                return l;
            }
            catch (Exception e) {
                LOG.warn("Attempt #{} failed: " + e, i);
                throw e;
            }
        }, (int)numAttempts, (TimeDuration)sleepTime, () -> name, null);
        LOG.info(cluster.printServers(groupId));
        if (expectLeader) {
            return Optional.ofNullable(leader).orElseThrow(exception::get);
        }
        if (leader == null) {
            return null;
        }
        throw new IllegalStateException("expectLeader = " + expectLeader + " but leader = " + leader);
    }

    public static RaftPeerId waitAndKillLeader(MiniRaftCluster cluster) throws InterruptedException {
        RaftServer.Division leader = RaftTestUtil.waitForLeader(cluster);
        Assert.assertNotNull((Object)leader);
        LOG.info("killing leader = " + leader);
        cluster.killServer(leader.getId());
        return leader.getId();
    }

    public static void waitFor(Supplier<Boolean> check, int checkEveryMillis, int waitForMillis) throws TimeoutException, InterruptedException {
        Preconditions.checkNotNull(check);
        Preconditions.checkArgument((waitForMillis >= checkEveryMillis ? 1 : 0) != 0);
        long st = System.currentTimeMillis();
        boolean result = check.get();
        while (!result && System.currentTimeMillis() - st < (long)waitForMillis) {
            Thread.sleep(checkEveryMillis);
            result = check.get();
        }
        if (!result) {
            throw new TimeoutException("Timed out waiting for condition.");
        }
    }

    public static boolean logEntriesContains(RaftLog log, SimpleMessage ... expectedMessages) {
        return RaftTestUtil.logEntriesContains(log, 0L, Long.MAX_VALUE, expectedMessages);
    }

    public static boolean logEntriesNotContains(RaftLog log, SimpleMessage ... expectedMessages) {
        return RaftTestUtil.logEntriesNotContains(log, 0L, Long.MAX_VALUE, expectedMessages);
    }

    public static boolean logEntriesContains(RaftLog log, long startIndex, long endIndex, SimpleMessage ... expectedMessages) {
        int idxExpected = 0;
        LogEntryHeader[] termIndices = log.getEntries(startIndex, endIndex);
        for (int idxEntries = 0; idxEntries < termIndices.length && idxExpected < expectedMessages.length; ++idxEntries) {
            try {
                if (!Arrays.equals(expectedMessages[idxExpected].getContent().toByteArray(), log.get(termIndices[idxEntries].getIndex()).getStateMachineLogEntry().getLogData().toByteArray())) continue;
                ++idxExpected;
                continue;
            }
            catch (IOException e) {
                throw new RuntimeException(e);
            }
        }
        return idxExpected == expectedMessages.length;
    }

    public static boolean logEntriesNotContains(RaftLog log, long startIndex, long endIndex, SimpleMessage ... expectedMessages) {
        int idxEntries = 0;
        LogEntryHeader[] termIndices = log.getEntries(startIndex, endIndex);
        for (int idxExpected = 0; idxEntries < termIndices.length && idxExpected < expectedMessages.length; ++idxExpected, ++idxEntries) {
            try {
                if (!Arrays.equals(expectedMessages[idxExpected].getContent().toByteArray(), log.get(termIndices[idxEntries].getIndex()).getStateMachineLogEntry().getLogData().toByteArray())) continue;
                return false;
            }
            catch (IOException e) {
                throw new RuntimeException(e);
            }
        }
        return true;
    }

    public static void checkLogEntries(RaftLog log, SimpleMessage[] expectedMessages, Predicate<RaftProtos.LogEntryProto> predicate) {
        LogEntryHeader[] termIndices = log.getEntries(0L, Long.MAX_VALUE);
        for (int i = 0; i < termIndices.length; ++i) {
            for (int j = 0; j < expectedMessages.length; ++j) {
                try {
                    RaftProtos.LogEntryProto e = log.get(termIndices[i].getIndex());
                    if (!Arrays.equals(expectedMessages[j].getContent().toByteArray(), e.getStateMachineLogEntry().getLogData().toByteArray())) continue;
                    Assert.assertTrue((boolean)predicate.test(e));
                    continue;
                }
                catch (IOException exception) {
                    exception.printStackTrace();
                }
            }
        }
    }

    public static void assertLogEntries(MiniRaftCluster cluster, SimpleMessage[] expectedMessages) {
        for (SimpleMessage m : expectedMessages) {
            RaftTestUtil.assertLogEntries(cluster, m);
        }
    }

    public static void assertLogEntries(MiniRaftCluster cluster, SimpleMessage expectedMessage) {
        int size = cluster.getNumServers();
        long count = cluster.getServerAliveStream().map(RaftServer.Division::getRaftLog).filter(log -> RaftTestUtil.logEntriesContains(log, expectedMessage)).count();
        if (2L * count <= (long)size) {
            throw new AssertionError((Object)("Not in majority: size=" + size + " but count=" + count));
        }
    }

    public static void assertLogEntries(RaftServer.Division server, long expectedTerm, SimpleMessage[] expectedMessages, int numAttempts, Logger log) throws Exception {
        String name = server.getId() + " assertLogEntries";
        Function<Integer, Consumer> print = i -> i < numAttempts ? s -> {} : System.out::println;
        JavaUtils.attempt(i -> RaftTestUtil.assertLogEntries(server.getRaftLog(), expectedTerm, expectedMessages, (Consumer)print.apply((Integer)i)), (int)numAttempts, (TimeDuration)TimeDuration.ONE_SECOND, () -> name, (Logger)log);
    }

    public static Iterable<RaftProtos.LogEntryProto> getLogEntryProtos(RaftLog log) {
        return CollectionUtils.as((Object[])log.getEntries(0L, log.getLastEntryTermIndex().getIndex() + 1L), ti -> {
            try {
                return log.get(ti.getIndex());
            }
            catch (IOException exception) {
                throw new AssertionError("Failed to get log at " + ti, exception);
            }
        });
    }

    public static Map<Integer, RaftProtos.LogEntryProto> getStateMachineLogEntries(RaftLog log, SimpleMessage[] messages) {
        if (messages.length == 0) {
            return Collections.emptyMap();
        }
        List<RaftProtos.LogEntryProto> entries = RaftTestUtil.getStateMachineLogEntries(log, s -> {});
        if (entries.isEmpty()) {
            return Collections.emptyMap();
        }
        entries.sort(LOG_ENTRY_PROTO_COMPARATOR);
        Arrays.sort(messages, SIMPLE_MESSAGE_COMPARATOR);
        HashMap<Integer, RaftProtos.LogEntryProto> found = new HashMap<Integer, RaftProtos.LogEntryProto>();
        int e = 0;
        int m = 0;
        while (e < entries.size() && m < messages.length) {
            int diff = messages[m].getContent().asReadOnlyByteBuffer().compareTo(entries.get(e).getStateMachineLogEntry().getLogData().asReadOnlyByteBuffer());
            if (diff == 0) {
                found.put(m, entries.get(e));
                ++m;
                ++e;
                continue;
            }
            if (diff < 0) {
                ++m;
                continue;
            }
            ++e;
        }
        Assertions.assertEquals((int)messages.length, (int)found.size());
        return found;
    }

    public static List<RaftProtos.LogEntryProto> getStateMachineLogEntries(RaftLog log, Consumer<String> print) {
        ArrayList<RaftProtos.LogEntryProto> entries = new ArrayList<RaftProtos.LogEntryProto>();
        for (RaftProtos.LogEntryProto e : RaftTestUtil.getLogEntryProtos(log)) {
            String s = LogProtoUtils.toLogEntryString((RaftProtos.LogEntryProto)e);
            if (e.hasStateMachineLogEntry()) {
                print.accept(entries.size() + ") " + s);
                entries.add(e);
                continue;
            }
            if (e.hasConfigurationEntry()) {
                print.accept("Ignoring " + s);
                continue;
            }
            if (e.hasMetadataEntry()) {
                print.accept("Ignoring " + s);
                continue;
            }
            throw new AssertionError((Object)("Unexpected LogEntryBodyCase " + e.getLogEntryBodyCase() + " at " + s));
        }
        return entries;
    }

    public static Void assertLogEntries(RaftLog log, long expectedTerm, SimpleMessage[] expectedMessages, Consumer<String> print) {
        List<RaftProtos.LogEntryProto> entries = RaftTestUtil.getStateMachineLogEntries(log, print);
        try {
            RaftTestUtil.assertLogEntries(entries, expectedTerm, expectedMessages);
        }
        catch (Exception t) {
            throw new AssertionError("entries: " + entries, t);
        }
        return null;
    }

    public static void assertLogEntries(List<RaftProtos.LogEntryProto> entries, long expectedTerm, SimpleMessage ... expectedMessages) {
        long logIndex = 0L;
        Assert.assertEquals((long)expectedMessages.length, (long)entries.size());
        for (int i = 0; i < expectedMessages.length; ++i) {
            RaftProtos.LogEntryProto e = entries.get(i);
            Assert.assertTrue((e.getTerm() >= expectedTerm ? 1 : 0) != 0);
            if (e.getTerm() > expectedTerm) {
                expectedTerm = e.getTerm();
            }
            Assert.assertTrue((e.getIndex() > logIndex ? 1 : 0) != 0);
            logIndex = e.getIndex();
            Assert.assertEquals((Object)expectedMessages[i].getContent(), (Object)e.getStateMachineLogEntry().getLogData());
        }
    }

    public static void block(BooleanSupplier isBlocked) throws InterruptedException {
        while (isBlocked.getAsBoolean()) {
            RaftServerConfigKeys.Rpc.TIMEOUT_MAX_DEFAULT.sleep();
        }
    }

    public static void delay(IntSupplier getDelayMs) throws InterruptedException {
        int t = getDelayMs.getAsInt();
        if (t > 0) {
            Thread.sleep(t);
        }
    }

    public static RaftPeerId changeLeader(MiniRaftCluster cluster, RaftPeerId oldLeader) throws Exception {
        return RaftTestUtil.changeLeader(cluster, oldLeader, AssumptionViolatedException::new);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public static RaftPeerId changeLeader(MiniRaftCluster cluster, RaftPeerId oldLeader, Function<String, Exception> constructor) throws Exception {
        String name = JavaUtils.getCallerStackTraceElement().getMethodName() + "-changeLeader";
        cluster.setBlockRequestsFrom(oldLeader.toString(), true);
        try {
            RaftPeerId raftPeerId = (RaftPeerId)JavaUtils.attemptRepeatedly(() -> {
                RaftPeerId newLeader = RaftTestUtil.waitForLeader(cluster).getId();
                if (newLeader.equals((Object)oldLeader)) {
                    throw (Exception)constructor.apply("Failed to change leader: newLeader == oldLeader == " + oldLeader);
                }
                LOG.info("Changed leader from " + oldLeader + " to " + newLeader);
                return newLeader;
            }, (int)20, (TimeDuration)BaseTest.HUNDRED_MILLIS, (String)name, (Logger)LOG);
            return raftPeerId;
        }
        finally {
            cluster.setBlockRequestsFrom(oldLeader.toString(), false);
        }
    }

    public static void blockQueueAndSetDelay(Iterable<RaftServer> servers, DelayLocalExecutionInjection injection, String leaderId, int delayMs, TimeDuration maxTimeout) throws InterruptedException {
        boolean block = delayMs > 0;
        LOG.debug("{} requests sent to leader {} and set {}ms delay for the others", new Object[]{block ? "Block" : "Unblock", leaderId, delayMs});
        if (block) {
            BlockRequestHandlingInjection.getInstance().blockReplier(leaderId);
        } else {
            BlockRequestHandlingInjection.getInstance().unblockReplier(leaderId);
        }
        StreamSupport.stream(servers.spliterator(), false).filter(s -> !s.getId().toString().equals(leaderId)).forEach(s -> {
            if (block) {
                injection.setDelayMs(s.getId().toString(), delayMs);
            } else {
                injection.removeDelay(s.getId().toString());
            }
        });
        Thread.sleep(3L * maxTimeout.toLong(TimeUnit.MILLISECONDS));
    }

    public static void isolate(MiniRaftCluster cluster, RaftPeerId id) {
        try {
            BlockRequestHandlingInjection.getInstance().blockReplier(id.toString());
            cluster.setBlockRequestsFrom(id.toString(), true);
        }
        catch (Exception e) {
            e.printStackTrace();
        }
    }

    public static void deIsolate(MiniRaftCluster cluster, RaftPeerId id) {
        BlockRequestHandlingInjection.getInstance().unblockReplier(id.toString());
        cluster.setBlockRequestsFrom(id.toString(), false);
    }

    public static Thread sendMessageInNewThread(MiniRaftCluster cluster, RaftPeerId leaderId, SimpleMessage ... messages) {
        Thread t = new Thread(() -> {
            try (RaftClient client = cluster.createClient(leaderId);){
                for (SimpleMessage mssg : messages) {
                    client.io().send((Message)mssg);
                }
            }
            catch (Exception e) {
                e.printStackTrace();
            }
        });
        t.start();
        return t;
    }

    public static void assertSameLog(RaftLog expected, RaftLog computed) throws Exception {
        Assert.assertEquals((Object)expected.getLastEntryTermIndex(), (Object)computed.getLastEntryTermIndex());
        long lastIndex = expected.getNextIndex() - 1L;
        Assert.assertEquals((long)expected.getLastEntryTermIndex().getIndex(), (long)lastIndex);
        for (long i = 0L; i < lastIndex; ++i) {
            Assert.assertEquals((Object)expected.get(i), (Object)computed.get(i));
        }
    }

    public static EnumMap<RaftProtos.LogEntryProto.LogEntryBodyCase, AtomicLong> countEntries(RaftLog raftLog) throws Exception {
        EnumMap<RaftProtos.LogEntryProto.LogEntryBodyCase, AtomicLong> counts = new EnumMap<RaftProtos.LogEntryProto.LogEntryBodyCase, AtomicLong>(RaftProtos.LogEntryProto.LogEntryBodyCase.class);
        for (long i = 0L; i < raftLog.getNextIndex(); ++i) {
            RaftProtos.LogEntryProto e = raftLog.get(i);
            counts.computeIfAbsent(e.getLogEntryBodyCase(), c -> new AtomicLong()).incrementAndGet();
        }
        return counts;
    }

    public static RaftProtos.LogEntryProto getLastEntry(RaftProtos.LogEntryProto.LogEntryBodyCase targetCase, RaftLog raftLog) throws Exception {
        try (AutoCloseableLock readLock = ((RaftLogBase)raftLog).readLock();){
            for (long i = raftLog.getNextIndex() - 1L; i >= 0L; --i) {
                RaftProtos.LogEntryProto entry = raftLog.get(i);
                if (entry.getLogEntryBodyCase() != targetCase) continue;
                RaftProtos.LogEntryProto logEntryProto = entry;
                return logEntryProto;
            }
        }
        return null;
    }

    public static void assertSuccessReply(CompletableFuture<RaftClientReply> reply) throws Exception {
        RaftTestUtil.assertSuccessReply(reply.get(10L, TimeUnit.SECONDS));
    }

    public static void assertSuccessReply(RaftClientReply reply) {
        Assert.assertNotNull((String)"reply == null", (Object)reply);
        Assert.assertTrue((String)("reply is not success: " + reply), (boolean)reply.isSuccess());
    }

    public static class SimpleOperation {
        private static final ClientId CLIENT_ID = ClientId.randomId();
        private static final AtomicLong CALL_ID = new AtomicLong();
        private final String op;
        private final RaftProtos.StateMachineLogEntryProto smLogEntryProto;

        public SimpleOperation(String op) {
            this(op, false);
        }

        public SimpleOperation(String op, boolean hasStateMachineData) {
            this(CLIENT_ID, CALL_ID.incrementAndGet(), op, hasStateMachineData);
        }

        private SimpleOperation(ClientId clientId, long callId, String op, boolean hasStateMachineData) {
            this.op = Objects.requireNonNull(op);
            ByteString bytes = ProtoUtils.toByteString((String)op);
            this.smLogEntryProto = LogProtoUtils.toStateMachineLogEntryProto((ClientId)clientId, (long)callId, (RaftProtos.StateMachineLogEntryProto.Type)RaftProtos.StateMachineLogEntryProto.Type.WRITE, (ByteString)bytes, (ByteString)(hasStateMachineData ? bytes : null));
        }

        public String toString() {
            return this.op;
        }

        public boolean equals(Object obj) {
            return obj == this || obj instanceof SimpleOperation && ((SimpleOperation)obj).op.equals(this.op);
        }

        public int hashCode() {
            return this.op.hashCode();
        }

        public RaftProtos.StateMachineLogEntryProto getLogEntryContent() {
            return this.smLogEntryProto;
        }
    }

    public static class SimpleMessage
    implements Message {
        final String messageId;
        final ByteString bytes;

        public static SimpleMessage[] create(int numMessages) {
            return SimpleMessage.create(numMessages, "m");
        }

        public static SimpleMessage[] create(int numMessages, String prefix) {
            SimpleMessage[] messages = new SimpleMessage[numMessages];
            for (int i = 0; i < messages.length; ++i) {
                messages[i] = new SimpleMessage(prefix + i);
            }
            return messages;
        }

        public SimpleMessage(String messageId) {
            this(messageId, ProtoUtils.toByteString((String)messageId));
        }

        public SimpleMessage(String messageId, ByteString bytes) {
            this.messageId = messageId;
            this.bytes = bytes;
        }

        public String toString() {
            return this.messageId;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (obj == null || !(obj instanceof SimpleMessage)) {
                return false;
            }
            SimpleMessage that = (SimpleMessage)obj;
            return this.messageId.equals(that.messageId);
        }

        public int hashCode() {
            return this.messageId.hashCode();
        }

        public ByteString getContent() {
            return this.bytes;
        }
    }
}

