/*
 * Decompiled with CFR 0.152.
 */
package org.apache.druid.math.expr;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ThreadLocalRandom;
import java.util.function.BooleanSupplier;
import java.util.function.DoubleSupplier;
import java.util.function.LongSupplier;
import java.util.function.Supplier;
import javax.annotation.Nullable;
import org.apache.druid.java.util.common.NonnullPair;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprEval;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.math.expr.ExprType;
import org.apache.druid.math.expr.ExpressionProcessing;
import org.apache.druid.math.expr.ExpressionType;
import org.apache.druid.math.expr.InputBindings;
import org.apache.druid.math.expr.Parser;
import org.apache.druid.math.expr.SettableObjectBinding;
import org.apache.druid.math.expr.SettableVectorInputBinding;
import org.apache.druid.math.expr.vector.ExprEvalVector;
import org.apache.druid.math.expr.vector.ExprVectorProcessor;
import org.apache.druid.query.expression.LookupExprMacro;
import org.apache.druid.query.expression.NestedDataExpressions;
import org.apache.druid.query.expression.TimestampFloorExprMacro;
import org.apache.druid.query.expression.TimestampShiftExprMacro;
import org.apache.druid.query.lookup.LookupExtractorFactory;
import org.apache.druid.query.lookup.LookupExtractorFactoryContainer;
import org.apache.druid.query.lookup.LookupExtractorFactoryContainerProvider;
import org.apache.druid.query.lookup.TestMapLookupExtractorFactory;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Assert;
import org.junit.Assume;
import org.junit.Test;

public class VectorExprResultConsistencyTest
extends InitializedNullHandlingTest {
    private static final Logger log = new Logger(VectorExprResultConsistencyTest.class);
    private static final int NUM_ITERATIONS = 10;
    private static final int VECTOR_SIZE = 512;
    private static final Map<String, String> LOOKUP = Map.of("1", "a", "12", "b", "33", "c", "111", "d", "123", "e", "124", "f");
    private static final Map<String, String> INJECTIVE_LOOKUP = new HashMap<String, String>(){

        @Override
        public String get(Object key) {
            return (String)key;
        }
    };
    private static final ExprMacroTable MACRO_TABLE = new ExprMacroTable((List)ImmutableList.of((Object)new TimestampFloorExprMacro(), (Object)new TimestampShiftExprMacro(), (Object)new NestedDataExpressions.JsonObjectExprMacro(), (Object)new LookupExprMacro(new LookupExtractorFactoryContainerProvider(){

        public Set<String> getAllLookupNames() {
            return Set.of("test-lookup", "test-lookup-injective");
        }

        public Optional<LookupExtractorFactoryContainer> get(String lookupName) {
            if ("test-lookup".equals(lookupName)) {
                return Optional.of(new LookupExtractorFactoryContainer("v0", (LookupExtractorFactory)new TestMapLookupExtractorFactory(LOOKUP, false)));
            }
            if ("test-lookup-injective".equals(lookupName)) {
                return Optional.of(new LookupExtractorFactoryContainer("v0", (LookupExtractorFactory)new TestMapLookupExtractorFactory(INJECTIVE_LOOKUP, true)));
            }
            return Optional.empty();
        }

        public String getCanonicalLookupName(String lookupName) {
            return "";
        }
    })));
    final Map<String, ExpressionType> types = ImmutableMap.builder().put((Object)"l1", (Object)ExpressionType.LONG).put((Object)"l2", (Object)ExpressionType.LONG).put((Object)"d1", (Object)ExpressionType.DOUBLE).put((Object)"d2", (Object)ExpressionType.DOUBLE).put((Object)"s1", (Object)ExpressionType.STRING).put((Object)"s2", (Object)ExpressionType.STRING).put((Object)"boolString1", (Object)ExpressionType.STRING).put((Object)"boolString2", (Object)ExpressionType.STRING).build();

    @Test
    public void testConstants() {
        VectorExprResultConsistencyTest.testExpression("null", this.types);
        VectorExprResultConsistencyTest.testExpression("1", this.types);
        VectorExprResultConsistencyTest.testExpression("1.1", this.types);
        VectorExprResultConsistencyTest.testExpression("NaN", this.types);
        VectorExprResultConsistencyTest.testExpression("Infinity", this.types);
        VectorExprResultConsistencyTest.testExpression("-Infinity", this.types);
        VectorExprResultConsistencyTest.testExpression("'hello'", this.types);
        VectorExprResultConsistencyTest.testExpression("json_object('a', 1, 'b', 'abc', 'c', 3.3, 'd', array(1,2,3))", this.types);
    }

    @Test
    public void testIdentifiers() {
        ArrayList<String> columns = new ArrayList<String>(this.types.keySet());
        columns.add("unknown");
        List<String> template = List.of("%s");
        VectorExprResultConsistencyTest.testFunctions(this.types, template, columns);
    }

    @Test
    public void testCast() {
        Set<String> columns = Set.of("d1", "l1", "s1");
        Set<String> castTo = Set.of("'STRING'", "'LONG'", "'DOUBLE'", "'ARRAY<STRING>'", "'ARRAY<LONG>'", "'ARRAY<DOUBLE>'");
        Set args = Sets.cartesianProduct((Set[])new Set[]{columns, castTo});
        List<String> templates = List.of("cast(%s, %s)");
        VectorExprResultConsistencyTest.testFunctions(this.types, templates, args);
    }

    @Test
    public void testCastArraysRoundTrip() {
        VectorExprResultConsistencyTest.testExpression("cast(cast(s1, 'ARRAY<STRING>'), 'STRING')", this.types);
        VectorExprResultConsistencyTest.testExpression("cast(cast(d1, 'ARRAY<DOUBLE>'), 'DOUBLE')", this.types);
        VectorExprResultConsistencyTest.testExpression("cast(cast(d1, 'ARRAY<STRING>'), 'DOUBLE')", this.types);
        VectorExprResultConsistencyTest.testExpression("cast(cast(l1, 'ARRAY<LONG>'), 'LONG')", this.types);
        VectorExprResultConsistencyTest.testExpression("cast(cast(l1, 'ARRAY<STRING>'), 'LONG')", this.types);
    }

    @Test
    public void testUnaryOperators() {
        List<String> functions = List.of("-");
        List<String> templates = List.of("%sd1", "%sl1");
        VectorExprResultConsistencyTest.testFunctions(this.types, templates, functions);
    }

    @Test
    public void testBinaryMathOperators() {
        Set<String> columns = Set.of("d1", "d2", "l1", "l2", "1", "1.0", "nonexistent", "null", "s1");
        Set<String> columns2 = Set.of("d1", "d2", "l1", "l2", "1", "1.0");
        Set templateInputs = Sets.cartesianProduct((Set[])new Set[]{columns, columns2});
        ArrayList<String> templates = new ArrayList<String>();
        for (List template : templateInputs) {
            templates.add(StringUtils.format((String)"%s %s %s", (Object[])new Object[]{template.get(0), "%s", template.get(1)}));
        }
        List<String> args = List.of("+", "-", "*", "/", "^", "%");
        VectorExprResultConsistencyTest.testFunctions(this.types, templates, args);
    }

    @Test
    public void testBinaryComparisonOperators() {
        Set<String> columns = Set.of("d1", "d2", "l1", "l2", "1", "1.0", "s1", "s2", "nonexistent", "null");
        Set<String> columns2 = Set.of("d1", "d2", "l1", "l2", "1", "1.0", "s1", "s2", "null");
        Set templateInputs = Sets.cartesianProduct((Set[])new Set[]{columns, columns2});
        ArrayList<String> templates = new ArrayList<String>();
        for (List template : templateInputs) {
            templates.add(StringUtils.format((String)"%s %s %s", (Object[])new Object[]{template.get(0), "%s", template.get(1)}));
        }
        List<String> args = List.of(">", ">=", "<", "<=", "==", "!=");
        VectorExprResultConsistencyTest.testFunctions(this.types, templates, args);
    }

    @Test
    public void testUnaryLogicOperators() {
        List<String> functions = List.of("!");
        List<String> templates = List.of("%sd1", "%sl1", "%sboolString1");
        VectorExprResultConsistencyTest.testFunctions(this.types, templates, functions);
    }

    @Test
    public void testBinaryLogicOperators() {
        List<String> functions = List.of("&&", "||");
        List<String> templates = List.of("d1 %s d2", "l1 %s l2", "boolString1 %s boolString2", "(d1 == d2) %s (l1 == l2)");
        VectorExprResultConsistencyTest.testFunctions(this.types, templates, functions);
    }

    @Test
    public void testBinaryOperatorTrees() {
        Set<String> columns = Set.of("d1", "l1", "1", "1.0", "nonexistent", "null");
        Set<String> columns2 = Set.of("d2", "l2", "2", "2.0");
        Set templateInputs = Sets.cartesianProduct((Set[])new Set[]{columns, columns2, columns2});
        ArrayList<String> templates = new ArrayList<String>();
        for (List template : templateInputs) {
            templates.add(StringUtils.format((String)"(%s %s %s) %s %s", (Object[])new Object[]{template.get(0), "%s", template.get(1), "%s", template.get(2)}));
        }
        Set<String> ops = Set.of("+", "-", "*", "/");
        Set args = Sets.cartesianProduct((Set[])new Set[]{ops, ops});
        VectorExprResultConsistencyTest.testFunctions(this.types, templates, args);
    }

    @Test
    public void testUnivariateFunctions() {
        List<String> functions = List.of("parse_long", "isNull", "notNull");
        List<String> templates = List.of("%s(s1)", "%s(l1)", "%s(d1)", "%s(nonexistent)", "%s(null)");
        VectorExprResultConsistencyTest.testFunctions(this.types, templates, functions);
    }

    @Test
    public void testUnivariateMathFunctions() {
        List<String> functions = List.of("abs", "acos", "asin", "atan", "cbrt", "ceil", "cos", "cosh", "cot", "exp", "expm1", "floor", "getExponent", "log", "log10", "log1p", "nextUp", "rint", "signum", "sin", "sinh", "sqrt", "tan", "tanh", "toDegrees", "toRadians", "ulp", "bitwiseComplement", "bitwiseConvertDoubleToLongBits", "bitwiseConvertLongBitsToDouble");
        List<String> templates = List.of("%s(l1)", "%s(d1)", "%s(pi())", "%s(null)", "%s(missing)");
        VectorExprResultConsistencyTest.testFunctions(this.types, templates, functions);
    }

    @Test
    public void testBivariateMathFunctions() {
        List<String> functions = List.of("atan2", "copySign", "div", "hypot", "remainder", "max", "min", "nextAfter", "scalb", "pow", "bitwiseAnd", "bitwiseOr", "bitwiseXor", "bitwiseShiftLeft", "bitwiseShiftRight");
        List<String> templates = List.of("%s(d1, d2)", "%s(d1, l1)", "%s(l1, d1)", "%s(l1, l2)", "%s(nonexistent, l1)", "%s(nonexistent, d1)");
        VectorExprResultConsistencyTest.testFunctions(this.types, templates, functions);
    }

    @Test
    public void testSymmetricalBivariateFunctions() {
        List<String> functions = List.of("nvl");
        List<String> templates = List.of("%s(d1, d2)", "%s(l1, l2)", "%s(s1, s2)", "%s(nonexistent, l1)", "%s(nonexistent, d1)", "%s(nonexistent, s1)", "%s(nonexistent, nonexistent2)");
        VectorExprResultConsistencyTest.testFunctions(this.types, templates, functions);
    }

    @Test
    public void testStringFns() {
        VectorExprResultConsistencyTest.testExpression("s1 + s2", this.types);
        VectorExprResultConsistencyTest.testExpression("s1 + '-' + s2", this.types);
        VectorExprResultConsistencyTest.testExpression("concat(s1, s2)", this.types);
        VectorExprResultConsistencyTest.testExpression("concat(s1,'-',s2,'-',l1,'-',d1)", this.types);
    }

    @Test
    public void testLookup() {
        ArrayList<String> columns = new ArrayList<String>(this.types.keySet());
        columns.add("unknown");
        List<String> templates = List.of("lookup(%s, 'test-lookup')", "lookup(%s, 'test-lookup', 'missing')", "lookup(%s, 'test-lookup-injective')", "lookup(%s, 'test-lookup-injective', 'missing')");
        VectorExprResultConsistencyTest.testFunctions(this.types, templates, columns);
    }

    @Test
    public void testArrayFns() {
        try {
            ExpressionProcessing.initializeForFallback();
            VectorExprResultConsistencyTest.testExpression("array(s1, s2)", this.types);
            VectorExprResultConsistencyTest.testExpression("array(l1, l2)", this.types);
            VectorExprResultConsistencyTest.testExpression("array(d1, d2)", this.types);
            VectorExprResultConsistencyTest.testExpression("array(l1, d2)", this.types);
            VectorExprResultConsistencyTest.testExpression("array(s1, l2)", this.types);
        }
        finally {
            ExpressionProcessing.initializeForTests();
        }
    }

    @Test
    public void testJsonFns() {
        Assume.assumeTrue((boolean)ExpressionProcessing.allowVectorizeFallback());
        VectorExprResultConsistencyTest.testExpression("json_object('k1', s1, 'k2', l1)", this.types);
    }

    @Test
    public void testTimeFunctions() {
        VectorExprResultConsistencyTest.testExpression("timestamp_floor(l1, 'PT1H')", this.types);
        VectorExprResultConsistencyTest.testExpression("timestamp_shift(l1, 'P1M', 1)", this.types);
    }

    static void testFunctions(Map<String, ExpressionType> types, List<String> templates, List<String> args) {
        for (String template : templates) {
            for (String arg : args) {
                String expr = StringUtils.format((String)template, (Object[])new Object[]{arg});
                VectorExprResultConsistencyTest.testExpression(expr, types);
            }
        }
    }

    static void testFunctions(Map<String, ExpressionType> types, List<String> templates, Set<List<String>> argsArrays) {
        for (String template : templates) {
            for (List<String> args : argsArrays) {
                String expr = StringUtils.format((String)template, (Object[])args.toArray());
                VectorExprResultConsistencyTest.testExpression(expr, types);
            }
        }
    }

    public static void testExpression(String expr, Map<String, ExpressionType> types) {
        VectorExprResultConsistencyTest.testExpression(expr, types, MACRO_TABLE);
    }

    public static void testExpression(String expr, Map<String, ExpressionType> types, ExprMacroTable macroTable) {
        log.debug("running expression [%s]", new Object[]{expr});
        Expr parsed = Parser.parse((String)expr, (ExprMacroTable)macroTable);
        VectorExprResultConsistencyTest.testExpressionRandomizedBindings(expr, parsed, types, 10);
        VectorExprResultConsistencyTest.testExpressionSequentialBindings(expr, parsed, types, 10);
    }

    public static void testExpressionSequentialBindings(String expr, Expr parsed, Map<String, ExpressionType> types, int numIterations) {
        for (int iter = 0; iter < numIterations; ++iter) {
            NonnullPair<Expr.ObjectBinding[], Expr.VectorInputBinding> bindings = VectorExprResultConsistencyTest.makeSequentialBinding(512, types, 1 + iter * 512);
            Assert.assertTrue((String)StringUtils.format((String)"Cannot vectorize %s", (Object[])new Object[]{expr}), (boolean)parsed.canVectorize((Expr.InputBindingInspector)bindings.rhs));
            ExpressionType outputType = parsed.getOutputType((Expr.InputBindingInspector)bindings.rhs);
            ExprEvalVector vectorEval = parsed.asVectorProcessor((Expr.VectorInputBindingInspector)bindings.rhs).evalVector((Expr.VectorInputBinding)bindings.rhs);
            if (outputType != null) {
                Assert.assertEquals((String)expr, (Object)outputType, (Object)vectorEval.getType());
            }
            Object[] vectorVals = vectorEval.getObjectVector();
            for (int i = 0; i < 512; ++i) {
                ExprEval eval = parsed.eval(((Expr.ObjectBinding[])bindings.lhs)[i]);
                if (outputType != null && !eval.isNumericNull()) {
                    Assert.assertEquals((Object)eval.type(), (Object)outputType);
                }
                if (outputType != null && outputType.isArray()) {
                    Assert.assertArrayEquals((String)StringUtils.format((String)"Values do not match for row %s for expression %s", (Object[])new Object[]{i, expr}), (Object[])((Object[])eval.valueOrDefault()), (Object[])((Object[])vectorVals[i]));
                    continue;
                }
                Assert.assertEquals((String)StringUtils.format((String)"Values do not match for row %s for expression %s", (Object[])new Object[]{i, expr}), (Object)eval.valueOrDefault(), (Object)vectorVals[i]);
            }
        }
    }

    public static void testExpressionRandomizedBindings(String expr, Expr parsed, Map<String, ExpressionType> types, int numIterations) {
        final Expr.InputBindingInspector inspector = InputBindings.inspectorFromTypeMap(types);
        Expr.VectorInputBindingInspector vectorInputBindingInspector = new Expr.VectorInputBindingInspector(){

            public int getMaxVectorSize() {
                return 512;
            }

            @Nullable
            public ExpressionType getType(String name) {
                return inspector.getType(name);
            }
        };
        Assert.assertTrue((String)StringUtils.format((String)"Cannot vectorize %s", (Object[])new Object[]{expr}), (boolean)parsed.canVectorize(inspector));
        ExpressionType outputType = parsed.getOutputType(inspector);
        ExprVectorProcessor processor = parsed.asVectorProcessor(vectorInputBindingInspector);
        if (outputType != null) {
            Assert.assertEquals((String)expr, (Object)outputType, (Object)processor.getOutputType());
        }
        for (int iterations = 0; iterations < numIterations; ++iterations) {
            NonnullPair<Expr.ObjectBinding[], Expr.VectorInputBinding> bindings = VectorExprResultConsistencyTest.makeRandomizedBindings(512, types);
            ExprEvalVector vectorEval = processor.evalVector((Expr.VectorInputBinding)bindings.rhs);
            Object[] vectorVals = vectorEval.getObjectVector();
            for (int i = 0; i < 512; ++i) {
                ExprEval eval = parsed.eval(((Expr.ObjectBinding[])bindings.lhs)[i]);
                if (outputType != null && !eval.isNumericNull()) {
                    Assert.assertEquals((Object)eval.type(), (Object)outputType);
                }
                if (outputType != null && outputType.isArray()) {
                    Assert.assertArrayEquals((String)StringUtils.format((String)"Values do not match for row %s for expression %s", (Object[])new Object[]{i, expr}), (Object[])((Object[])eval.valueOrDefault()), (Object[])((Object[])vectorVals[i]));
                    continue;
                }
                Assert.assertEquals((String)StringUtils.format((String)"Values do not match for row %s for expression %s", (Object[])new Object[]{i, expr}), (Object)eval.valueOrDefault(), (Object)vectorVals[i]);
            }
        }
    }

    public static NonnullPair<Expr.ObjectBinding[], Expr.VectorInputBinding> makeRandomizedBindings(int vectorSize, Map<String, ExpressionType> types) {
        ThreadLocalRandom r = ThreadLocalRandom.current();
        return VectorExprResultConsistencyTest.populateBindings(vectorSize, types, () -> r.nextLong(0L, 0x7FFFFFFEL), r::nextDouble, () -> r.nextDouble(0.0, 1.0) > 0.9, () -> String.valueOf(r.nextInt()));
    }

    public static NonnullPair<Expr.ObjectBinding[], Expr.VectorInputBinding> makeSequentialBinding(int vectorSize, Map<String, ExpressionType> types, final int start) {
        return VectorExprResultConsistencyTest.populateBindings(vectorSize, types, new LongSupplier(){
            int counter;
            {
                this.counter = start;
            }

            @Override
            public long getAsLong() {
                return this.counter++;
            }
        }, new DoubleSupplier(){
            int counter;
            {
                this.counter = start;
            }

            @Override
            public double getAsDouble() {
                return this.counter++;
            }
        }, () -> ThreadLocalRandom.current().nextBoolean(), new Supplier<String>(){
            int counter;
            {
                this.counter = start;
            }

            @Override
            public String get() {
                return String.valueOf(this.counter++);
            }
        });
    }

    static NonnullPair<Expr.ObjectBinding[], Expr.VectorInputBinding> populateBindings(int vectorSize, Map<String, ExpressionType> types, LongSupplier longsFn, DoubleSupplier doublesFn, BooleanSupplier nullsFn, Supplier<String> stringFn) {
        SettableVectorInputBinding vectorBinding = new SettableVectorInputBinding(vectorSize);
        SettableObjectBinding[] objectBindings = new SettableObjectBinding[vectorSize];
        for (Map.Entry<String, ExpressionType> entry : types.entrySet()) {
            boolean[] nulls = new boolean[vectorSize];
            switch ((ExprType)entry.getValue().getType()) {
                case LONG: {
                    long[] longs = new long[vectorSize];
                    for (int i = 0; i < vectorSize; ++i) {
                        nulls[i] = nullsFn.getAsBoolean();
                        long l = longs[i] = nulls[i] ? 0L : longsFn.getAsLong();
                        if (objectBindings[i] == null) {
                            objectBindings[i] = new SettableObjectBinding();
                        }
                        objectBindings[i].withBinding(entry.getKey(), nulls[i] ? null : Long.valueOf(longs[i]));
                    }
                    vectorBinding.addLong(entry.getKey(), longs, nulls);
                    break;
                }
                case DOUBLE: {
                    double[] doubles = new double[vectorSize];
                    for (int i = 0; i < vectorSize; ++i) {
                        nulls[i] = nullsFn.getAsBoolean();
                        double d = doubles[i] = nulls[i] ? 0.0 : doublesFn.getAsDouble();
                        if (objectBindings[i] == null) {
                            objectBindings[i] = new SettableObjectBinding();
                        }
                        objectBindings[i].withBinding(entry.getKey(), nulls[i] ? null : Double.valueOf(doubles[i]));
                    }
                    vectorBinding.addDouble(entry.getKey(), doubles, nulls);
                    break;
                }
                case STRING: {
                    Object[] strings = new String[vectorSize];
                    for (int i = 0; i < vectorSize; ++i) {
                        nulls[i] = nullsFn.getAsBoolean();
                        if (!nulls[i] && entry.getKey().startsWith("boolString")) {
                            strings[i] = String.valueOf(nullsFn.getAsBoolean());
                        } else {
                            Object object = strings[i] = nulls[i] ? null : String.valueOf(stringFn.get());
                        }
                        if (objectBindings[i] == null) {
                            objectBindings[i] = new SettableObjectBinding();
                        }
                        objectBindings[i].withBinding(entry.getKey(), nulls[i] ? null : strings[i]);
                    }
                    vectorBinding.addString(entry.getKey(), strings);
                }
            }
        }
        return new NonnullPair((Object)objectBindings, (Object)vectorBinding);
    }
}

