/*
 * DATABRICKS CONFIDENTIAL & PROPRIETARY
 * __________________
 *
 * Copyright 2023-present Databricks, Inc.
 * All Rights Reserved.
 *
 * NOTICE:  All information contained herein is, and remains the property of Databricks, Inc.
 * and its suppliers, if any.  The intellectual and technical concepts contained herein are
 * proprietary to Databricks, Inc. and its suppliers and may be covered by U.S. and foreign Patents,
 * patents in process, and are protected by trade secret and/or copyright law. Dissemination, use,
 * or reproduction of this information is strictly forbidden unless prior written permission is
 * obtained from Databricks, Inc.
 *
 * If you view or obtain a copy of this information and believe Databricks, Inc. may not have
 * intended it to be made available, please promptly report it to Databricks Legal Department
 * @ legal@databricks.com.
 */
// This package declaration doesn't match the directory structure, but it is here as we want to
// override the  SparkConnectClient's internals. Given that SparkConnectClient is packaged and
// distributed to the users, we don't want to expose the Databricks internal connection details
// to the users, hence we do this override here.
package org.apache.spark.sql.connect.client

import io.grpc._

import org.apache.spark.internal.Logging

/**
 * This class represents the mTLS enabled SparkConnectClient. This specific class allows plain HTTP
 * connection, as it is only reaching the mTLS proxy on the same node which will perform the SSL
 * encryption for us. The most important part here is building the channel for the
 * SparkConnectClient - this class is able to create an insecure channel with the correct
 * interceptors.
 */
class MTlsBuilder extends SparkConnectClient.Builder with Logging {
  private var mTLSEnabled: Boolean = false

  def withMTlsEnabled(mTLSEnabled: Boolean): SparkConnectClient.Builder = {
    this.mTLSEnabled = mTLSEnabled
    this
  }

  override def token(inputToken: String): SparkConnectClient.Builder = {
    if (!mTLSEnabled) {
      super.token(inputToken)
    } else {
      _configuration = _configuration.copy(token = Option(inputToken), isSslEnabled = Some(false))
      this
    }
  }

  override def build(): SparkConnectClient = {
    if (mTLSEnabled && _configuration.token.isDefined) {
      new SparkConnectClient(_configuration,
        _configuration.createChannelAndEventLoopGroupOpt(credentials))
    } else {
      super.build()
    }
  }

  /**
   * Creates the HTTP channel with the token credentials passed through. This is only used for the
   * mTLS proxy which we know lives on the same machine.
   */
  private def credentials = {
    if (!mTLSEnabled) {
      _configuration.credentials
    } else {
      CompositeChannelCredentials.create(InsecureChannelCredentials.create(),
        new SparkConnectClient.AccessTokenCallCredentials(_configuration.token.get))
    }
  }
}
