/*
 * Decompiled with CFR 0.152.
 */
package com.linkedin.coral.trino.rel2trino.transformers;

import com.linkedin.coral.calcite.$internal.com.google.common.collect.ImmutableList;
import com.linkedin.coral.common.functions.GenericProjectFunction;
import com.linkedin.coral.common.transformers.SqlCallTransformer;
import com.linkedin.coral.common.utils.TypeDerivationUtil;
import com.linkedin.coral.trino.rel2trino.functions.RelDataTypeToTrinoTypeStringConverter;
import com.linkedin.coral.trino.rel2trino.functions.TrinoArrayTransformFunction;
import com.linkedin.coral.trino.rel2trino.functions.TrinoKeywordsConverter;
import com.linkedin.coral.trino.rel2trino.functions.TrinoMapTransformValuesFunction;
import com.linkedin.coral.trino.rel2trino.functions.TrinoStructCastRowFunction;
import java.util.ArrayList;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rel.type.RelRecordType;
import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlIdentifier;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.type.ArraySqlType;
import org.apache.calcite.sql.type.MapSqlType;

public class GenericProjectTransformer
extends SqlCallTransformer {
    public GenericProjectTransformer(TypeDerivationUtil typeDerivationUtil) {
        super(typeDerivationUtil);
    }

    @Override
    protected boolean condition(SqlCall sqlCall) {
        return sqlCall.getOperator() instanceof GenericProjectFunction;
    }

    @Override
    protected SqlCall transform(SqlCall sqlCall) {
        return this.convertGenericProject(sqlCall);
    }

    private SqlCall convertGenericProject(SqlCall call) {
        SqlNode transformColumn = call.getOperandList().get(0);
        ImmutableList<String> transformColumnFieldFullName = ((SqlIdentifier)transformColumn).names;
        String transformColumnFieldName = (String)transformColumnFieldFullName.get(transformColumnFieldFullName.size() - 1);
        RelDataType fromDataType = this.deriveRelDatatype(transformColumn);
        RelDataType toDataType = call.getOperator().inferReturnType(null);
        switch (toDataType.getSqlTypeName()) {
            case ROW: {
                String structDataTypeArgumentString = this.structDataTypeArgumentString((RelRecordType)fromDataType, (RelRecordType)toDataType, transformColumnFieldName);
                TrinoStructCastRowFunction structFunction = new TrinoStructCastRowFunction(toDataType);
                return structFunction.createCall(SqlParserPos.ZERO, new SqlIdentifier(structDataTypeArgumentString, SqlParserPos.ZERO));
            }
            case ARRAY: {
                String arrayDataTypeArgumentString = this.arrayDataTypeArgumentString((ArraySqlType)fromDataType, (ArraySqlType)toDataType, transformColumnFieldName);
                TrinoArrayTransformFunction arrayFunction = new TrinoArrayTransformFunction(toDataType);
                return arrayFunction.createCall(SqlParserPos.ZERO, new SqlIdentifier(arrayDataTypeArgumentString, SqlParserPos.ZERO));
            }
            case MAP: {
                String mapDataTypeArgumentString = this.mapDataTypeArgumentString((MapSqlType)fromDataType, (MapSqlType)toDataType, transformColumnFieldName);
                TrinoMapTransformValuesFunction mapFunction = new TrinoMapTransformValuesFunction(toDataType);
                return mapFunction.createCall(SqlParserPos.ZERO, new SqlIdentifier(mapDataTypeArgumentString, SqlParserPos.ZERO));
            }
        }
        return call;
    }

    private String mapDataTypeString(MapSqlType fromDataType, MapSqlType toDataType, String fieldNameReference) {
        String mapDataTypeArgumentString = this.mapDataTypeArgumentString(fromDataType, toDataType, fieldNameReference);
        return String.format("transform_values(%s)", mapDataTypeArgumentString);
    }

    private String mapDataTypeArgumentString(MapSqlType fromDataType, MapSqlType toDataType, String fieldNameReference) {
        String mapKeyFieldReference = "k";
        String mapValueFieldReference = "v";
        String valueTransformedFieldString = this.relDataTypeFieldAccessString(fromDataType.getValueType(), toDataType.getValueType(), mapValueFieldReference);
        return String.format("%s, (%s, %s) -> %s", fieldNameReference, mapKeyFieldReference, mapValueFieldReference, valueTransformedFieldString);
    }

    private String arrayDataTypeString(ArraySqlType fromDataType, ArraySqlType toDataType, String fieldNameReference) {
        String arrayDataTypeArgumentString = this.arrayDataTypeArgumentString(fromDataType, toDataType, fieldNameReference);
        return String.format("transform(%s)", arrayDataTypeArgumentString);
    }

    private String arrayDataTypeArgumentString(ArraySqlType fromDataType, ArraySqlType toDataType, String fieldNameReference) {
        String elementFieldReference = "x";
        String elementTransformedFieldString = this.relDataTypeFieldAccessString(fromDataType.getComponentType(), toDataType.getComponentType(), elementFieldReference);
        return String.format("%s, %s -> %s", fieldNameReference, elementFieldReference, elementTransformedFieldString);
    }

    private String structDataTypeString(RelRecordType fromDataType, RelRecordType toDataType, String fieldNameReference) {
        String structDataTypeArgumentString = this.structDataTypeArgumentString(fromDataType, toDataType, fieldNameReference);
        return String.format("cast(%s)", structDataTypeArgumentString);
    }

    private String structDataTypeArgumentString(RelRecordType fromDataType, RelRecordType toDataType, String fieldNameReference) {
        String structFieldsAccessString = this.buildStructRelDataTypeFieldAccessString(fromDataType, toDataType, fieldNameReference);
        String castToRowTypeString = RelDataTypeToTrinoTypeStringConverter.buildTrinoTypeString(toDataType);
        return String.format("%s as %s", structFieldsAccessString, castToRowTypeString);
    }

    private String relDataTypeFieldAccessString(RelDataType fromDataType, RelDataType toDataType, String fieldNameReference) {
        if (fromDataType.equals(toDataType)) {
            return fieldNameReference;
        }
        switch (toDataType.getSqlTypeName()) {
            case ROW: {
                return this.structDataTypeString((RelRecordType)fromDataType, (RelRecordType)toDataType, fieldNameReference);
            }
            case ARRAY: {
                return this.arrayDataTypeString((ArraySqlType)fromDataType, (ArraySqlType)toDataType, fieldNameReference);
            }
            case MAP: {
                return this.mapDataTypeString((MapSqlType)fromDataType, (MapSqlType)toDataType, fieldNameReference);
            }
        }
        return fieldNameReference;
    }

    private String buildStructRelDataTypeFieldAccessString(RelRecordType fromDataType, RelRecordType toDataType, String fieldNameReference) {
        ArrayList<String> structSelectedFieldStrings = new ArrayList<String>();
        for (RelDataTypeField toDataTypeField : toDataType.getFieldList()) {
            RelDataTypeField fromDataTypeField = fromDataType.getField(toDataTypeField.getName(), false, false);
            if (fromDataTypeField == null) {
                throw new RuntimeException(String.format("Field %s was not found in column %s.", toDataTypeField.getName(), fieldNameReference));
            }
            String fromDataTypeFieldName = String.format("%s.%s", TrinoKeywordsConverter.quoteWordIfNotQuoted(fieldNameReference), TrinoKeywordsConverter.quoteWordIfNotQuoted(fromDataTypeField.getName()));
            structSelectedFieldStrings.add(this.relDataTypeFieldAccessString(fromDataTypeField.getType(), toDataTypeField.getType(), fromDataTypeFieldName));
        }
        String subFieldsString = String.join((CharSequence)", ", structSelectedFieldStrings);
        return String.format("row(%s)", subFieldsString);
    }
}

