/*
 * Decompiled with CFR 0.152.
 */
package com.blazebit.persistence.impl.transform;

import com.blazebit.persistence.impl.AttributeHolder;
import com.blazebit.persistence.impl.ClauseType;
import com.blazebit.persistence.impl.JoinManager;
import com.blazebit.persistence.impl.JoinNode;
import com.blazebit.persistence.impl.JpaUtils;
import com.blazebit.persistence.impl.MainQuery;
import com.blazebit.persistence.impl.ResolvedExpression;
import com.blazebit.persistence.impl.SimplePathReference;
import com.blazebit.persistence.impl.SubqueryBuilderListenerImpl;
import com.blazebit.persistence.impl.SubqueryInitiatorFactory;
import com.blazebit.persistence.parser.EntityMetamodel;
import com.blazebit.persistence.parser.FunctionKind;
import com.blazebit.persistence.parser.expression.AggregateExpression;
import com.blazebit.persistence.parser.expression.Expression;
import com.blazebit.persistence.parser.expression.ExpressionCopyContext;
import com.blazebit.persistence.parser.expression.ExpressionModifierCollectingResultVisitorAdapter;
import com.blazebit.persistence.parser.expression.FunctionExpression;
import com.blazebit.persistence.parser.expression.ListIndexExpression;
import com.blazebit.persistence.parser.expression.MapKeyExpression;
import com.blazebit.persistence.parser.expression.PathExpression;
import com.blazebit.persistence.parser.expression.PathReference;
import com.blazebit.persistence.parser.expression.PropertyExpression;
import com.blazebit.persistence.parser.expression.StringLiteral;
import com.blazebit.persistence.parser.expression.Subquery;
import com.blazebit.persistence.parser.expression.SubqueryExpression;
import com.blazebit.persistence.parser.expression.modifier.ExpressionModifier;
import com.blazebit.persistence.parser.util.ExpressionUtils;
import com.blazebit.persistence.parser.util.JpaMetamodelUtils;
import com.blazebit.persistence.spi.JpaProvider;
import java.util.ArrayList;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import javax.persistence.metamodel.Attribute;
import javax.persistence.metamodel.EntityType;
import javax.persistence.metamodel.IdentifiableType;
import javax.persistence.metamodel.ManagedType;
import javax.persistence.metamodel.PluralAttribute;
import javax.persistence.metamodel.SingularAttribute;
import javax.persistence.metamodel.Type;

public class SizeTransformationVisitor
extends ExpressionModifierCollectingResultVisitorAdapter {
    private static final Set<Type.PersistenceType> IDENTIFIABLE_PERSISTENCE_TYPES = EnumSet.of(Type.PersistenceType.ENTITY, Type.PersistenceType.MAPPED_SUPERCLASS);
    private final MainQuery mainQuery;
    private final EntityMetamodel metamodel;
    private final SubqueryInitiatorFactory subqueryInitFactory;
    private final JoinManager joinManager;
    private final JpaProvider jpaProvider;
    private boolean countTransformationDisabled;
    private boolean orderBySelectClause;
    private boolean distinctRequired;
    private ClauseType clause;
    private final Set<TransformedExpressionEntry> transformedExpressions = new HashSet<TransformedExpressionEntry>();
    private final Map<String, LateJoinEntry> lateJoins = new HashMap<String, LateJoinEntry>();
    private final Map<ResolvedExpression, Set<ClauseType>> requiredGroupBys = new LinkedHashMap<ResolvedExpression, Set<ClauseType>>();
    private final Map<ResolvedExpression, Set<ClauseType>> subqueryGroupBys = new LinkedHashMap<ResolvedExpression, Set<ClauseType>>();
    private JoinNode currentJoinNode;
    private Set<JoinNode> joinNodeBlacklist = new HashSet<JoinNode>();
    private boolean aggregateFunctionContext;

    public SizeTransformationVisitor(MainQuery mainQuery, SubqueryInitiatorFactory subqueryInitFactory, JoinManager joinManager, JpaProvider jpaProvider) {
        this.mainQuery = mainQuery;
        this.metamodel = mainQuery.getMetamodel();
        this.subqueryInitFactory = subqueryInitFactory;
        this.joinManager = joinManager;
        this.jpaProvider = jpaProvider;
    }

    public ClauseType getClause() {
        return this.clause;
    }

    public void setClause(ClauseType clause) {
        this.clause = clause;
    }

    public boolean isCountTransformationDisabled() {
        return this.countTransformationDisabled;
    }

    public void setCountTransformationDisabled(boolean countTransformationDisabled) {
        this.countTransformationDisabled = countTransformationDisabled;
    }

    public void setOrderBySelectClause(boolean orderBySelectClause) {
        this.orderBySelectClause = orderBySelectClause;
    }

    public Map<String, LateJoinEntry> getLateJoins() {
        return this.lateJoins;
    }

    public Map<ResolvedExpression, Set<ClauseType>> getRequiredGroupBys() {
        return this.requiredGroupBys;
    }

    public Map<ResolvedExpression, Set<ClauseType>> getSubqueryGroupBys() {
        return this.subqueryGroupBys;
    }

    private boolean isCountTransformationEnabled() {
        return !this.countTransformationDisabled && this.mainQuery.getQueryConfiguration().isCountTransformationEnabled();
    }

    public Boolean visit(PathExpression expression) {
        LateJoinEntry lateJoinEntry;
        if (this.orderBySelectClause && (lateJoinEntry = this.lateJoins.get(this.getJoinLookupKey(expression))) != null) {
            lateJoinEntry.getClauseDependencies().add(ClauseType.ORDER_BY);
        }
        if (this.clause == ClauseType.SELECT) {
            for (JoinNode current = (JoinNode)expression.getBaseNode(); current != null; current = current.getParent()) {
                this.joinNodeBlacklist.add(current);
            }
        }
        return super.visit(expression);
    }

    public Boolean visit(FunctionExpression expression) {
        if (this.clause != ClauseType.WHERE && ExpressionUtils.isSizeFunction((FunctionExpression)expression)) {
            return true;
        }
        if (!this.aggregateFunctionContext && FunctionKind.AGGREGATE == this.mainQuery.getCbf().getFunctions().get(expression.getFunctionName().toLowerCase())) {
            this.aggregateFunctionContext = true;
            Boolean result = super.visit(expression);
            this.aggregateFunctionContext = false;
            return result;
        }
        return super.visit(expression);
    }

    protected void onModifier(ExpressionModifier parentModifier) {
        PathExpression sizeArg = (PathExpression)((FunctionExpression)parentModifier.get()).getExpressions().get(0);
        parentModifier.set(this.getSizeExpression(parentModifier, sizeArg));
        sizeArg.accept((Expression.ResultVisitor)this);
    }

    private boolean requiresBlacklistedNode(PathExpression sizeArg) {
        JoinNode sizeArgBaseNode = (JoinNode)sizeArg.getBaseNode();
        if (this.joinNodeBlacklist.contains(sizeArgBaseNode)) {
            return sizeArgBaseNode.getNodes().keySet().contains(sizeArg.getField());
        }
        return false;
    }

    private Expression getSizeExpression(ExpressionModifier parentModifier, PathExpression sizeArg) {
        boolean subqueryRequired;
        JoinNode sizeArgJoin = (JoinNode)sizeArg.getBaseNode();
        String property = sizeArg.getPathReference().getField();
        Type<?> nodeType = ((JoinNode)sizeArg.getBaseNode()).getNodeType();
        if (!(nodeType instanceof EntityType)) {
            throw new IllegalArgumentException("Size on a collection owned by a non-entity type is not supported yet: " + sizeArg);
        }
        EntityType startType = (EntityType)nodeType;
        AttributeHolder result = JpaUtils.getAttributeForJoining(this.metamodel, sizeArg);
        PluralAttribute targetAttribute = (PluralAttribute)result.getAttribute();
        if (targetAttribute == null) {
            throw new RuntimeException("Attribute [" + property + "] not found on class " + startType.getJavaType().getName());
        }
        PluralAttribute.CollectionType collectionType = targetAttribute.getCollectionType();
        boolean isElementCollection = this.jpaProvider.getJpaMetamodelAccessor().isElementCollection((Attribute)targetAttribute);
        if (isElementCollection) {
            subqueryRequired = false;
        } else {
            ManagedType managedTargetType = (ManagedType)result.getAttributeType();
            if (managedTargetType instanceof EntityType) {
                subqueryRequired = ((EntityType)managedTargetType).getIdType().getPersistenceType() == Type.PersistenceType.EMBEDDABLE;
            } else {
                throw new RuntimeException("Path [" + sizeArg.toString() + "] does not refer to a collection");
            }
        }
        ArrayList<PathExpression> groupByExprs = new ArrayList<PathExpression>();
        for (SingularAttribute idAttribute : JpaMetamodelUtils.getIdAttributes((IdentifiableType)startType)) {
            ArrayList<PropertyExpression> pathElementExpr = new ArrayList<PropertyExpression>(2);
            pathElementExpr.add(new PropertyExpression(sizeArgJoin.getAlias()));
            pathElementExpr.add(new PropertyExpression(idAttribute.getName()));
            PathExpression groupByExpr = new PathExpression(pathElementExpr);
            groupByExprs.add(groupByExpr);
        }
        boolean bl = subqueryRequired = subqueryRequired || !startType.hasSingleIdAttribute() || this.joinManager.getRoots().size() > 1 || this.clause == ClauseType.JOIN || !this.isCountTransformationEnabled() || this.jpaProvider.isBag((EntityType)targetAttribute.getDeclaringType(), targetAttribute.getName()) || this.requiresBlacklistedNode(sizeArg) || this.aggregateFunctionContext;
        if (subqueryRequired) {
            return this.wrapSubqueryConditionally(this.generateSubquery(sizeArg), this.aggregateFunctionContext);
        }
        if (this.currentJoinNode != null && !this.currentJoinNode.equals(sizeArgJoin)) {
            int sizeArgJoinDepth;
            int currentJoinDepth = this.currentJoinNode.getJoinDepth();
            if (currentJoinDepth > (sizeArgJoinDepth = sizeArgJoin.getJoinDepth())) {
                return this.wrapSubqueryConditionally(this.generateSubquery(sizeArg), this.aggregateFunctionContext);
            }
            for (TransformedExpressionEntry transformedExpressionEntry : this.transformedExpressions) {
                PathExpression originalSizeArg = transformedExpressionEntry.getOriginalSizeArg();
                Expression subquery = this.wrapSubqueryConditionally(this.generateSubquery(originalSizeArg), transformedExpressionEntry.isAggregateFunctionContext());
                transformedExpressionEntry.getParentModifier().set(subquery);
            }
            this.transformedExpressions.clear();
            this.requiredGroupBys.clear();
            this.lateJoins.clear();
            this.distinctRequired = false;
            if (currentJoinDepth == sizeArgJoinDepth) {
                return this.wrapSubqueryConditionally(this.generateSubquery(sizeArg), this.aggregateFunctionContext);
            }
        }
        for (PathExpression groupByExpr : groupByExprs) {
            this.joinManager.implicitJoin((Expression)groupByExpr, true, true, true, null, null, new HashSet<String>(), false, false, false, false);
        }
        PathExpression originalSizeArg = sizeArg.copy(ExpressionCopyContext.EMPTY);
        originalSizeArg.setPathReference(sizeArg.getPathReference());
        sizeArg.setUsedInCollectionFunction(false);
        ArrayList<Expression> countArguments = new ArrayList<Expression>();
        String joinLookupKey = this.getJoinLookupKey(sizeArg);
        LateJoinEntry lateJoin = this.lateJoins.get(joinLookupKey);
        if (lateJoin == null) {
            lateJoin = new LateJoinEntry();
            this.lateJoins.put(joinLookupKey, lateJoin);
        }
        lateJoin.getExpressionsToJoin().add((Expression)sizeArg);
        lateJoin.getClauseDependencies().add(this.clause);
        if (isElementCollection && collectionType != PluralAttribute.CollectionType.MAP || collectionType == PluralAttribute.CollectionType.SET) {
            if (IDENTIFIABLE_PERSISTENCE_TYPES.contains(targetAttribute.getElementType().getPersistenceType()) && targetAttribute.isCollection()) {
                PluralAttribute sizeArgTargetAttribute = (PluralAttribute)JpaMetamodelUtils.getAttribute((ManagedType)startType, (String)sizeArg.getPathReference().getField());
                for (Attribute attribute : JpaMetamodelUtils.getIdAttributes((IdentifiableType)((IdentifiableType)sizeArgTargetAttribute.getElementType()))) {
                    ArrayList<PropertyExpression> pathElementExpressions = new ArrayList<PropertyExpression>(sizeArg.getExpressions().size() + 1);
                    pathElementExpressions.addAll(sizeArg.getExpressions());
                    pathElementExpressions.add(new PropertyExpression(attribute.getName()));
                    PathExpression pathExpression = new PathExpression(pathElementExpressions);
                    countArguments.add((Expression)pathExpression);
                    lateJoin.getExpressionsToJoin().add((Expression)pathExpression);
                }
            } else {
                countArguments.add((Expression)sizeArg);
            }
        } else {
            sizeArg.setCollectionQualifiedPath(true);
            if (collectionType == PluralAttribute.CollectionType.LIST) {
                countArguments.add((Expression)new ListIndexExpression(sizeArg));
            } else {
                countArguments.add((Expression)new MapKeyExpression(sizeArg));
            }
        }
        AggregateExpression countExpr = this.createCountFunction(this.distinctRequired, countArguments);
        this.transformedExpressions.add(new TransformedExpressionEntry(countExpr, originalSizeArg, parentModifier, this.aggregateFunctionContext));
        this.currentJoinNode = (JoinNode)originalSizeArg.getBaseNode();
        if (!this.distinctRequired && this.lateJoins.size() + this.joinManager.getCollectionJoins().size() > 1) {
            this.distinctRequired = true;
            for (TransformedExpressionEntry transformedExpressionEntry : this.transformedExpressions) {
                AggregateExpression transformedExpr = transformedExpressionEntry.getTransformedExpression();
                if (ExpressionUtils.isCustomFunctionInvocation((FunctionExpression)transformedExpr) && "count_tuple".equalsIgnoreCase(((StringLiteral)transformedExpr.getExpressions().get(0)).getValue())) {
                    Expression possibleDistinct = (Expression)transformedExpr.getExpressions().get(1);
                    if (possibleDistinct instanceof StringLiteral && "DISTINCT".equals(((StringLiteral)possibleDistinct).getValue())) continue;
                    transformedExpr.getExpressions().add(1, new StringLiteral("DISTINCT"));
                    continue;
                }
                transformedExpr.setDistinct(true);
            }
        }
        for (Expression expression : groupByExprs) {
            String groupByExprString = expression.toString();
            ResolvedExpression resolvedExpression = new ResolvedExpression(groupByExprString, expression);
            Set<ClauseType> clauseTypes = this.requiredGroupBys.get(resolvedExpression);
            if (clauseTypes == null) {
                this.requiredGroupBys.put(resolvedExpression, EnumSet.of(this.clause));
                continue;
            }
            clauseTypes.add(this.clause);
        }
        return countExpr;
    }

    private String getJoinLookupKey(PathExpression pathExpression) {
        JoinNode originalNode = (JoinNode)pathExpression.getBaseNode();
        return originalNode.getAliasInfo().getAbsolutePath() + "." + pathExpression.getField();
    }

    private AggregateExpression createCountFunction(boolean distinct, List<Expression> countTupleArguments) {
        countTupleArguments.add(0, (Expression)new StringLiteral("count_tuple".toUpperCase()));
        if (distinct) {
            countTupleArguments.add(1, (Expression)new StringLiteral("DISTINCT"));
        }
        return new AggregateExpression(false, "FUNCTION", countTupleArguments);
    }

    private SubqueryExpression generateSubquery(PathExpression sizeArg) {
        JoinNode sizeArgJoin = (JoinNode)sizeArg.getBaseNode();
        Type<?> nodeType = sizeArgJoin.getNodeType();
        if (!(nodeType instanceof EntityType)) {
            throw new IllegalArgumentException("Size on a collection owned by a non-entity type is not supported yet: " + sizeArg);
        }
        EntityType startType = (EntityType)nodeType;
        Subquery countSubquery = (Subquery)this.subqueryInitFactory.createSubqueryInitiator(null, new SubqueryBuilderListenerImpl(), false, this.getClause()).from(sizeArg.getPathReference().toString()).select("COUNT(*)");
        for (SingularAttribute idAttribute : JpaMetamodelUtils.getIdAttributes((IdentifiableType)startType)) {
            String groupByExprString = sizeArgJoin.getAlias() + "." + idAttribute.getName();
            ResolvedExpression groupByExpr = new ResolvedExpression(groupByExprString, null);
            Set<ClauseType> clauseTypes = this.subqueryGroupBys.get(groupByExpr);
            if (clauseTypes == null) {
                ArrayList<PropertyExpression> pathElementExpressions = new ArrayList<PropertyExpression>(2);
                pathElementExpressions.add(new PropertyExpression(sizeArgJoin.getAlias()));
                pathElementExpressions.add(new PropertyExpression(idAttribute.getName()));
                PathExpression pathExpression = new PathExpression(pathElementExpressions);
                pathExpression.setPathReference((PathReference)new SimplePathReference(sizeArgJoin, idAttribute.getName(), this.metamodel.type(JpaMetamodelUtils.resolveFieldClass((Class)startType.getJavaType(), (Attribute)idAttribute))));
                groupByExpr = new ResolvedExpression(groupByExprString, (Expression)pathExpression);
                this.subqueryGroupBys.put(groupByExpr, EnumSet.of(this.clause));
                continue;
            }
            clauseTypes.add(this.clause);
        }
        return new SubqueryExpression(countSubquery);
    }

    private Expression wrapSubqueryConditionally(SubqueryExpression subquery, boolean wrap) {
        if (wrap) {
            ArrayList<Object> subqueryFunctionArguments = new ArrayList<Object>(1);
            subqueryFunctionArguments.add(new StringLiteral("subquery"));
            subqueryFunctionArguments.add(subquery);
            return new FunctionExpression("FUNCTION", subqueryFunctionArguments);
        }
        return subquery;
    }

    static class LateJoinEntry {
        private final EnumSet<ClauseType> clauseDependencies = EnumSet.noneOf(ClauseType.class);
        private final List<Expression> expressionsToJoin = new ArrayList<Expression>();

        LateJoinEntry() {
        }

        public EnumSet<ClauseType> getClauseDependencies() {
            return this.clauseDependencies;
        }

        public List<Expression> getExpressionsToJoin() {
            return this.expressionsToJoin;
        }
    }

    private static class TransformedExpressionEntry {
        private final AggregateExpression transformedExpression;
        private final PathExpression originalSizeArg;
        private final ExpressionModifier parentModifier;
        private final boolean aggregateFunctionContext;

        public TransformedExpressionEntry(AggregateExpression transformedExpression, PathExpression originalSizeArg, ExpressionModifier parentModifier, boolean aggregateFunctionContext) {
            this.transformedExpression = transformedExpression;
            this.originalSizeArg = originalSizeArg;
            this.parentModifier = parentModifier;
            this.aggregateFunctionContext = aggregateFunctionContext;
        }

        public AggregateExpression getTransformedExpression() {
            return this.transformedExpression;
        }

        public PathExpression getOriginalSizeArg() {
            return this.originalSizeArg;
        }

        public ExpressionModifier getParentModifier() {
            return this.parentModifier;
        }

        public boolean isAggregateFunctionContext() {
            return this.aggregateFunctionContext;
        }
    }
}

