/*
 * Decompiled with CFR 0.152.
 */
package org.apache.pinot.core.operator.transform.function;

import com.google.common.base.Preconditions;
import java.util.List;
import java.util.Map;
import org.apache.pinot.common.function.scalar.VectorFunctions;
import org.apache.pinot.core.operator.ColumnContext;
import org.apache.pinot.core.operator.blocks.ValueBlock;
import org.apache.pinot.core.operator.transform.TransformResultMetadata;
import org.apache.pinot.core.operator.transform.function.BaseTransformFunction;
import org.apache.pinot.core.operator.transform.function.LiteralTransformFunction;
import org.apache.pinot.core.operator.transform.function.TransformFunction;

public class VectorTransformFunctions {

    public static class VectorNormTransformFunction
    extends BaseTransformFunction {
        public static final String FUNCTION_NAME = "vectorNorm";
        private TransformFunction _transformFunction;

        @Override
        public void init(List<TransformFunction> arguments, Map<String, ColumnContext> columnContextMap) {
            super.init(arguments, columnContextMap);
            if (arguments.size() != 1) {
                throw new IllegalArgumentException("Exactly 1 argument is required for Vector transform function");
            }
            this._transformFunction = arguments.get(0);
            Preconditions.checkArgument((!this._transformFunction.getResultMetadata().isSingleValue() ? 1 : 0) != 0, (String)"Argument must be multi-valued float vector for vector distance transform function: %s", (Object)this.getName());
        }

        @Override
        public String getName() {
            return FUNCTION_NAME;
        }

        @Override
        public TransformResultMetadata getResultMetadata() {
            return DOUBLE_SV_NO_DICTIONARY_METADATA;
        }

        @Override
        public double[] transformToDoubleValuesSV(ValueBlock valueBlock) {
            int length = valueBlock.getNumDocs();
            this.initDoubleValuesSV(length);
            float[][] values = this._transformFunction.transformToFloatValuesMV(valueBlock);
            for (int i = 0; i < length; ++i) {
                this._doubleValuesSV[i] = VectorFunctions.vectorNorm((float[])values[i]);
            }
            return this._doubleValuesSV;
        }
    }

    public static class VectorDimsTransformFunction
    extends BaseTransformFunction {
        public static final String FUNCTION_NAME = "vectorDims";
        private TransformFunction _transformFunction;

        @Override
        public void init(List<TransformFunction> arguments, Map<String, ColumnContext> columnContextMap) {
            super.init(arguments, columnContextMap);
            if (arguments.size() != 1) {
                throw new IllegalArgumentException("Exactly 1 argument is required for Vector transform function");
            }
            this._transformFunction = arguments.get(0);
            Preconditions.checkArgument((!this._transformFunction.getResultMetadata().isSingleValue() ? 1 : 0) != 0, (String)"Argument must be multi-valued float vector for vector distance transform function: %s", (Object)this.getName());
        }

        @Override
        public String getName() {
            return FUNCTION_NAME;
        }

        @Override
        public TransformResultMetadata getResultMetadata() {
            return INT_SV_NO_DICTIONARY_METADATA;
        }

        @Override
        public int[] transformToIntValuesSV(ValueBlock valueBlock) {
            int length = valueBlock.getNumDocs();
            this.initIntValuesSV(length);
            float[][] values = this._transformFunction.transformToFloatValuesMV(valueBlock);
            for (int i = 0; i < length; ++i) {
                this._intValuesSV[i] = VectorFunctions.vectorDims((float[])values[i]);
            }
            return this._intValuesSV;
        }
    }

    public static abstract class VectorDistanceTransformFunction
    extends BaseTransformFunction {
        protected TransformFunction _leftTransformFunction;
        protected TransformFunction _rightTransformFunction;

        @Override
        public void init(List<TransformFunction> arguments, Map<String, ColumnContext> columnContextMap) {
            super.init(arguments, columnContextMap);
            this.checkArgumentSize(arguments);
            this._leftTransformFunction = arguments.get(0);
            this._rightTransformFunction = arguments.get(1);
            Preconditions.checkArgument((!this._leftTransformFunction.getResultMetadata().isSingleValue() && !this._rightTransformFunction.getResultMetadata().isSingleValue() ? 1 : 0) != 0, (String)"Argument must be multi-valued float vector for vector distance transform function: %s", (Object)this.getName());
        }

        protected void checkArgumentSize(List<TransformFunction> arguments) {
            if (arguments.size() != 2) {
                throw new IllegalArgumentException("Exactly 2 arguments are required for Vector transform function");
            }
        }

        @Override
        public TransformResultMetadata getResultMetadata() {
            return DOUBLE_SV_NO_DICTIONARY_METADATA;
        }

        @Override
        public double[] transformToDoubleValuesSV(ValueBlock valueBlock) {
            int length = valueBlock.getNumDocs();
            this.initDoubleValuesSV(length);
            float[][] leftValues = this._leftTransformFunction.transformToFloatValuesMV(valueBlock);
            float[][] rightValues = this._rightTransformFunction.transformToFloatValuesMV(valueBlock);
            for (int i = 0; i < length; ++i) {
                this._doubleValuesSV[i] = this.computeDistance(leftValues[i], rightValues[i]);
            }
            return this._doubleValuesSV;
        }

        protected abstract double computeDistance(float[] var1, float[] var2);
    }

    public static class L2DistanceTransformFunction
    extends VectorDistanceTransformFunction {
        public static final String FUNCTION_NAME = "l2Distance";

        @Override
        public String getName() {
            return FUNCTION_NAME;
        }

        @Override
        protected double computeDistance(float[] vector1, float[] vector2) {
            return VectorFunctions.l2Distance((float[])vector1, (float[])vector2);
        }
    }

    public static class L1DistanceTransformFunction
    extends VectorDistanceTransformFunction {
        public static final String FUNCTION_NAME = "l1Distance";

        @Override
        public String getName() {
            return FUNCTION_NAME;
        }

        @Override
        protected double computeDistance(float[] vector1, float[] vector2) {
            return VectorFunctions.l1Distance((float[])vector1, (float[])vector2);
        }
    }

    public static class InnerProductTransformFunction
    extends VectorDistanceTransformFunction {
        public static final String FUNCTION_NAME = "innerProduct";

        @Override
        public String getName() {
            return FUNCTION_NAME;
        }

        @Override
        protected double computeDistance(float[] vector1, float[] vector2) {
            return VectorFunctions.innerProduct((float[])vector1, (float[])vector2);
        }
    }

    public static class CosineDistanceTransformFunction
    extends VectorDistanceTransformFunction {
        public static final String FUNCTION_NAME = "cosineDistance";
        private Double _defaultValue = null;

        @Override
        protected void checkArgumentSize(List<TransformFunction> arguments) {
            if (arguments.size() < 2 || arguments.size() > 3) {
                throw new IllegalArgumentException("2 or 3 arguments are required for CosineDistance function");
            }
        }

        @Override
        public void init(List<TransformFunction> arguments, Map<String, ColumnContext> columnContextMap) {
            super.init(arguments, columnContextMap);
            if (arguments.size() == 3) {
                this._defaultValue = ((LiteralTransformFunction)arguments.get(2)).getDoubleLiteral();
            }
        }

        @Override
        public String getName() {
            return FUNCTION_NAME;
        }

        @Override
        protected double computeDistance(float[] vector1, float[] vector2) {
            if (this._defaultValue != null) {
                return VectorFunctions.cosineDistance((float[])vector1, (float[])vector2, (double)this._defaultValue);
            }
            return VectorFunctions.cosineDistance((float[])vector1, (float[])vector2);
        }
    }
}

