/*
 * JBoss, Home of Professional Open Source.
 * Copyright 2013, Red Hat, Inc., and individual contributors
 * as indicated by the @author tags. See the copyright.txt file in the
 * distribution for a full listing of individual contributors.
 *
 * This is free software; you can redistribute it and/or modify it
 * under the terms of the GNU Lesser General Public License as
 * published by the Free Software Foundation; either version 2.1 of
 * the License, or (at your option) any later version.
 *
 * This software is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
 * Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public
 * License along with this software; if not, write to the Free
 * Software Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
 * 02110-1301 USA, or see the FSF site: http://www.fsf.org.
 */

package org.wildfly.extension.undertow.security;

import io.undertow.predicate.Predicates;
import io.undertow.server.HandlerWrapper;
import io.undertow.server.HttpHandler;
import io.undertow.server.HttpServerExchange;
import io.undertow.server.handlers.PredicateHandler;
import io.undertow.servlet.handlers.ServletChain;
import io.undertow.servlet.handlers.ServletRequestContext;
import io.undertow.servlet.predicate.DispatcherTypePredicate;
import org.jboss.metadata.javaee.jboss.RunAsIdentityMetaData;
import org.jboss.security.RunAs;
import org.jboss.security.RunAsIdentity;
import org.jboss.security.SecurityContext;
import org.wildfly.extension.undertow.UndertowLogger;
import org.wildfly.security.manager.WildFlySecurityManager;

import javax.security.jacc.PolicyContext;
import java.security.PrivilegedAction;
import java.util.Map;

import static org.wildfly.extension.undertow.security.SecurityActions.setRunAsIdentity;

public class SecurityContextAssociationHandler implements HttpHandler {

    private final Map<String, RunAsIdentityMetaData> runAsIdentityMetaDataMap;
    private final String contextId;
    private final HttpHandler next;

    private final PrivilegedAction<String> setContextIdAction;

    public SecurityContextAssociationHandler(final Map<String, RunAsIdentityMetaData> runAsIdentityMetaDataMap, final String contextId, final HttpHandler next) {
        this.runAsIdentityMetaDataMap = runAsIdentityMetaDataMap;
        this.contextId = contextId;
        this.next = next;
        this.setContextIdAction = new SetContextIDAction(contextId);
    }

    @Override
    public void handleRequest(final HttpServerExchange exchange) throws Exception {
        SecurityContext sc = exchange.getAttachment(UndertowSecurityAttachments.SECURITY_CONTEXT_ATTACHMENT);
        String previousContextID = null;
        RunAsIdentityMetaData identity = null;
        RunAs old = null;
        try {
            final ServletChain servlet = exchange.getAttachment(ServletRequestContext.ATTACHMENT_KEY).getCurrentServlet();
            identity = runAsIdentityMetaDataMap.get(servlet.getManagedServlet().getServletInfo().getName());
            RunAsIdentity runAsIdentity = null;
            if (identity != null) {
                UndertowLogger.ROOT_LOGGER.tracef("%s, runAs: %s", servlet.getManagedServlet().getServletInfo().getName(), identity);
                runAsIdentity = new RunAsIdentity(identity.getRoleName(), identity.getPrincipalName(), identity.getRunAsRoles());
            }
            old = setRunAsIdentity(runAsIdentity, sc);

            // set JACC contextID
            previousContextID = setContextID(setContextIdAction);

            // Perform the request
            next.handleRequest(exchange);
        } finally {
            if (identity != null) {
                setRunAsIdentity(old, sc);
            }
            setContextID(new SetContextIDAction(previousContextID));
        }
    }

    private static class SetContextIDAction implements PrivilegedAction<String> {

        private final String contextID;

        SetContextIDAction(String contextID) {
            this.contextID = contextID;
        }

        @Override
        public String run() {
            String currentContextID = PolicyContext.getContextID();
            PolicyContext.setContextID(this.contextID);
            return currentContextID;
        }
    }

    private String setContextID(PrivilegedAction<String> action) {
        if(WildFlySecurityManager.isChecking()) {
            return WildFlySecurityManager.doUnchecked(action);
        }else {
            return action.run();
        }
    }


    public static HandlerWrapper wrapper(final Map<String, RunAsIdentityMetaData> runAsIdentityMetaDataMap, final String contextId) {
        return new HandlerWrapper() {
            @Override
            public HttpHandler wrap(final HttpHandler handler) {
                //we only run this on REQUEST or ASYNC invocations
                return new PredicateHandler(Predicates.or(DispatcherTypePredicate.REQUEST, DispatcherTypePredicate.ASYNC), new SecurityContextAssociationHandler(runAsIdentityMetaDataMap, contextId, handler), handler);
            }
        };
    }
}
