// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

package com.microsoft.azure.javamsalruntime;

import com.sun.jna.Native;
import com.sun.jna.Platform;
import com.sun.jna.Pointer;
import com.sun.jna.WString;
import com.sun.jna.platform.win32.User32;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.sun.jna.platform.win32.Kernel32;
import java.util.Map.Entry;

/**
 * A set of APIs intended to map more-or-less directly MSALRuntime's APIs, particularly those
 * related to signing in/signing out users, acquiring tokens, and getting account information
 *
 * All of MsalRuntimeInterop's APIs are non-static, therefore an MsalRuntimeInterop instance
 * must be created before any calls to MSALRuntime are made
 *
 * This class has a static block that will be called when the first MSALRuntimeInterop instance is
 * created, ensuring that any required initialization steps are performed before the first call to
 * any MSALRuntime API
 */
public class MsalRuntimeInterop {
    private static final Logger LOG = LoggerFactory.getLogger(MsalRuntimeInterop.class);
    private static final String MSALRUNTIME_DLL_PATH = "external-dlls/MSALRuntimeDlls/";

    // This static instance of MsalRuntimeLibrary helps to ensure that startup is called on a
    // per-process basis, and to simplify the many calls to MsalRuntimeLibrary's API
    public static final MsalRuntimeLibrary MSALRUNTIME_LIBRARY;

    // Used to simplify the many calls made to ErrorHandler during error checking and exception
    // handling
    public static final ErrorHelper ERROR_HELPER;

    private static LogCallbackHandle logCallbackHandle;
    private static Callbacks.LogCallback logCallback = new Callbacks.LogCallback();

    static {
        LOG.info("Setting up MSALRuntime.");
        MSALRUNTIME_LIBRARY = loadMsalRuntimeLibrary();
        ERROR_HELPER = new ErrorHelper();

        // Add shutdown hook to call the MSALRuntime shutdown API when the JVM process exits
        Runtime.getRuntime().addShutdownHook(new Thread(MsalRuntimeInterop::shutdownMsalRuntime));
    }

    /**
     * Calls MSALRuntime's startup API
     */
    public void startupMsalRuntime() {
        ERROR_HELPER.checkMsalRuntimeError(MSALRUNTIME_LIBRARY.MSALRUNTIME_Startup());
        LOG.info("MSALRuntime startup API called successfully.");
    }

    /**
     * Calls MSALRuntime's shutdown API
     * <p>
     * Also completes any remaining futures, and performs any necessary cleanup steps to avoid
     * memory leaks
     */
    public static void shutdownMsalRuntime() {
        LOG.info("Shutting down MSALRuntime.");

        for (Entry<Integer, MsalRuntimeFuture> entry :
             MsalRuntimeFuture.msalRuntimeFutures.entrySet()) {
            entry.getValue().completeExceptionally(new MsalInteropException(
                    "MSALRuntime shutdown API called before operation could complete",
                    "msalruntime_shutdown"));
            entry.getValue().handle.release();
            MsalRuntimeFuture.msalRuntimeFutures.entrySet().remove(entry);
        }

        if (logCallbackHandle != null)
            logCallbackHandle.release();

        // MSALRUNTIME_Shutdown doesn't return an error handle, nothing for us to check
        MSALRUNTIME_LIBRARY.MSALRUNTIME_Shutdown();
        LOG.info("MSALRuntime shutdown API called successfully.");
    }

    /**
     * Retrieves any cached account information for a given account ID
     *
     * @param accountId     the ID of the account we will retrieve data for
     * @param correlationId unique ID used to identify a certain request throughout various
     *         telemetry and logs
     * @return an AsyncHandler instance, which can be treated as a CompletableFuture
     */
    public MsalRuntimeFuture readAccountById(String accountId, String correlationId) {
        MsalRuntimeFuture msalRuntimeFuture =
                new MsalRuntimeFuture(new Callbacks.ReadAccountResultCallback());

        msalRuntimeFuture.callback = new Callbacks.ReadAccountResultCallback();

        ERROR_HELPER.checkMsalRuntimeError(MSALRUNTIME_LIBRARY.MSALRUNTIME_ReadAccountByIdAsync(
                new WString(accountId), new WString(correlationId),
                (Callbacks.ReadAccountResultCallback)msalRuntimeFuture.callback,
                msalRuntimeFuture.msalRuntimeFuturesKey, msalRuntimeFuture.handle));

        return msalRuntimeFuture;
    }

    /**
     * Calls MSALRuntime's SignIn API, which will attempt a silent sign in and fall back to
     * interactive if needed <p> This API essentially combines the behavior signInSilently and
     * signInInteractively
     *
     * @param windowHandle   the parent window handle that will be used to coordinate UI elements
     *         shown to the user
     * @param authParameters a number of parameters to be used in this request
     * @param correlationId  unique ID used to identify a certain request throughout various
     *         telemetry and logs
     * @param loginHint      a login hint (such as a username) that may be shown to the user
     * @return an AsyncHandler instance, which can be treated as a CompletableFuture
     */
    public MsalRuntimeFuture signIn(
            long windowHandle, AuthParameters authParameters, String correlationId, String loginHint) {
        MsalRuntimeFuture msalRuntimeFuture =
                new MsalRuntimeFuture(new Callbacks.AuthResultCallback());

        windowHandle = checkWindowHandle(windowHandle);

        ERROR_HELPER.checkMsalRuntimeError(MSALRUNTIME_LIBRARY.MSALRUNTIME_SignInAsync(
                windowHandle, authParameters.getHandle().value(), new WString(correlationId),
                new WString(loginHint == null ? "" : loginHint),
                (Callbacks.AuthResultCallback)msalRuntimeFuture.callback,
                msalRuntimeFuture.msalRuntimeFuturesKey, msalRuntimeFuture.handle));

        return msalRuntimeFuture;
    }

    /**
     * Calls MSALRuntime's SignInSilently API to attempt a sign in without showing a UI to the user
     *
     * @param authParameters a number of parameters to be used in this request
     * @param correlationId  unique ID used to identify a certain request throughout various
     *         telemetry and logs
     * @return an AsyncHandler instance, which can be treated as a CompletableFuture
     */
    public MsalRuntimeFuture signInSilently(AuthParameters authParameters, String correlationId) {
        MsalRuntimeFuture msalRuntimeFuture =
                new MsalRuntimeFuture(new Callbacks.AuthResultCallback());

        ERROR_HELPER.checkMsalRuntimeError(MSALRUNTIME_LIBRARY.MSALRUNTIME_SignInSilentlyAsync(
                authParameters.getHandle().value(), new WString(correlationId),
                (Callbacks.AuthResultCallback)msalRuntimeFuture.callback,
                msalRuntimeFuture.msalRuntimeFuturesKey, msalRuntimeFuture.handle));

        return msalRuntimeFuture;
    }

    /**
     * Calls MSALRuntime's SignInInteractively API to attempt a sign in by showing a UI to the user
     *
     * @param windowHandle   the parent window handle that will be used to coordinate UI elements
     *         shown to the user
     * @param authParameters a number of parameters to be used in this request
     * @param correlationId  unique ID used to identify a certain request throughout various
     *         telemetry and logs
     * @param loginHint      a login hint (such as a username) that may be shown to the user
     * @return an AsyncHandler instance, which can be treated as a CompletableFuture
     */
    public MsalRuntimeFuture signInInteractively(
            long windowHandle, AuthParameters authParameters, String correlationId, String loginHint) {
        MsalRuntimeFuture msalRuntimeFuture =
                new MsalRuntimeFuture(new Callbacks.AuthResultCallback());

        windowHandle = checkWindowHandle(windowHandle);

        ERROR_HELPER.checkMsalRuntimeError(MSALRUNTIME_LIBRARY.MSALRUNTIME_SignInInteractivelyAsync(
                windowHandle, authParameters.getHandle().value(), new WString(correlationId),
                new WString(loginHint == null ? "" : loginHint),
                (Callbacks.AuthResultCallback)msalRuntimeFuture.callback,
                msalRuntimeFuture.msalRuntimeFuturesKey, msalRuntimeFuture.handle));

        return msalRuntimeFuture;
    }

    /**
     * Calls MSALRuntime's AcquireTokenSilently API to retrieve tokens for a given account, without
     * showing a UI to the user
     *
     * @param authParameters a number of parameters to be used in this request
     * @param correlationId  unique ID used to identify a certain request throughout various
     *         telemetry and logs
     * @param account  a ReadAccountResult instance, which must already be populated with a
     *         ReadAccountResultHandle
     * @return an AsyncHandler instance, which can be treated as a CompletableFuture
     */
    public MsalRuntimeFuture acquireTokenSilently(
            AuthParameters authParameters, String correlationId, Account account) {
        if (account.getHandle() == null) {
            throw new MsalInteropException(
                    "Account handle is null, sign in or account discovery failed. Cannot retrieve tokens.",
                    "msalruntime_account_error");
        }

        MsalRuntimeFuture msalRuntimeFuture =
                new MsalRuntimeFuture(new Callbacks.AuthResultCallback());

        ERROR_HELPER.checkMsalRuntimeError(MSALRUNTIME_LIBRARY.MSALRUNTIME_AcquireTokenSilentlyAsync(
                authParameters.getHandle().value(), new WString(correlationId),
                account.getHandle().value(), (Callbacks.AuthResultCallback)msalRuntimeFuture.callback,
                msalRuntimeFuture.msalRuntimeFuturesKey, msalRuntimeFuture.handle));

        return msalRuntimeFuture;
    }

    /**
     * Calls MSALRuntime's AcquireTokenInteractively API to retrieve tokens for a given account, by
     * showing a UI to the user
     *
     * @param windowHandle   the parent window handle that will be used to coordinate UI elements
     *         shown to the user
     * @param authParameters a number of parameters to be used in this request
     * @param account  a ReadAccountResult instance, which must already be populated with a
     *         ReadAccountResultHandle
     * @param correlationId  unique ID used to identify a certain request throughout various
     *         telemetry and logs
     * @return an AsyncHandler instance, which can be treated as a CompletableFuture
     */
    public MsalRuntimeFuture acquireTokenInteractively(
            long windowHandle, AuthParameters authParameters, String correlationId, Account account) {
        if (account.getHandle() == null) {
            throw new MsalInteropException(
                    "Account handle is null, sign in or account discovery failed. Cannot retrieve tokens.",
                    "msalruntime_account_error");
        }

        MsalRuntimeFuture msalRuntimeFuture =
                new MsalRuntimeFuture(new Callbacks.AuthResultCallback());

        windowHandle = checkWindowHandle(windowHandle);

        ERROR_HELPER.checkMsalRuntimeError(
                MSALRUNTIME_LIBRARY.MSALRUNTIME_AcquireTokenInteractivelyAsync(
                        windowHandle, authParameters.getHandle().value(),
                        new WString(correlationId), account.getHandle().value(),
                        (Callbacks.AuthResultCallback)msalRuntimeFuture.callback,
                        msalRuntimeFuture.msalRuntimeFuturesKey, msalRuntimeFuture.handle));

        return msalRuntimeFuture;
    }

    /**
     * Calls MSALRuntime's SignOut API, which will delete cached tokens for a given account and
     * require this account to perform a new sign in
     *
     * @param clientId client ID used in the call that created the account information
     * @param correlationId unique ID used to identify a certain request throughout various
     *         telemetry and logs
     * @param account an Account object, which must be populated with a valid handle
     * @return an AsyncHandler instance, which can be treated as a CompletableFuture
     */
    public MsalRuntimeFuture signOutSilently(String clientId, String correlationId, Account account) {
        if (account.getHandle() == null) {
            throw new MsalInteropException(
                    "Account handle is null, cannot sign out.", "msalruntime_account_error");
        }

        MsalRuntimeFuture msalRuntimeFuture =
                new MsalRuntimeFuture(new Callbacks.SignOutResultCallback());

        ERROR_HELPER.checkMsalRuntimeError(MSALRUNTIME_LIBRARY.MSALRUNTIME_SignOutSilentlyAsync(
                new WString(clientId), new WString(correlationId), account.getHandle().value(),
                (Callbacks.SignOutResultCallback)msalRuntimeFuture.callback,
                msalRuntimeFuture.msalRuntimeFuturesKey, msalRuntimeFuture.handle));

        return msalRuntimeFuture;
    }

    /**
     * Sets up the necessary callbacks and handles to integrate MSALRuntime logs into MSAL Java's
     * logging framework, allowing detailed logs from MSALRuntime's native code to appear in
     * javamsalruntime's logs
     *
     * By default, detailed logs from MSALRuntime will not be shown, though some MSALRuntime error
     * messages may appear as part of a thrown exception
     *
     * @param enableLogging true enables MSALRuntime logging, false disables it
     */
    public static synchronized void enableLogging(boolean enableLogging) {
        if (enableLogging) {
            // Avoid calling logging APIs if logging is already enabled
            if (logCallbackHandle == null) {
                LogCallbackHandle handle = new LogCallbackHandle();

                ERROR_HELPER.checkMsalRuntimeError(MSALRUNTIME_LIBRARY.MSALRUNTIME_RegisterLogCallback(
                        logCallback, null, handle));

                // Assign the handle to the global variable only after a successful call to
                // RegisterLogCallback
                logCallbackHandle = handle;
            }

        } else {
            // According to comments in MSALRuntime's MSALRuntimeLogging.h, releasing the log
            // callback handle will de-register it
            logCallbackHandle.release();
            logCallbackHandle = null;
        }
    }

    /**
     * Allows PII data to appear in MSALRuntime error messages and logs
     *
     * By default, PII logging is disabled
     *
     * @param enablePIILogging true enables PII logging, false disables it
     */
    public static synchronized void enableLoggingPii(boolean enablePIILogging) {
        // The MSALRUNTIME_SetIsPiiEnabled API enables PII logging if sent a value of '1', otherwise
        // it disables PII logging
        if (enablePIILogging) {
            MSALRUNTIME_LIBRARY.MSALRUNTIME_SetIsPiiEnabled(1);
        } else {
            MSALRUNTIME_LIBRARY.MSALRUNTIME_SetIsPiiEnabled(0);
        }
    }

    long checkWindowHandle(long windowHandle) {
        if (windowHandle == 0) {
            try {
                return Pointer.nativeValue(
                        User32.INSTANCE.GetAncestor(Kernel32.INSTANCE.GetConsoleWindow(), 3)
                                .getPointer());
            } catch (NullPointerException e) {
                throw new MsalInteropException(
                        "Window handle not provided, and could not retrieve console's window handle. Window handles must be provided if the application is not running in a Windows terminal.",
                        "msalruntime_client_error");
            }
        } else {
            return windowHandle;
        }
    }

    /**
     * Performs any checks necessary to identify the system architecture and chose the correct
     * MSALRuntime dll, and uses JNA to associate that dll with our MsalRuntimeLibrary interface
     *
     * @return an MsalRuntimeLibrary instance that can be used to call into a C++ dll from Java
     */
    static MsalRuntimeLibrary loadMsalRuntimeLibrary() {
        try {
            if (Platform.isWindows()) {
                if (Platform.is64Bit()) {
                    if (Platform.isARM()) {
                        System.setProperty(
                                "jna.library.path",
                                MsalRuntimeInterop.class.getClassLoader()
                                        .getResource(MSALRUNTIME_DLL_PATH)
                                        .toString());
                        return Native.load(
                                MSALRUNTIME_DLL_PATH + "msalruntime_arm64.dll",
                                MsalRuntimeLibrary.class);
                    }

                    System.setProperty(
                            "jna.library.path",
                            MsalRuntimeInterop.class.getClassLoader()
                                    .getResource(MSALRUNTIME_DLL_PATH)
                                    .toString());
                    return Native.load(
                            MSALRUNTIME_DLL_PATH + "msalruntime.dll", MsalRuntimeLibrary.class);
                } else {
                    System.setProperty(
                            "jna.library.path",
                            MsalRuntimeInterop.class.getClassLoader()
                                    .getResource(MSALRUNTIME_DLL_PATH)
                                    .toString());
                    return Native.load(
                            MSALRUNTIME_DLL_PATH + "msalruntime_x86.dll", MsalRuntimeLibrary.class);
                }
            } else {
                throw new MsalInteropException(
                        "Could not detect platform, or platform was not supported.",
                        "msalruntime_initialization_error");
            }
        } catch (UnsatisfiedLinkError e) {
            throw new MsalInteropException(
                    "Could not find or load MSALRuntime dll.", "msalruntime_initialization_error");
        }
    }
}
