/*
 * Copyright (c) 2021-2023, NVIDIA CORPORATION.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.nvidia.spark.rapids

import scala.collection.mutable.ArrayBuffer

import ai.rapids.cudf.ColumnVector
import com.nvidia.spark.rapids.Arm.withResource

import org.apache.spark.sql.vectorized.ColumnarBatch

object GatherUtils {
  def gather(cb: ColumnarBatch, rows: ArrayBuffer[Int]): ColumnarBatch = {
    val colTypes = GpuColumnVector.extractTypes(cb)
    if (rows.isEmpty) {
      GpuColumnVector.emptyBatchFromTypes(colTypes)
    } else if (cb.numCols() == 0) {
      // for count agg, num of cols is 0
      val c = GpuColumnVector.emptyBatchFromTypes(colTypes)
      c.setNumRows(rows.length)
      c
    } else {
      withResource(ColumnVector.fromInts(rows: _*)) { gatherCv =>
        withResource(GpuColumnVector.from(cb)) { table =>
          // GPU gather
          withResource(table.gather(gatherCv)) { gatheredTable =>
            GpuColumnVector.from(gatheredTable, colTypes)
          }
        }
      }
    }
  }
}
