package org.mule.soap.internal.rm;

import static org.mule.soap.internal.client.AbstractSoapCxfClient.MULE_ATTACHMENTS_KEY;
import static org.mule.soap.internal.client.AbstractSoapCxfClient.MULE_HEADERS_KEY;
import static org.mule.soap.internal.client.AbstractSoapCxfClient.MESSAGE_DISPATCHER;
import static org.mule.soap.internal.client.AbstractSoapCxfClient.MULE_SOAP_OPERATION_STYLE;
import static org.mule.soap.internal.client.AbstractSoapCxfClient.MULE_TRANSPORT_HEADERS_KEY;
import static org.mule.soap.internal.client.AbstractSoapCxfClient.MULE_WSC_ADDRESS;
import static org.mule.soap.internal.rm.RMUtils.MULE_ADDRESSING_ENABLE;
import static org.mule.soap.internal.rm.RMUtils.MULE_RM_ENABLE;

import static java.lang.Boolean.TRUE;
import static java.lang.String.format;
import static java.util.Collections.emptyList;
import static java.util.Collections.emptyMap;
import static org.apache.cxf.interceptor.StaxOutInterceptor.FORCE_START_DOCUMENT;
import static org.apache.cxf.message.Message.ENCODING;

import org.mule.soap.api.client.BadRequestException;
import org.mule.soap.api.rm.CreateSequenceRequest;
import org.mule.soap.api.rm.TerminateSequenceRequest;
import org.mule.soap.api.transport.TransportDispatcher;
import org.mule.wsdl.parser.model.operation.OperationType;

import java.util.HashMap;
import java.util.Map;

import javax.xml.datatype.DatatypeFactory;
import javax.xml.datatype.Duration;
import javax.xml.namespace.QName;

import org.apache.cxf.Bus;
import org.apache.cxf.endpoint.Client;
import org.apache.cxf.endpoint.ClientImpl;
import org.apache.cxf.endpoint.ConduitSelector;
import org.apache.cxf.endpoint.DeferredConduitSelector;
import org.apache.cxf.endpoint.Endpoint;
import org.apache.cxf.message.Exchange;
import org.apache.cxf.message.ExchangeImpl;
import org.apache.cxf.message.Message;
import org.apache.cxf.message.MessageImpl;
import org.apache.cxf.service.model.BindingInfo;
import org.apache.cxf.service.model.BindingOperationInfo;
import org.apache.cxf.service.model.EndpointInfo;
import org.apache.cxf.service.model.OperationInfo;
import org.apache.cxf.transport.Conduit;
import org.apache.cxf.ws.addressing.EndpointReferenceType;
import org.apache.cxf.ws.addressing.MAPAggregator;
import org.apache.cxf.ws.addressing.WSAddressingFeature;
import org.apache.cxf.ws.rm.Destination;
import org.apache.cxf.ws.rm.DestinationSequence;
import org.apache.cxf.ws.rm.EncoderDecoder;
import org.apache.cxf.ws.rm.ProtocolVariation;
import org.apache.cxf.ws.rm.Source;
import org.apache.cxf.ws.rm.SourceSequence;
import org.apache.cxf.ws.rm.RetransmissionQueue;
import org.apache.cxf.ws.rm.RMConstants;
import org.apache.cxf.ws.rm.RMManager;
import org.apache.cxf.ws.rm.RMEndpoint;
import org.apache.cxf.ws.rm.RMUtils;
import org.apache.cxf.ws.rm.manager.SourcePolicyType;
import org.apache.cxf.ws.rm.v200702.AcceptType;
import org.apache.cxf.ws.rm.v200702.CreateSequenceType;
import org.apache.cxf.ws.rm.v200702.CreateSequenceResponseType;
import org.apache.cxf.ws.rm.v200702.TerminateSequenceType;
import org.apache.cxf.ws.rm.v200702.Expires;
import org.apache.cxf.ws.rm.v200702.Identifier;
import org.apache.cxf.ws.rm.v200702.OfferType;

public class RMClient {

  private final RMManager manager;
  private final Client client;
  private final String address;

  public RMClient(Client client, String address) {
    this.address = address;
    this.client = client;
    this.manager = client.getBus().getExtension(RMManager.class);
  }

  public SourceSequence createSequence(CreateSequenceRequest request, TransportDispatcher dispatcher) throws Exception {
    Message message = getMessage();
    Source source = manager.getSource(message);
    RMEndpoint reliableEndpoint = source.getReliableEndpoint();

    CreateSequenceType create = new CreateSequenceType();
    create.setAcksTo(RMUtils.createAnonymousReference());

    SourcePolicyType sp = manager.getSourcePolicy();
    Duration d = DatatypeFactory.newInstance().newDuration(request.getSequenceTtl());
    if (d != null) {
      Expires expires = new Expires();
      expires.setValue(d);
      create.setExpires(expires);
    }

    if (sp.isIncludeOffer()) {
      OfferType offer = new OfferType();
      d = sp.getOfferedSequenceExpiration();
      if (null != d) {
        Expires expires = new Expires();
        expires.setValue(d);
        offer.setExpires(expires);
      }
      offer.setIdentifier(reliableEndpoint.getSource().generateSequenceIdentifier());
      offer.setEndpoint(RMUtils.createAnonymousReference());
      create.setOffer(offer);
    }

    ProtocolVariation protocol =
        ProtocolVariation.findVariant(request.getNamespaceUri(), request.getAddressingNamespaceUri().orElse(null));
    EncoderDecoder codec = protocol.getCodec();

    final OperationInfo oi = getOperationInfo(reliableEndpoint, protocol, codec.getConstants().getCreateSequenceOperationName());

    Object resp =
        invoke(reliableEndpoint, oi, protocol, new Object[] {codec.convertToSend(create)},
               getInvocationContext(oi.isOneWay(), dispatcher, request.getNamespaceUri(),
                                    request.getAddressingNamespaceUri().orElse(null)),
               new ExchangeImpl());
    CreateSequenceResponseType createSequenceResponseType = codec.convertReceivedCreateSequenceResponse(resp);

    SourceSequence seq = new SourceSequence(createSequenceResponseType.getIdentifier(),
                                            protocol);
    seq.setExpires(createSequenceResponseType.getExpires());
    source.addSequence(seq);

    if (sp.isIncludeOffer()) {
      AcceptType accept = createSequenceResponseType.getAccept();
      if (accept != null) {
        Destination dest = reliableEndpoint.getDestination();
        String address = accept.getAcksTo().getAddress().getValue();
        if (!RMUtils.getAddressingConstants().getNoneURI().equals(address)) {
          DestinationSequence ds = new DestinationSequence(create.getOffer().getIdentifier(), accept.getAcksTo(), dest,
                                                           protocol);
          dest.addSequence(ds);
        }
      }
    }

    SourceSequence sourceSequence = source.getSequence(createSequenceResponseType.getIdentifier());
    sourceSequence.setTarget(RMUtils.createReference(address));
    return sourceSequence;
  }

  public void terminateSequence(TerminateSequenceRequest request, TransportDispatcher dispatcher) throws Exception {
    SourceSequence sourceSequence = getSourceSequence(request.getSequenceIdentifier());

    if (sourceSequence == null) {
      throw new BadRequestException(format("Error at sequence [%s] termination, no sequence found for that identifier.",
                                           request.getSequenceIdentifier()));
    }

    RetransmissionQueue retransmissionQueue = manager.getRetransmissionQueue();
    retransmissionQueue.stop(sourceSequence);

    RMEndpoint reliableEndpoint = sourceSequence.getSource().getReliableEndpoint();

    ProtocolVariation protocol = sourceSequence.getProtocol();
    RMConstants constants = protocol.getConstants();
    OperationInfo oi = getOperationInfo(reliableEndpoint, protocol, constants.getTerminateSequenceOperationName());

    TerminateSequenceType ts = new TerminateSequenceType();
    ts.setIdentifier(sourceSequence.getIdentifier());
    ts.setLastMsgNumber(sourceSequence.getCurrentMessageNr());

    invoke(reliableEndpoint, oi, protocol, new Object[] {protocol.getCodec().convertToSend(ts)},
           getInvocationContext(oi.isOneWay(), dispatcher),
           new ExchangeImpl());

    sourceSequence.getSource().removeSequence(sourceSequence);
  }

  private Map<String, Object> getInvocationContext(boolean isOneWay, TransportDispatcher dispatcher) {
    return getInvocationContext(isOneWay, dispatcher, null, null);
  }

  private Map<String, Object> getInvocationContext(boolean isOneWay, TransportDispatcher dispatcher, String namespaceUri,
                                                   String addressingNamespaceUri) {
    OperationType operationType = isOneWay
        ? OperationType.ONE_WAY
        : OperationType.REQUEST_RESPONSE;

    Map<String, Object> props = new HashMap<>();
    props.put(MULE_ATTACHMENTS_KEY, emptyMap());
    props.put(MULE_WSC_ADDRESS, address);
    props.put(ENCODING, "UTF-8");
    props.put(MULE_HEADERS_KEY, emptyList());
    props.put(MULE_TRANSPORT_HEADERS_KEY, emptyMap());
    props.put(MESSAGE_DISPATCHER, dispatcher);
    props.put(MULE_SOAP_OPERATION_STYLE, operationType);
    props.put(FORCE_START_DOCUMENT, false);
    props.put(MULE_ADDRESSING_ENABLE, TRUE);
    props.put(MULE_RM_ENABLE, TRUE);

    if (addressingNamespaceUri != null) {
      props.put(MAPAggregator.ADDRESSING_NAMESPACE, addressingNamespaceUri);
      props.put(RMManager.WSRM_WSA_VERSION_PROPERTY, addressingNamespaceUri);
    }
    if (namespaceUri != null) {
      props.put(RMManager.WSRM_VERSION_PROPERTY, namespaceUri);
    }

    Map<String, Object> ctx = new HashMap<>();
    ctx.put(Client.REQUEST_CONTEXT, props);
    return ctx;
  }

  private OperationInfo getOperationInfo(RMEndpoint reliableEndpoint, ProtocolVariation protocol,
                                         QName closeSequenceOperationName) {
    return reliableEndpoint.getEndpoint(protocol).getEndpointInfo().getService().getInterface()
        .getOperation(closeSequenceOperationName);
  }

  public SourceSequence getSourceSequence(String sequence) {
    try {
      RMManager manager = client.getBus().getExtension(RMManager.class);
      Source source = manager.getSource(getMessage());

      Identifier identifier = new Identifier();
      identifier.setValue(sequence);

      return source.getSequence(identifier);
    } catch (Exception e) {
      return null;
    }
  }

  private Message getMessage() {
    Exchange exchange = new ExchangeImpl();
    exchange.put(Bus.class, client.getBus());
    exchange.put(Endpoint.class, client.getEndpoint());

    Message message = new MessageImpl();
    message.setExchange(exchange);
    return message;
  }

  Object invoke(RMEndpoint reliableEndpoint, OperationInfo oi, ProtocolVariation protocol,
                Object[] params, Map<String, Object> context,
                Exchange exchange)
      throws Exception {
    RMManager manager = reliableEndpoint.getManager();
    Bus bus = manager.getBus();
    Endpoint endpoint = reliableEndpoint.getEndpoint(protocol);
    Endpoint applicationEndpoint = reliableEndpoint.getApplicationEndpoint();
    BindingInfo bi = reliableEndpoint.getBindingInfo(protocol);
    Conduit c = reliableEndpoint.getConduit();
    EndpointReferenceType replyTo = RMUtils.createAnonymousReference();
    Client innerClient = createClient(bus, endpoint, applicationEndpoint, protocol, c, replyTo);
    BindingOperationInfo boi = bi.getOperation(oi);

    Object[] result = innerClient.invoke(boi, params, context, exchange);
    if (result != null && result.length > 0) {
      return result[0];
    }
    return null;
  }

  private Client createClient(Bus bus, Endpoint endpoint, Endpoint applicationEndpoint, final ProtocolVariation protocol,
                              Conduit conduit, final EndpointReferenceType address) {
    ConduitSelector cs = new DeferredConduitSelector(conduit) {

      @Override
      public synchronized Conduit selectConduit(Message message) {
        Conduit conduit = null;
        EndpointInfo endpointInfo = getEndpoint().getEndpointInfo();
        EndpointReferenceType original =
            endpointInfo.getTarget();
        try {
          if (null != address) {
            endpointInfo.setAddress(address);
          }
          conduit = super.selectConduit(message);
        } finally {
          endpointInfo.setAddress(original);
        }
        return conduit;
      }
    };
    InnerClient client = new InnerClient(bus, endpoint, applicationEndpoint, cs);
    WSAddressingFeature wsa = new WSAddressingFeature();
    wsa.setAddressingRequired(true);
    wsa.initialize(client, bus);
    Map<String, Object> context = client.getRequestContext();
    context.put(MAPAggregator.ADDRESSING_NAMESPACE, protocol.getWSANamespace());
    context.put(RMManager.WSRM_VERSION_PROPERTY, protocol.getWSRMNamespace());
    context.put(RMManager.WSRM_WSA_VERSION_PROPERTY, protocol.getWSANamespace());
    return client;
  }


  class InnerClient extends ClientImpl {

    private final Endpoint applicationEndpoint;

    InnerClient(Bus bus, Endpoint endpoint, Endpoint applicationEndpoint, ConduitSelector cs) {
      super(bus, endpoint, cs);
      this.applicationEndpoint = applicationEndpoint;
    }

    @Override
    public void onMessage(Message m) {
      m.getExchange().put(Endpoint.class, applicationEndpoint);
      super.onMessage(m);
    }
  }
}
