/*
 * Decompiled with CFR 0.152.
 */
package org.axonframework.messaging.unitofwork;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import org.axonframework.messaging.GenericMessage;
import org.axonframework.messaging.GenericResultMessage;
import org.axonframework.messaging.Message;
import org.axonframework.messaging.ResultMessage;
import org.axonframework.messaging.unitofwork.BatchingUnitOfWork;
import org.axonframework.messaging.unitofwork.ExecutionResult;
import org.axonframework.messaging.unitofwork.UnitOfWork;
import org.axonframework.utils.MockException;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

class BatchingUnitOfWorkTest {
    private List<PhaseTransition> transitions;
    private BatchingUnitOfWork<?> subject;

    BatchingUnitOfWorkTest() {
    }

    @BeforeEach
    void setUp() {
        this.transitions = new ArrayList<PhaseTransition>();
    }

    @Test
    void executeTask() throws Exception {
        List<Message<?>> messages = Arrays.asList(BatchingUnitOfWorkTest.toMessage(0), BatchingUnitOfWorkTest.toMessage(1), BatchingUnitOfWorkTest.toMessage(2));
        this.subject = new BatchingUnitOfWork(messages);
        this.subject.executeWithResult(() -> {
            this.registerListeners((UnitOfWork<?>)this.subject);
            return BatchingUnitOfWorkTest.resultFor(this.subject.getMessage());
        });
        this.validatePhaseTransitions(Arrays.asList(UnitOfWork.Phase.PREPARE_COMMIT, UnitOfWork.Phase.COMMIT, UnitOfWork.Phase.AFTER_COMMIT, UnitOfWork.Phase.CLEANUP), messages);
        HashMap expectedResults = new HashMap();
        messages.forEach(m -> expectedResults.put((Message<?>)m, new ExecutionResult(GenericResultMessage.asResultMessage((Object)BatchingUnitOfWorkTest.resultFor(m)))));
        this.assertExecutionResults(expectedResults, this.subject.getExecutionResults());
    }

    @Test
    void rollback() {
        List<Message> messages = Arrays.asList(BatchingUnitOfWorkTest.toMessage(0), BatchingUnitOfWorkTest.toMessage(1), BatchingUnitOfWorkTest.toMessage(2));
        this.subject = new BatchingUnitOfWork(messages);
        MockException e = new MockException();
        try {
            this.subject.executeWithResult(() -> {
                this.registerListeners((UnitOfWork<?>)this.subject);
                if (this.subject.getMessage().getPayload().equals(1)) {
                    throw e;
                }
                return BatchingUnitOfWorkTest.resultFor(this.subject.getMessage());
            });
        }
        catch (Exception exception) {
            // empty catch block
        }
        this.validatePhaseTransitions(Arrays.asList(UnitOfWork.Phase.ROLLBACK, UnitOfWork.Phase.CLEANUP), messages.subList(0, 2));
        HashMap expectedResult = new HashMap();
        messages.forEach(m -> expectedResult.put((Message<?>)m, new ExecutionResult(GenericResultMessage.asResultMessage((Throwable)e))));
        this.assertExecutionResults(expectedResult, this.subject.getExecutionResults());
    }

    @Test
    void suppressedExceptionOnRollback() {
        List<Message<?>> messages = Arrays.asList(BatchingUnitOfWorkTest.toMessage(0), BatchingUnitOfWorkTest.toMessage(1), BatchingUnitOfWorkTest.toMessage(2));
        AtomicInteger cleanupCounter = new AtomicInteger();
        this.subject = new BatchingUnitOfWork(messages);
        MockException taskException = new MockException("task exception");
        MockException commitException = new MockException("commit exception");
        MockException cleanupException = new MockException("cleanup exception");
        this.subject.onCleanup(u -> cleanupCounter.incrementAndGet());
        this.subject.onCleanup(u -> {
            throw cleanupException;
        });
        this.subject.onCleanup(u -> cleanupCounter.incrementAndGet());
        try {
            this.subject.executeWithResult(() -> {
                this.registerListeners((UnitOfWork<?>)this.subject);
                if (this.subject.getMessage().getPayload().equals(2)) {
                    this.subject.addHandler(UnitOfWork.Phase.PREPARE_COMMIT, u -> {
                        throw commitException;
                    });
                    throw taskException;
                }
                return BatchingUnitOfWorkTest.resultFor(this.subject.getMessage());
            }, e -> false);
        }
        catch (Exception exception) {
            // empty catch block
        }
        this.validatePhaseTransitions(Arrays.asList(UnitOfWork.Phase.PREPARE_COMMIT, UnitOfWork.Phase.ROLLBACK, UnitOfWork.Phase.CLEANUP), messages);
        HashMap expectedResult = new HashMap();
        expectedResult.put(messages.get(0), new ExecutionResult(GenericResultMessage.asResultMessage((Throwable)commitException)));
        expectedResult.put(messages.get(1), new ExecutionResult(GenericResultMessage.asResultMessage((Throwable)commitException)));
        expectedResult.put(messages.get(2), new ExecutionResult(GenericResultMessage.asResultMessage((Throwable)taskException)));
        this.assertExecutionResults(expectedResult, this.subject.getExecutionResults());
        Assertions.assertSame((Object)commitException, (Object)taskException.getSuppressed()[0]);
        Assertions.assertEquals((int)2, (int)cleanupCounter.get());
    }

    private void registerListeners(UnitOfWork<?> unitOfWork) {
        unitOfWork.onPrepareCommit(u -> this.transitions.add(new PhaseTransition(u.getMessage(), UnitOfWork.Phase.PREPARE_COMMIT)));
        unitOfWork.onCommit(u -> this.transitions.add(new PhaseTransition(u.getMessage(), UnitOfWork.Phase.COMMIT)));
        unitOfWork.afterCommit(u -> this.transitions.add(new PhaseTransition(u.getMessage(), UnitOfWork.Phase.AFTER_COMMIT)));
        unitOfWork.onRollback(u -> this.transitions.add(new PhaseTransition(u.getMessage(), UnitOfWork.Phase.ROLLBACK)));
        unitOfWork.onCleanup(u -> this.transitions.add(new PhaseTransition(u.getMessage(), UnitOfWork.Phase.CLEANUP)));
    }

    private static Message<?> toMessage(Object payload) {
        return new GenericMessage(payload);
    }

    public static Object resultFor(Message<?> message) {
        return "Result for: " + message.getPayload();
    }

    private void validatePhaseTransitions(List<UnitOfWork.Phase> phases, List<Message<?>> messages) {
        Iterator<PhaseTransition> iterator = this.transitions.iterator();
        for (UnitOfWork.Phase phase : phases) {
            Iterator<Message<?>> messageIterator = phase.isReverseCallbackOrder() ? new LinkedList(messages).descendingIterator() : messages.iterator();
            messageIterator.forEachRemaining(message -> {
                PhaseTransition expected = new PhaseTransition((Message<?>)message, phase);
                Assertions.assertTrue((boolean)iterator.hasNext());
                PhaseTransition actual = (PhaseTransition)iterator.next();
                Assertions.assertEquals((Object)expected, (Object)actual);
            });
        }
    }

    private void assertExecutionResults(Map<Message<?>, ExecutionResult> expected, Map<Message<?>, ExecutionResult> actual) {
        Assertions.assertEquals(expected.keySet(), actual.keySet());
        List expectedMessages = expected.values().stream().map(ExecutionResult::getResult).collect(Collectors.toList());
        List actualMessages = actual.values().stream().map(ExecutionResult::getResult).collect(Collectors.toList());
        List expectedPayloads = expectedMessages.stream().filter(crm -> !crm.isExceptional()).map(Message::getPayload).collect(Collectors.toList());
        List actualPayloads = actualMessages.stream().filter(crm -> !crm.isExceptional()).map(Message::getPayload).collect(Collectors.toList());
        List expectedExceptions = expectedMessages.stream().filter(ResultMessage::isExceptional).map(ResultMessage::exceptionResult).collect(Collectors.toList());
        List actualExceptions = actualMessages.stream().filter(ResultMessage::isExceptional).map(ResultMessage::exceptionResult).collect(Collectors.toList());
        List expectedMetaData = expectedMessages.stream().map(Message::getMetaData).collect(Collectors.toList());
        List actualMetaData = actualMessages.stream().map(Message::getMetaData).collect(Collectors.toList());
        Assertions.assertEquals((int)expectedPayloads.size(), (int)actualPayloads.size());
        Assertions.assertTrue((boolean)expectedPayloads.containsAll(actualPayloads));
        Assertions.assertEquals((int)expectedExceptions.size(), (int)actualExceptions.size());
        Assertions.assertTrue((boolean)expectedExceptions.containsAll(actualExceptions));
        Assertions.assertTrue((boolean)expectedMetaData.containsAll(actualMetaData));
    }

    private static class PhaseTransition {
        private final UnitOfWork.Phase phase;
        private final Message<?> message;

        public PhaseTransition(Message<?> message, UnitOfWork.Phase phase) {
            this.message = message;
            this.phase = phase;
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            PhaseTransition that = (PhaseTransition)o;
            return this.phase == that.phase && Objects.equals(this.message, that.message);
        }

        public int hashCode() {
            return Objects.hash(this.phase, this.message);
        }

        public String toString() {
            return this.phase + " -> " + this.message.getPayload();
        }
    }
}

