/*
 * Decompiled with CFR 0.152.
 */
package io.substrait.dsl;

import io.substrait.expression.AggregateFunctionInvocation;
import io.substrait.expression.Expression;
import io.substrait.expression.FieldReference;
import io.substrait.expression.ImmutableExpression;
import io.substrait.expression.ImmutableFieldReference;
import io.substrait.expression.WindowBound;
import io.substrait.extension.SimpleExtension;
import io.substrait.function.ToTypeString;
import io.substrait.plan.ImmutablePlan;
import io.substrait.plan.ImmutableRoot;
import io.substrait.plan.Plan;
import io.substrait.relation.Aggregate;
import io.substrait.relation.Cross;
import io.substrait.relation.Expand;
import io.substrait.relation.Fetch;
import io.substrait.relation.Filter;
import io.substrait.relation.Join;
import io.substrait.relation.NamedScan;
import io.substrait.relation.Project;
import io.substrait.relation.Rel;
import io.substrait.relation.Set;
import io.substrait.relation.Sort;
import io.substrait.relation.physical.HashJoin;
import io.substrait.relation.physical.MergeJoin;
import io.substrait.relation.physical.NestedLoopJoin;
import io.substrait.type.ImmutableType;
import io.substrait.type.NamedStruct;
import io.substrait.type.Type;
import io.substrait.type.TypeCreator;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.OptionalLong;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class SubstraitBuilder {
    static final TypeCreator R = TypeCreator.of(false);
    static final TypeCreator N = TypeCreator.of(true);
    private final SimpleExtension.ExtensionCollection extensions;

    public SubstraitBuilder(SimpleExtension.ExtensionCollection extensions) {
        this.extensions = extensions;
    }

    public Aggregate.Measure measure(AggregateFunctionInvocation aggFn) {
        return Aggregate.Measure.builder().function(aggFn).build();
    }

    public Aggregate.Measure measure(AggregateFunctionInvocation aggFn, Expression preMeasureFilter) {
        return Aggregate.Measure.builder().function(aggFn).preMeasureFilter(preMeasureFilter).build();
    }

    public Aggregate aggregate(Function<Rel, Aggregate.Grouping> groupingFn, Function<Rel, List<Aggregate.Measure>> measuresFn, Rel input) {
        Function<Rel, List<Aggregate.Grouping>> groupingsFn = groupingFn.andThen(g -> Stream.of(g).collect(Collectors.toList()));
        return this.aggregate(groupingsFn, measuresFn, Optional.empty(), input);
    }

    public Aggregate aggregate(Function<Rel, Aggregate.Grouping> groupingFn, Function<Rel, List<Aggregate.Measure>> measuresFn, Rel.Remap remap, Rel input) {
        Function<Rel, List<Aggregate.Grouping>> groupingsFn = groupingFn.andThen(g -> Stream.of(g).collect(Collectors.toList()));
        return this.aggregate(groupingsFn, measuresFn, Optional.of(remap), input);
    }

    private Aggregate aggregate(Function<Rel, List<Aggregate.Grouping>> groupingsFn, Function<Rel, List<Aggregate.Measure>> measuresFn, Optional<Rel.Remap> remap, Rel input) {
        List<Aggregate.Grouping> groupings = groupingsFn.apply(input);
        List<Aggregate.Measure> measures = measuresFn.apply(input);
        return Aggregate.builder().groupings(groupings).measures(measures).remap(remap).input(input).build();
    }

    public Cross cross(Rel left, Rel right) {
        return this.cross(left, right, Optional.empty());
    }

    public Cross cross(Rel left, Rel right, Rel.Remap remap) {
        return this.cross(left, right, Optional.of(remap));
    }

    private Cross cross(Rel left, Rel right, Optional<Rel.Remap> remap) {
        return Cross.builder().left(left).right(right).remap(remap).build();
    }

    public Fetch fetch(long offset, long count, Rel input) {
        return this.fetch(offset, OptionalLong.of(count), Optional.empty(), input);
    }

    public Fetch fetch(long offset, long count, Rel.Remap remap, Rel input) {
        return this.fetch(offset, OptionalLong.of(count), Optional.of(remap), input);
    }

    public Fetch limit(long limit, Rel input) {
        return this.fetch(0L, OptionalLong.of(limit), Optional.empty(), input);
    }

    public Fetch limit(long limit, Rel.Remap remap, Rel input) {
        return this.fetch(0L, OptionalLong.of(limit), Optional.of(remap), input);
    }

    public Fetch offset(long offset, Rel input) {
        return this.fetch(offset, OptionalLong.empty(), Optional.empty(), input);
    }

    public Fetch offset(long offset, Rel.Remap remap, Rel input) {
        return this.fetch(offset, OptionalLong.empty(), Optional.of(remap), input);
    }

    private Fetch fetch(long offset, OptionalLong count, Optional<Rel.Remap> remap, Rel input) {
        return Fetch.builder().offset(offset).count(count).input(input).remap(remap).build();
    }

    public Filter filter(Function<Rel, Expression> conditionFn, Rel input) {
        return this.filter(conditionFn, Optional.empty(), input);
    }

    public Filter filter(Function<Rel, Expression> conditionFn, Rel.Remap remap, Rel input) {
        return this.filter(conditionFn, Optional.of(remap), input);
    }

    private Filter filter(Function<Rel, Expression> conditionFn, Optional<Rel.Remap> remap, Rel input) {
        Expression condition = conditionFn.apply(input);
        return Filter.builder().input(input).condition(condition).remap(remap).build();
    }

    public Join innerJoin(Function<JoinInput, Expression> conditionFn, Rel left, Rel right) {
        return this.join(conditionFn, Join.JoinType.INNER, left, right);
    }

    public Join innerJoin(Function<JoinInput, Expression> conditionFn, Rel.Remap remap, Rel left, Rel right) {
        return this.join(conditionFn, Join.JoinType.INNER, remap, left, right);
    }

    public Join join(Function<JoinInput, Expression> conditionFn, Join.JoinType joinType, Rel left, Rel right) {
        return this.join(conditionFn, joinType, Optional.empty(), left, right);
    }

    public Join join(Function<JoinInput, Expression> conditionFn, Join.JoinType joinType, Rel.Remap remap, Rel left, Rel right) {
        return this.join(conditionFn, joinType, Optional.of(remap), left, right);
    }

    private Join join(Function<JoinInput, Expression> conditionFn, Join.JoinType joinType, Optional<Rel.Remap> remap, Rel left, Rel right) {
        Expression condition = conditionFn.apply(new JoinInput(left, right));
        return Join.builder().left(left).right(right).condition(condition).joinType(joinType).remap(remap).build();
    }

    public HashJoin hashJoin(List<Integer> leftKeys, List<Integer> rightKeys, HashJoin.JoinType joinType, Rel left, Rel right) {
        return this.hashJoin(leftKeys, rightKeys, joinType, Optional.empty(), left, right);
    }

    public HashJoin hashJoin(List<Integer> leftKeys, List<Integer> rightKeys, HashJoin.JoinType joinType, Optional<Rel.Remap> remap, Rel left, Rel right) {
        return HashJoin.builder().left(left).right(right).leftKeys(this.fieldReferences(left, leftKeys.stream().mapToInt(Integer::intValue).toArray())).rightKeys(this.fieldReferences(right, rightKeys.stream().mapToInt(Integer::intValue).toArray())).joinType(joinType).remap(remap).build();
    }

    public MergeJoin mergeJoin(List<Integer> leftKeys, List<Integer> rightKeys, MergeJoin.JoinType joinType, Rel left, Rel right) {
        return this.mergeJoin(leftKeys, rightKeys, joinType, Optional.empty(), left, right);
    }

    public MergeJoin mergeJoin(List<Integer> leftKeys, List<Integer> rightKeys, MergeJoin.JoinType joinType, Optional<Rel.Remap> remap, Rel left, Rel right) {
        return MergeJoin.builder().left(left).right(right).leftKeys(this.fieldReferences(left, leftKeys.stream().mapToInt(Integer::intValue).toArray())).rightKeys(this.fieldReferences(right, rightKeys.stream().mapToInt(Integer::intValue).toArray())).joinType(joinType).remap(remap).build();
    }

    public NestedLoopJoin nestedLoopJoin(Function<JoinInput, Expression> conditionFn, NestedLoopJoin.JoinType joinType, Rel left, Rel right) {
        return this.nestedLoopJoin(conditionFn, joinType, Optional.empty(), left, right);
    }

    private NestedLoopJoin nestedLoopJoin(Function<JoinInput, Expression> conditionFn, NestedLoopJoin.JoinType joinType, Optional<Rel.Remap> remap, Rel left, Rel right) {
        Expression condition = conditionFn.apply(new JoinInput(left, right));
        return NestedLoopJoin.builder().left(left).right(right).condition(condition).joinType(joinType).remap(remap).build();
    }

    public NamedScan namedScan(Iterable<String> tableName, Iterable<String> columnNames, Iterable<Type> types) {
        return this.namedScan(tableName, columnNames, types, Optional.empty());
    }

    public NamedScan namedScan(Iterable<String> tableName, Iterable<String> columnNames, Iterable<Type> types, Rel.Remap remap) {
        return this.namedScan(tableName, columnNames, types, Optional.of(remap));
    }

    private NamedScan namedScan(Iterable<String> tableName, Iterable<String> columnNames, Iterable<Type> types, Optional<Rel.Remap> remap) {
        ImmutableType.Struct struct = Type.Struct.builder().addAllFields(types).nullable(false).build();
        NamedStruct namedStruct = NamedStruct.of(columnNames, struct);
        return NamedScan.builder().names(tableName).initialSchema(namedStruct).remap(remap).build();
    }

    public Project project(Function<Rel, Iterable<? extends Expression>> expressionsFn, Rel input) {
        return this.project(expressionsFn, Optional.empty(), input);
    }

    public Project project(Function<Rel, Iterable<? extends Expression>> expressionsFn, Rel.Remap remap, Rel input) {
        return this.project(expressionsFn, Optional.of(remap), input);
    }

    private Project project(Function<Rel, Iterable<? extends Expression>> expressionsFn, Optional<Rel.Remap> remap, Rel input) {
        Iterable<? extends Expression> expressions = expressionsFn.apply(input);
        return Project.builder().input(input).expressions(expressions).remap(remap).build();
    }

    public Expand expand(Function<Rel, Iterable<? extends Expand.ExpandField>> fieldsFn, Rel input) {
        return this.expand(fieldsFn, Optional.empty(), input);
    }

    public Expand expand(Function<Rel, Iterable<? extends Expand.ExpandField>> fieldsFn, Rel.Remap remap, Rel input) {
        return this.expand(fieldsFn, Optional.of(remap), input);
    }

    private Expand expand(Function<Rel, Iterable<? extends Expand.ExpandField>> fieldsFn, Optional<Rel.Remap> remap, Rel input) {
        Iterable<? extends Expand.ExpandField> fields = fieldsFn.apply(input);
        return Expand.builder().input(input).fields(fields).remap(remap).build();
    }

    public Set set(Set.SetOp op, Rel ... inputs) {
        return this.set(op, Optional.empty(), inputs);
    }

    public Set set(Set.SetOp op, Rel.Remap remap, Rel ... inputs) {
        return this.set(op, Optional.of(remap), inputs);
    }

    private Set set(Set.SetOp op, Optional<Rel.Remap> remap, Rel ... inputs) {
        return Set.builder().setOp(op).remap(remap).addAllInputs(Arrays.asList(inputs)).build();
    }

    public Sort sort(Function<Rel, Iterable<? extends Expression.SortField>> sortFieldFn, Rel input) {
        return this.sort(sortFieldFn, Optional.empty(), input);
    }

    public Sort sort(Function<Rel, Iterable<? extends Expression.SortField>> sortFieldFn, Rel.Remap remap, Rel input) {
        return this.sort(sortFieldFn, Optional.of(remap), input);
    }

    private Sort sort(Function<Rel, Iterable<? extends Expression.SortField>> sortFieldFn, Optional<Rel.Remap> remap, Rel input) {
        Iterable<? extends Expression.SortField> condition = sortFieldFn.apply(input);
        return Sort.builder().input(input).sortFields(condition).remap(remap).build();
    }

    public Expression.BoolLiteral bool(boolean v) {
        return Expression.BoolLiteral.builder().value(v).build();
    }

    public Expression.I32Literal i32(int v) {
        return Expression.I32Literal.builder().value(v).build();
    }

    public Expression.FP64Literal fp64(double v) {
        return Expression.FP64Literal.builder().value(v).build();
    }

    public Expression cast(Expression input, Type type) {
        return ImmutableExpression.Cast.builder().input(input).type(type).failureBehavior(Expression.FailureBehavior.UNSPECIFIED).build();
    }

    public FieldReference fieldReference(Rel input, int index) {
        return ImmutableFieldReference.newInputRelReference(index, input);
    }

    public List<FieldReference> fieldReferences(Rel input, int ... indexes) {
        return Arrays.stream(indexes).mapToObj(index -> this.fieldReference(input, index)).collect(Collectors.toList());
    }

    public FieldReference fieldReference(List<Rel> inputs, int index) {
        return ImmutableFieldReference.newInputRelReference(index, inputs);
    }

    public List<FieldReference> fieldReferences(List<Rel> inputs, int ... indexes) {
        return Arrays.stream(indexes).mapToObj(index -> this.fieldReference(inputs, index)).collect(Collectors.toList());
    }

    public Expression.IfThen ifThen(Iterable<? extends Expression.IfClause> ifClauses, Expression elseClause) {
        return Expression.IfThen.builder().addAllIfClauses(ifClauses).elseClause(elseClause).build();
    }

    public Expression.IfClause ifClause(Expression condition, Expression then) {
        return Expression.IfClause.builder().condition(condition).then(then).build();
    }

    public Expression singleOrList(Expression condition, Expression ... options) {
        return ImmutableExpression.SingleOrList.builder().condition(condition).addOptions(options).build();
    }

    public Expression.InPredicate inPredicate(Rel haystack, Expression ... needles) {
        return Expression.InPredicate.builder().addAllNeedles(Arrays.asList(needles)).haystack(haystack).build();
    }

    public List<Expression.SortField> sortFields(Rel input, int ... indexes) {
        return Arrays.stream(indexes).mapToObj(index -> Expression.SortField.builder().expr(ImmutableFieldReference.newInputRelReference(index, input)).direction(Expression.SortDirection.ASC_NULLS_LAST).build()).collect(Collectors.toList());
    }

    public Expression.SortField sortField(Expression expression, Expression.SortDirection sortDirection) {
        return Expression.SortField.builder().expr(expression).direction(sortDirection).build();
    }

    public Expression.SwitchClause switchClause(Expression.Literal condition, Expression then) {
        return Expression.SwitchClause.builder().condition(condition).then(then).build();
    }

    public ImmutableExpression.Switch switchExpression(Expression match, Iterable<? extends Expression.SwitchClause> clauses, Expression defaultClause) {
        return ImmutableExpression.Switch.builder().match(match).addAllSwitchClauses(clauses).defaultClause(defaultClause).build();
    }

    public AggregateFunctionInvocation aggregateFn(String namespace, String key, Type outputType, Expression ... args) {
        SimpleExtension.AggregateFunctionVariant declaration = this.extensions.getAggregateFunction(SimpleExtension.FunctionAnchor.of(namespace, key));
        return AggregateFunctionInvocation.builder().arguments(Arrays.stream(args).collect(Collectors.toList())).outputType(outputType).declaration(declaration).aggregationPhase(Expression.AggregationPhase.INITIAL_TO_RESULT).invocation(Expression.AggregationInvocation.ALL).build();
    }

    public Aggregate.Grouping grouping(Rel input, int ... indexes) {
        List<FieldReference> columns = this.fieldReferences(input, indexes);
        return Aggregate.Grouping.builder().addAllExpressions(columns).build();
    }

    public Aggregate.Grouping grouping(Expression ... expressions) {
        return Aggregate.Grouping.builder().addExpressions(expressions).build();
    }

    public Aggregate.Measure count(Rel input, int field) {
        SimpleExtension.AggregateFunctionVariant declaration = this.extensions.getAggregateFunction(SimpleExtension.FunctionAnchor.of("/functions_aggregate_generic.yaml", "count:any"));
        return this.measure(AggregateFunctionInvocation.builder().arguments(this.fieldReferences(input, field)).outputType(SubstraitBuilder.R.I64).declaration(declaration).aggregationPhase(Expression.AggregationPhase.INITIAL_TO_RESULT).invocation(Expression.AggregationInvocation.ALL).build());
    }

    public Aggregate.Measure min(Rel input, int field) {
        return this.min(this.fieldReference(input, field));
    }

    public Aggregate.Measure min(Expression expr) {
        return this.singleArgumentArithmeticAggregate(expr, "min", TypeCreator.asNullable(expr.getType()));
    }

    public Aggregate.Measure max(Rel input, int field) {
        return this.max(this.fieldReference(input, field));
    }

    public Aggregate.Measure max(Expression expr) {
        return this.singleArgumentArithmeticAggregate(expr, "max", TypeCreator.asNullable(expr.getType()));
    }

    public Aggregate.Measure avg(Rel input, int field) {
        return this.avg(this.fieldReference(input, field));
    }

    public Aggregate.Measure avg(Expression expr) {
        return this.singleArgumentArithmeticAggregate(expr, "avg", TypeCreator.asNullable(expr.getType()));
    }

    public Aggregate.Measure sum(Rel input, int field) {
        return this.sum(this.fieldReference(input, field));
    }

    public Aggregate.Measure sum(Expression expr) {
        return this.singleArgumentArithmeticAggregate(expr, "sum", TypeCreator.asNullable(expr.getType()));
    }

    public Aggregate.Measure sum0(Rel input, int field) {
        return this.sum(this.fieldReference(input, field));
    }

    public Aggregate.Measure sum0(Expression expr) {
        return this.singleArgumentArithmeticAggregate(expr, "sum0", SubstraitBuilder.R.I64);
    }

    private Aggregate.Measure singleArgumentArithmeticAggregate(Expression expr, String functionName, Type outputType) {
        String typeString = ToTypeString.apply(expr.getType());
        SimpleExtension.AggregateFunctionVariant declaration = this.extensions.getAggregateFunction(SimpleExtension.FunctionAnchor.of("/functions_arithmetic.yaml", String.format("%s:%s", functionName, typeString)));
        return this.measure(AggregateFunctionInvocation.builder().arguments(Arrays.asList(expr)).outputType(outputType).declaration(declaration).aggregationPhase(Expression.AggregationPhase.INITIAL_TO_RESULT).invocation(Expression.AggregationInvocation.ALL).build());
    }

    public Expression.ScalarFunctionInvocation negate(Expression expr) {
        Type outputType = expr.getType();
        return this.scalarFn("/functions_arithmetic.yaml", String.format("negate:%s", ToTypeString.apply(outputType)), outputType, expr);
    }

    public Expression.ScalarFunctionInvocation add(Expression left, Expression right) {
        return this.arithmeticFunction("add", left, right);
    }

    public Expression.ScalarFunctionInvocation subtract(Expression left, Expression right) {
        return this.arithmeticFunction("substract", left, right);
    }

    public Expression.ScalarFunctionInvocation multiply(Expression left, Expression right) {
        return this.arithmeticFunction("multiply", left, right);
    }

    public Expression.ScalarFunctionInvocation divide(Expression left, Expression right) {
        return this.arithmeticFunction("divide", left, right);
    }

    private Expression.ScalarFunctionInvocation arithmeticFunction(String fname, Expression left, Expression right) {
        String leftTypeStr = ToTypeString.apply(left.getType());
        String rightTypeStr = ToTypeString.apply(right.getType());
        String key = String.format("%s:%s_%s", fname, leftTypeStr, rightTypeStr);
        boolean isOutputNullable = left.getType().nullable() || right.getType().nullable();
        Type outputType = left.getType();
        outputType = isOutputNullable ? TypeCreator.asNullable(outputType) : TypeCreator.asNotNullable(outputType);
        return this.scalarFn("/functions_arithmetic.yaml", key, outputType, left, right);
    }

    public Expression.ScalarFunctionInvocation equal(Expression left, Expression right) {
        return this.scalarFn("/functions_comparison.yaml", "equal:any_any", SubstraitBuilder.R.BOOLEAN, left, right);
    }

    public Expression.ScalarFunctionInvocation or(Expression ... args) {
        boolean isOutputNullable = Arrays.stream(args).anyMatch(a -> a.getType().nullable());
        Type outputType = isOutputNullable ? SubstraitBuilder.N.BOOLEAN : SubstraitBuilder.R.BOOLEAN;
        return this.scalarFn("/functions_boolean.yaml", "or:bool", outputType, args);
    }

    public Expression.ScalarFunctionInvocation scalarFn(String namespace, String key, Type outputType, Expression ... args) {
        SimpleExtension.ScalarFunctionVariant declaration = this.extensions.getScalarFunction(SimpleExtension.FunctionAnchor.of(namespace, key));
        return Expression.ScalarFunctionInvocation.builder().declaration(declaration).outputType(outputType).arguments(Arrays.stream(args).collect(Collectors.toList())).build();
    }

    public Expression.WindowFunctionInvocation windowFn(String namespace, String key, Type outputType, Expression.AggregationPhase aggregationPhase, Expression.AggregationInvocation invocation, Expression.WindowBoundsType boundsType, WindowBound lowerBound, WindowBound upperBound, Expression ... args) {
        SimpleExtension.WindowFunctionVariant declaration = this.extensions.getWindowFunction(SimpleExtension.FunctionAnchor.of(namespace, key));
        return Expression.WindowFunctionInvocation.builder().declaration(declaration).outputType(outputType).aggregationPhase(aggregationPhase).invocation(invocation).boundsType(boundsType).lowerBound(lowerBound).upperBound(upperBound).arguments(Arrays.stream(args).collect(Collectors.toList())).build();
    }

    public Type.UserDefined userDefinedType(String namespace, String typeName) {
        return ImmutableType.UserDefined.builder().uri(namespace).name(typeName).nullable(false).build();
    }

    public Plan.Root root(Rel rel) {
        return ImmutableRoot.builder().input(rel).build();
    }

    public Plan plan(Plan.Root root) {
        return ImmutablePlan.builder().addRoots(root).build();
    }

    public Rel.Remap remap(Integer ... fields) {
        return Rel.Remap.of(Arrays.asList(fields));
    }

    public static final class JoinInput {
        private final Rel left;
        private final Rel right;

        public JoinInput(Rel left, Rel right) {
            this.left = left;
            this.right = right;
        }

        public String toString() {
            return "JoinInput[" + "left=" + this.left + "," + "right=" + this.right + "]";
        }

        public int hashCode() {
            int result = 0;
            result = 31 * result + (this.left != null ? this.left.hashCode() : 0);
            result = 31 * result + (this.right != null ? this.right.hashCode() : 0);
            return result;
        }

        public final boolean equals(Object arg0) {
            if (this == arg0) {
                return true;
            }
            if (arg0 == null) {
                return false;
            }
            if (arg0.getClass() != this.getClass()) {
                return false;
            }
            if (!Objects.equals(((JoinInput)arg0).left, this.left)) {
                return false;
            }
            return Objects.equals(((JoinInput)arg0).right, this.right);
            {
            }
        }

        public Rel left() {
            return this.left;
        }

        public Rel right() {
            return this.right;
        }
    }
}

