/*
 * Copyright (c) MuleSoft, Inc.  All rights reserved.  http://www.mulesoft.com
 * The software in this package is published under the terms of the CPAL v1.0
 * license, a copy of which has been included with this distribution in the
 * LICENSE.txt file.
 */
package org.mule.tests.internal;

import static org.mule.runtime.api.i18n.I18nMessageFactory.createStaticMessage;
import static org.mule.runtime.api.profiling.type.RuntimeProfilingEventTypes.TX_COMMIT;
import static org.mule.runtime.api.profiling.type.RuntimeProfilingEventTypes.TX_CONTINUE;
import static org.mule.runtime.api.profiling.type.RuntimeProfilingEventTypes.TX_ROLLBACK;
import static org.mule.runtime.api.profiling.type.RuntimeProfilingEventTypes.TX_START;
import static org.mule.runtime.core.api.transaction.TransactionCoordination.isTransactionActive;
import static org.mule.runtime.extension.api.annotation.param.MediaType.ANY;
import static java.lang.String.format;
import static java.lang.Thread.currentThread;

import org.mule.runtime.api.profiling.ProfilingDataConsumer;
import org.mule.runtime.api.profiling.type.ProfilingEventType;
import org.mule.runtime.api.profiling.type.context.TransactionProfilingEventContext;
import org.mule.runtime.core.privileged.profiling.PrivilegedProfilingService;
import org.mule.runtime.extension.api.annotation.param.MediaType;

import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.function.Predicate;
import javax.inject.Inject;


/**
 * This class is a container for operations, every public method in this class will be taken as an extension operation.
 */
public class TransactionProfilingConsumptionOperations {

    @Inject
    private PrivilegedProfilingService service;
    private static boolean alreadyAddedConsumers = false;

    // We keep the states in a map, because in case the transactions are created in a parallel scope
    // (e.g. ParallelForeach), the entries would be messed. Since a tx will always run in the same thread,
    // in such thread the states will be sorted by time of appearance. If there is no parallel execution,
    // the map will have just one entry.
    private static final Map<String, List<String>> obtainedStates = new HashMap<>();
    private static final ArrayList<Boolean> obtainedActive = new ArrayList<>();
    private static final ArrayList<String> expectedStates = new ArrayList<>();
    private static final ArrayList<Boolean> expectedActive = new ArrayList<>();
    private static final ArrayList<String> expectedAt = new ArrayList<>();


    /**
     * Clears the consumed {@link ProfilingEventType} for transactions, and registers the {@link ProfilingDataConsumer}
     * to get Transaction Profiling Events.
     */
    @MediaType(value = ANY, strict = false)
    public void startConsumingTransactions() {
        clear();
        if (alreadyAddedConsumers) {
            return;
        }
        alreadyAddedConsumers = true;
        service.registerProfilingDataConsumer(new ProfilingDataConsumer<TransactionProfilingEventContext>() {
            @Override
            public void onProfilingEvent(ProfilingEventType<TransactionProfilingEventContext> profilingEventType, TransactionProfilingEventContext profilingEventContext) {
                synchronized (obtainedStates) {
                    if (!obtainedStates.containsKey(currentThread().toString())) {
                        obtainedStates.put(currentThread().toString(), new LinkedList<>());
                    }
                    obtainedStates.get(currentThread().toString()).add(profilingEventType.toString());
                }
            }

            @Override
            public Set<ProfilingEventType<TransactionProfilingEventContext>> getProfilingEventTypes() {
                Set<ProfilingEventType<TransactionProfilingEventContext>> events = new HashSet<>();
                events.add(TX_START);
                events.add(TX_COMMIT);
                events.add(TX_CONTINUE);
                events.add(TX_ROLLBACK);
                return events;
            }

            @Override
            public Predicate<TransactionProfilingEventContext> getEventContextFilter() {
                return tx -> true;
            }
        });
    }

    /**
     * Adds an expectation for the transaction events to come
     * @param active The current transaction state (at the moment of the execution of this operation)
     * @param state The state of the last received event (commit, continue, rollback, start)
     */
    @MediaType(value = ANY, strict = false)
    public void addTransactionExpectation(boolean active, String state) {
        synchronized (expectedStates) {
            expectedActive.add(active);
            expectedStates.add(state);
            obtainedActive.add(isTransactionActive());
            expectedAt.add(currentThread().toString());
        }
    }


    private static void clear() {
        obtainedStates.clear();
        obtainedActive.clear();
        expectedStates.clear();
        expectedActive.clear();
        expectedAt.clear();
    }

    private String formattedState(boolean active, String state) {
        return format("(active: %b, state: %s)", active, state);
    }

    /**
     * Performs assertions between the expected states of transactions, and the actual received ones.
     * The expected and obtained states must be the same amount, and they should be coincide in "active" and "status"
     * in the order it was defined.
     * To avoid problems with parallel executions (e.g. Try within ParallelForeach), we check this order by incoming
     * {@link Thread}.
     */
    @MediaType(value = ANY, strict = false)
    public void assertTransactions() {
        int totalStates = obtainedStates.keySet().stream().map(thread -> obtainedStates.get(thread).size())
                .reduce(0, Integer::sum);
        if (expectedStates.size() != totalStates) {
            throw new AssertionError(createStaticMessage(format("Expected %d results, obtained %d",
                    expectedStates.size(), totalStates)));
        }
        for (int i = 0; i < expectedStates.size(); i++) {
            String runningThread = expectedAt.get(i);
            List<String> eventsInThread = obtainedStates.get(runningThread);
            String nextState = eventsInThread.remove(0);
            if (expectedActive.get(i) != obtainedActive.get(i) || !expectedStates.get(i).equals(nextState)) {
                throw new AssertionError(createStaticMessage(
                        format("Error in transaction expectations (number %d). Expected %s, obtained %s", i + 1,
                                formattedState(expectedActive.get(i), expectedStates.get(i)),
                                formattedState(obtainedActive.get(i), nextState)))
                );
            }
        }
    }

}
