/*
 * Decompiled with CFR 0.152.
 */
package io.trino.metadata;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.hash.Hashing;
import com.google.inject.Inject;
import io.trino.Session;
import io.trino.connector.system.GlobalSystemConnector;
import io.trino.execution.TaskId;
import io.trino.execution.warnings.WarningCollector;
import io.trino.metadata.FunctionManager;
import io.trino.metadata.LanguageFunctionProvider;
import io.trino.metadata.LanguageScalarFunctionData;
import io.trino.metadata.ResolvedFunction;
import io.trino.operator.scalar.SpecializedSqlScalarFunction;
import io.trino.security.AccessControl;
import io.trino.security.ViewAccessControl;
import io.trino.spi.ErrorCodeSupplier;
import io.trino.spi.QueryId;
import io.trino.spi.StandardErrorCode;
import io.trino.spi.TrinoException;
import io.trino.spi.connector.CatalogHandle;
import io.trino.spi.connector.CatalogSchemaName;
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.function.CatalogSchemaFunctionName;
import io.trino.spi.function.FunctionDependencies;
import io.trino.spi.function.FunctionId;
import io.trino.spi.function.FunctionMetadata;
import io.trino.spi.function.InvocationConvention;
import io.trino.spi.function.LanguageFunction;
import io.trino.spi.function.ScalarFunctionImplementation;
import io.trino.spi.function.SchemaFunctionName;
import io.trino.spi.security.GroupProvider;
import io.trino.spi.security.Identity;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeId;
import io.trino.spi.type.TypeManager;
import io.trino.sql.PlannerContext;
import io.trino.sql.SqlPath;
import io.trino.sql.analyzer.TypeSignatureTranslator;
import io.trino.sql.parser.SqlParser;
import io.trino.sql.routine.SqlRoutineAnalysis;
import io.trino.sql.routine.SqlRoutineAnalyzer;
import io.trino.sql.routine.SqlRoutineCompiler;
import io.trino.sql.routine.SqlRoutinePlanner;
import io.trino.sql.routine.ir.IrRoutine;
import io.trino.sql.tree.FunctionSpecification;
import io.trino.sql.tree.ParameterDeclaration;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;

public class LanguageFunctionManager
implements LanguageFunctionProvider {
    public static final String QUERY_LOCAL_SCHEMA = "$query";
    private static final String SQL_FUNCTION_PREFIX = "$trino_sql_";
    private final SqlParser parser;
    private final TypeManager typeManager;
    private final GroupProvider groupProvider;
    private SqlRoutineAnalyzer analyzer;
    private SqlRoutinePlanner planner;
    private final Map<QueryId, QueryFunctions> queryFunctions = new ConcurrentHashMap<QueryId, QueryFunctions>();

    @Inject
    public LanguageFunctionManager(SqlParser parser, TypeManager typeManager, GroupProvider groupProvider) {
        this.parser = Objects.requireNonNull(parser, "parser is null");
        this.typeManager = Objects.requireNonNull(typeManager, "typeManager is null");
        this.groupProvider = Objects.requireNonNull(groupProvider, "groupProvider is null");
    }

    public synchronized void setPlannerContext(PlannerContext plannerContext) {
        Preconditions.checkState((this.analyzer == null ? 1 : 0) != 0, (Object)"plannerContext already set");
        this.analyzer = new SqlRoutineAnalyzer(plannerContext, WarningCollector.NOOP);
        this.planner = new SqlRoutinePlanner(plannerContext);
    }

    public void tryRegisterQuery(Session session) {
        this.queryFunctions.putIfAbsent(session.getQueryId(), new QueryFunctions(session));
    }

    public void registerQuery(Session session) {
        boolean alreadyRegistered;
        boolean bl = alreadyRegistered = this.queryFunctions.putIfAbsent(session.getQueryId(), new QueryFunctions(session)) != null;
        if (alreadyRegistered) {
            throw new IllegalStateException("Query already registered: " + String.valueOf(session.getQueryId()));
        }
    }

    public void unregisterQuery(Session session) {
        this.queryFunctions.remove(session.getQueryId());
    }

    @Override
    public void registerTask(TaskId taskId, List<LanguageScalarFunctionData> languageFunctions) {
    }

    @Override
    public void unregisterTask(TaskId taskId) {
    }

    private QueryFunctions getQueryFunctions(Session session) {
        QueryFunctions queryFunctions = this.queryFunctions.get(session.getQueryId());
        if (queryFunctions == null) {
            throw new IllegalStateException("Query not registered: " + String.valueOf(session.getQueryId()));
        }
        return queryFunctions;
    }

    public List<FunctionMetadata> listFunctions(Collection<LanguageFunction> languageFunctions) {
        return (List)languageFunctions.stream().map(LanguageFunction::sql).map(sql -> SqlRoutineAnalyzer.extractFunctionMetadata(LanguageFunctionManager.createSqlLanguageFunctionId(sql), this.parser.createFunctionSpecification(sql))).collect(ImmutableList.toImmutableList());
    }

    public List<FunctionMetadata> getFunctions(Session session, CatalogHandle catalogHandle, SchemaFunctionName name, LanguageFunctionLoader languageFunctionLoader, RunAsIdentityLoader identityLoader) {
        return this.getQueryFunctions(session).getFunctions(catalogHandle, name, languageFunctionLoader, identityLoader);
    }

    public FunctionMetadata getFunctionMetadata(Session session, FunctionId functionId) {
        return this.getQueryFunctions(session).getFunctionMetadata(functionId);
    }

    public Set<ResolvedFunction> getDependencies(Session session, FunctionId functionId, AccessControl accessControl) {
        return this.getQueryFunctions(session).getDependencies(functionId, accessControl);
    }

    @Override
    public ScalarFunctionImplementation specialize(FunctionManager functionManager, ResolvedFunction resolvedFunction, FunctionDependencies functionDependencies, InvocationConvention invocationConvention) {
        return this.queryFunctions.values().stream().map(queryFunctions -> queryFunctions.specialize(resolvedFunction, functionManager, invocationConvention)).filter(Optional::isPresent).map(Optional::get).findFirst().orElseThrow(() -> new IllegalStateException("Unknown function implementation: " + String.valueOf(resolvedFunction.getFunctionId())));
    }

    public void registerResolvedFunction(Session session, ResolvedFunction resolvedFunction) {
        this.getQueryFunctions(session).registerResolvedFunction(resolvedFunction);
    }

    public List<LanguageScalarFunctionData> serializeFunctionsForWorkers(Session session) {
        return this.getQueryFunctions(session).serializeFunctionsForWorkers();
    }

    public void verifyForCreate(Session session, String sql, FunctionManager functionManager, AccessControl accessControl) {
        this.getQueryFunctions(session).verifyForCreate(sql, functionManager, accessControl);
    }

    public void addInlineFunction(Session session, String sql, AccessControl accessControl) {
        this.getQueryFunctions(session).addInlineFunction(sql, accessControl);
    }

    public static boolean isInlineFunction(CatalogSchemaFunctionName functionName) {
        return functionName.getCatalogName().equals("system") && functionName.getSchemaName().equals(QUERY_LOCAL_SCHEMA);
    }

    public static boolean isTrinoSqlLanguageFunction(FunctionId functionId) {
        return functionId.toString().startsWith(SQL_FUNCTION_PREFIX);
    }

    private static FunctionId createSqlLanguageFunctionId(String sql) {
        String hash = Hashing.sha256().hashUnencodedChars((CharSequence)sql).toString();
        return new FunctionId(SQL_FUNCTION_PREFIX + hash);
    }

    public String getSignatureToken(List<ParameterDeclaration> parameters) {
        return parameters.stream().map(ParameterDeclaration::getType).map(TypeSignatureTranslator::toTypeSignature).map(arg_0 -> ((TypeManager)this.typeManager).getType(arg_0)).map(Type::getTypeId).map(TypeId::getId).collect(Collectors.joining(",", "(", ")"));
    }

    private class QueryFunctions {
        private final Session session;
        private final Map<FunctionKey, FunctionListing> functionListing = new ConcurrentHashMap<FunctionKey, FunctionListing>();
        private final Map<FunctionId, LanguageFunctionImplementation> implementationsById = new ConcurrentHashMap<FunctionId, LanguageFunctionImplementation>();
        private final Map<ResolvedFunction, LanguageFunctionImplementation> implementationsByResolvedFunction = new ConcurrentHashMap<ResolvedFunction, LanguageFunctionImplementation>();

        public QueryFunctions(Session session) {
            this.session = session;
        }

        public void verifyForCreate(String sql, FunctionManager functionManager, AccessControl accessControl) {
            this.implementationWithoutSecurity(sql).verifyForCreate(functionManager, accessControl);
        }

        public void addInlineFunction(String sql, AccessControl accessControl) {
            LanguageFunctionImplementation implementation = this.implementationWithoutSecurity(sql);
            FunctionMetadata metadata = implementation.getFunctionMetadata();
            this.implementationsById.put(metadata.getFunctionId(), implementation);
            SchemaFunctionName name = new SchemaFunctionName(LanguageFunctionManager.QUERY_LOCAL_SCHEMA, metadata.getCanonicalName());
            this.getFunctionListing(GlobalSystemConnector.CATALOG_HANDLE, name).addFunction(metadata);
            implementation.analyzeAndPlan(accessControl);
        }

        public synchronized List<FunctionMetadata> getFunctions(CatalogHandle catalogHandle, SchemaFunctionName name, LanguageFunctionLoader languageFunctionLoader, RunAsIdentityLoader identityLoader) {
            return this.getFunctionListing(catalogHandle, name).getFunctions(languageFunctionLoader, identityLoader);
        }

        public Set<ResolvedFunction> getDependencies(FunctionId functionId, AccessControl accessControl) {
            LanguageFunctionImplementation function = this.implementationsById.get(functionId);
            Preconditions.checkArgument((function != null ? 1 : 0) != 0, (String)"Unknown function implementation: %s", (Object)functionId);
            return function.getFunctionDependencies(accessControl);
        }

        public Optional<ScalarFunctionImplementation> specialize(ResolvedFunction resolvedFunction, FunctionManager functionManager, InvocationConvention invocationConvention) {
            LanguageFunctionImplementation function = this.implementationsByResolvedFunction.get(resolvedFunction);
            if (function == null) {
                return Optional.empty();
            }
            return Optional.of(function.specialize(functionManager, invocationConvention));
        }

        public FunctionMetadata getFunctionMetadata(FunctionId functionId) {
            LanguageFunctionImplementation function = this.implementationsById.get(functionId);
            Preconditions.checkArgument((function != null ? 1 : 0) != 0, (String)"Unknown function implementation: %s", (Object)functionId);
            return function.getFunctionMetadata();
        }

        public void registerResolvedFunction(ResolvedFunction resolvedFunction) {
            FunctionId functionId = resolvedFunction.getFunctionId();
            LanguageFunctionImplementation function = this.implementationsById.get(functionId);
            Preconditions.checkArgument((function != null ? 1 : 0) != 0, (String)"Unknown function implementation: %s", (Object)functionId);
            this.implementationsByResolvedFunction.put(resolvedFunction, function);
        }

        public List<LanguageScalarFunctionData> serializeFunctionsForWorkers() {
            return (List)this.implementationsByResolvedFunction.entrySet().stream().map(entry -> new LanguageScalarFunctionData((ResolvedFunction)entry.getKey(), ((LanguageFunctionImplementation)entry.getValue()).getRoutine())).collect(ImmutableList.toImmutableList());
        }

        private FunctionListing getFunctionListing(CatalogHandle catalogHandle, SchemaFunctionName name) {
            return this.functionListing.computeIfAbsent(new FunctionKey(catalogHandle, name), x$0 -> new FunctionListing((FunctionKey)x$0));
        }

        private LanguageFunctionImplementation implementationWithoutSecurity(String sql) {
            return new LanguageFunctionImplementation(sql, this.session.getPath(), Optional.empty(), Optional.empty());
        }

        private LanguageFunctionImplementation implementationWithSecurity(String sql, List<CatalogSchemaName> path, Optional<String> owner, RunAsIdentityLoader identityLoader) {
            return new LanguageFunctionImplementation(sql, this.session.getPath().forView(path), owner, Optional.of(identityLoader));
        }

        private class LanguageFunctionImplementation {
            private final FunctionMetadata functionMetadata;
            private final FunctionSpecification functionSpecification;
            private final SqlPath path;
            private final Optional<String> owner;
            private final Optional<RunAsIdentityLoader> identityLoader;
            private SqlRoutineAnalysis analysis;
            private Set<ResolvedFunction> dependencies;
            private IrRoutine routine;
            private boolean analyzing;

            private LanguageFunctionImplementation(String sql, SqlPath path, Optional<String> owner, Optional<RunAsIdentityLoader> identityLoader) {
                this.functionSpecification = LanguageFunctionManager.this.parser.createFunctionSpecification(sql);
                this.functionMetadata = SqlRoutineAnalyzer.extractFunctionMetadata(LanguageFunctionManager.createSqlLanguageFunctionId(sql), this.functionSpecification);
                this.path = Objects.requireNonNull(path, "path is null");
                this.owner = Objects.requireNonNull(owner, "owner is null");
                this.identityLoader = Objects.requireNonNull(identityLoader, "identityLoader is null");
            }

            public FunctionMetadata getFunctionMetadata() {
                return this.functionMetadata;
            }

            public void verifyForCreate(FunctionManager functionManager, AccessControl accessControl) {
                Preconditions.checkState((boolean)this.identityLoader.isEmpty(), (Object)"create should not enforce security");
                this.analyzeAndPlan(accessControl);
                new SqlRoutineCompiler(functionManager).compile(this.getRoutine());
            }

            private synchronized void analyzeAndPlan(AccessControl accessControl) {
                if (this.analysis != null) {
                    return;
                }
                if (this.analyzing) {
                    throw new TrinoException((ErrorCodeSupplier)StandardErrorCode.NOT_SUPPORTED, "Recursive language functions are not supported: %s%s".formatted(this.functionMetadata.getCanonicalName(), this.functionMetadata.getSignature()));
                }
                this.analyzing = true;
                FunctionContext context = this.functionContext(accessControl);
                this.analysis = LanguageFunctionManager.this.analyzer.analyze(context.session(), context.accessControl(), this.functionSpecification);
                this.dependencies = this.analysis.analysis().getResolvedFunctions();
                this.routine = LanguageFunctionManager.this.planner.planSqlFunction(QueryFunctions.this.session, this.functionSpecification, this.analysis);
                this.analyzing = false;
            }

            public synchronized Set<ResolvedFunction> getFunctionDependencies(AccessControl accessControl) {
                this.analyzeAndPlan(accessControl);
                return this.dependencies;
            }

            public synchronized IrRoutine getRoutine() {
                if (this.routine == null) {
                    throw new IllegalStateException("Function not analyzed: " + String.valueOf(this.functionMetadata.getSignature()));
                }
                return this.routine;
            }

            public ScalarFunctionImplementation specialize(FunctionManager functionManager, InvocationConvention invocationConvention) {
                SpecializedSqlScalarFunction function = new SqlRoutineCompiler(functionManager).compile(this.getRoutine());
                return function.getScalarFunctionImplementation(invocationConvention);
            }

            private FunctionContext functionContext(AccessControl accessControl) {
                if (this.identityLoader.isEmpty() || SqlRoutineAnalyzer.isRunAsInvoker(this.functionSpecification)) {
                    Session functionSession = this.createFunctionSession(QueryFunctions.this.session.getIdentity());
                    return new FunctionContext(functionSession, accessControl);
                }
                Identity identity = this.identityLoader.get().getFunctionRunAsIdentity(this.owner);
                Identity newIdentity = Identity.from((Identity)identity).withGroups(LanguageFunctionManager.this.groupProvider.getGroups(identity.getUser())).build();
                Session functionSession = this.createFunctionSession(newIdentity);
                if (!identity.getUser().equals(QueryFunctions.this.session.getUser())) {
                    accessControl = new ViewAccessControl(accessControl);
                }
                return new FunctionContext(functionSession, accessControl);
            }

            private Session createFunctionSession(Identity identity) {
                return QueryFunctions.this.session.createViewSession(Optional.empty(), Optional.empty(), identity, this.path);
            }

            private record FunctionContext(Session session, AccessControl accessControl) {
            }
        }

        private class FunctionListing {
            private final CatalogHandle catalogHandle;
            private final SchemaFunctionName name;
            private final List<FunctionMetadata> functions = new ArrayList<FunctionMetadata>();
            private boolean loaded;

            public FunctionListing(FunctionKey key) {
                this.catalogHandle = key.catalogHandle();
                this.name = key.name();
            }

            public synchronized void addFunction(FunctionMetadata function) {
                this.functions.add(function);
                this.loaded = true;
            }

            public synchronized List<FunctionMetadata> getFunctions(LanguageFunctionLoader languageFunctionLoader, RunAsIdentityLoader identityLoader) {
                if (this.loaded) {
                    return ImmutableList.copyOf(this.functions);
                }
                this.loaded = true;
                List implementations = (List)languageFunctionLoader.getLanguageFunction(QueryFunctions.this.session.toConnectorSession(), this.name).stream().map(function -> QueryFunctions.this.implementationWithSecurity(function.sql(), function.path(), function.owner(), identityLoader)).collect(ImmutableList.toImmutableList());
                Set names = (Set)implementations.stream().map(function -> function.getFunctionMetadata().getCanonicalName()).collect(ImmutableSet.toImmutableSet());
                if (!names.isEmpty() && !names.equals(Set.of(this.name.getFunctionName()))) {
                    throw new TrinoException((ErrorCodeSupplier)StandardErrorCode.FUNCTION_IMPLEMENTATION_ERROR, "Catalog %s returned functions named %s when listing functions named %s".formatted(this.catalogHandle.getCatalogName(), names, this.name));
                }
                implementations.forEach(implementation -> this.functions.add(implementation.getFunctionMetadata()));
                implementations.forEach(processedFunction -> QueryFunctions.this.implementationsById.put(processedFunction.getFunctionMetadata().getFunctionId(), (LanguageFunctionImplementation)processedFunction));
                return ImmutableList.copyOf(this.functions);
            }
        }

        private record FunctionKey(CatalogHandle catalogHandle, SchemaFunctionName name) {
        }
    }

    public static interface LanguageFunctionLoader {
        public Collection<LanguageFunction> getLanguageFunction(ConnectorSession var1, SchemaFunctionName var2);
    }

    public static interface RunAsIdentityLoader {
        public Identity getFunctionRunAsIdentity(Optional<String> var1);
    }
}

