/*
 * Copyright The OpenZipkin Authors
 * SPDX-License-Identifier: Apache-2.0
 */
package brave.jakarta.jms;

import jakarta.jms.Connection;
import jakarta.jms.ConnectionFactory;
import jakarta.jms.JMSContext;
import jakarta.jms.JMSException;
import jakarta.jms.QueueConnection;
import jakarta.jms.QueueConnectionFactory;
import jakarta.jms.TopicConnection;
import jakarta.jms.TopicConnectionFactory;
import jakarta.jms.XAConnectionFactory;
import jakarta.jms.XAQueueConnectionFactory;
import jakarta.jms.XATopicConnectionFactory;

/** Implements all interfaces as according to ActiveMQ, this is typical of JMS 1.1. */
class TracingConnectionFactory implements QueueConnectionFactory, TopicConnectionFactory {
  static final int
    TYPE_CF = 1 << 1,
    TYPE_QUEUE_CF = 1 << 2,
    TYPE_TOPIC_CF = 1 << 3,
    TYPE_XA_CF = 1 << 4,
    TYPE_XA_QUEUE_CF = 1 << 5,
    TYPE_XA_TOPIC_CF = 1 << 6;

  static ConnectionFactory create(ConnectionFactory delegate, JmsTracing jmsTracing) {
    if (delegate == null) throw new NullPointerException("connectionFactory == null");
    if (delegate instanceof TracingConnectionFactory) return delegate;
    return new TracingConnectionFactory(delegate, jmsTracing);
  }

  // Object because ConnectionFactory and XAConnectionFactory share no common root
  final Object delegate;
  final JmsTracing jmsTracing;
  final int types;

  TracingConnectionFactory(Object delegate, JmsTracing jmsTracing) {
    this.delegate = delegate;
    this.jmsTracing = jmsTracing;
    int types = 0;
    if (delegate instanceof ConnectionFactory) types |= TYPE_CF;
    if (delegate instanceof QueueConnectionFactory) types |= TYPE_QUEUE_CF;
    if (delegate instanceof TopicConnectionFactory) types |= TYPE_TOPIC_CF;
    if (delegate instanceof XAConnectionFactory) types |= TYPE_XA_CF;
    if (delegate instanceof XAQueueConnectionFactory) types |= TYPE_XA_QUEUE_CF;
    if (delegate instanceof XATopicConnectionFactory) types |= TYPE_XA_TOPIC_CF;
    this.types = types;
  }

  @Override public Connection createConnection() throws JMSException {
    checkConnectionFactory();
    return TracingConnection.create(((ConnectionFactory) delegate).createConnection(), jmsTracing);
  }

  @Override public Connection createConnection(String userName, String password)
    throws JMSException {
    checkConnectionFactory();
    ConnectionFactory cf = (ConnectionFactory) delegate;
    return TracingConnection.create(cf.createConnection(userName, password), jmsTracing);
  }

  @Override public JMSContext createContext() {
    checkConnectionFactory();
    return TracingJMSContext.create(((ConnectionFactory) delegate).createContext(), jmsTracing);
  }

  @Override public JMSContext createContext(String userName, String password) {
    checkConnectionFactory();
    JMSContext ctx = ((ConnectionFactory) delegate).createContext(userName, password);
    return TracingJMSContext.create(ctx, jmsTracing);
  }

  @Override public JMSContext createContext(String userName, String password, int sessionMode) {
    checkConnectionFactory();
    JMSContext ctx = ((ConnectionFactory) delegate).createContext(userName, password, sessionMode);
    return TracingJMSContext.create(ctx, jmsTracing);
  }

  @Override public JMSContext createContext(int sessionMode) {
    checkConnectionFactory();
    JMSContext ctx = ((ConnectionFactory) delegate).createContext(sessionMode);
    return TracingJMSContext.create(ctx, jmsTracing);
  }

  /**
   * We have to check for what seems base case as the constructor is shared by {@link
   * TracingXAConnection}, which is might not be a {@link ConnectionFactory}!
   */
  void checkConnectionFactory() {
    if ((types & TYPE_CF) != TYPE_CF) {
      throw new IllegalStateException(delegate + " is not a ConnectionFactory");
    }
  }

  // QueueConnectionFactory

  @Override public QueueConnection createQueueConnection() throws JMSException {
    checkQueueConnectionFactory();
    QueueConnectionFactory qcf = (QueueConnectionFactory) delegate;
    return TracingConnection.create(qcf.createQueueConnection(), jmsTracing);
  }

  @Override public QueueConnection createQueueConnection(String userName, String password)
    throws JMSException {
    checkQueueConnectionFactory();
    QueueConnectionFactory qcf = (QueueConnectionFactory) delegate;
    return TracingConnection.create(qcf.createQueueConnection(userName, password), jmsTracing);
  }

  void checkQueueConnectionFactory() {
    if ((types & TYPE_QUEUE_CF) != TYPE_QUEUE_CF) {
      throw new IllegalStateException(delegate + " is not a QueueConnectionFactory");
    }
  }

  // TopicConnectionFactory

  @Override public TopicConnection createTopicConnection() throws JMSException {
    checkTopicConnectionFactory();
    TopicConnectionFactory qcf = (TopicConnectionFactory) delegate;
    return TracingConnection.create(qcf.createTopicConnection(), jmsTracing);
  }

  @Override public TopicConnection createTopicConnection(String userName, String password)
    throws JMSException {
    checkTopicConnectionFactory();
    TopicConnectionFactory qcf = (TopicConnectionFactory) delegate;
    return TracingConnection.create(qcf.createTopicConnection(userName, password), jmsTracing);
  }

  void checkTopicConnectionFactory() {
    if ((types & TYPE_TOPIC_CF) != TYPE_TOPIC_CF) {
      throw new IllegalStateException(delegate + " is not a TopicConnectionFactory");
    }
  }
}
