/*
 * SPDX-License-Identifier: Apache-2.0
 * Copyright Blazebit
 */

package com.blazebit.persistence.parser;

import com.blazebit.persistence.parser.expression.ArithmeticExpression;
import com.blazebit.persistence.parser.expression.ArithmeticFactor;
import com.blazebit.persistence.parser.expression.ArrayExpression;
import com.blazebit.persistence.parser.expression.DateLiteral;
import com.blazebit.persistence.parser.expression.EntityLiteral;
import com.blazebit.persistence.parser.expression.EnumLiteral;
import com.blazebit.persistence.parser.expression.Expression;
import com.blazebit.persistence.parser.expression.FunctionExpression;
import com.blazebit.persistence.parser.expression.GeneralCaseExpression;
import com.blazebit.persistence.parser.expression.ListIndexExpression;
import com.blazebit.persistence.parser.expression.MapEntryExpression;
import com.blazebit.persistence.parser.expression.MapKeyExpression;
import com.blazebit.persistence.parser.expression.MapValueExpression;
import com.blazebit.persistence.parser.expression.NullExpression;
import com.blazebit.persistence.parser.expression.NumericLiteral;
import com.blazebit.persistence.parser.expression.ParameterExpression;
import com.blazebit.persistence.parser.expression.PathElementExpression;
import com.blazebit.persistence.parser.expression.PathExpression;
import com.blazebit.persistence.parser.expression.PropertyExpression;
import com.blazebit.persistence.parser.expression.SimpleCaseExpression;
import com.blazebit.persistence.parser.expression.StringLiteral;
import com.blazebit.persistence.parser.expression.SubqueryExpression;
import com.blazebit.persistence.parser.expression.TimeLiteral;
import com.blazebit.persistence.parser.expression.TimestampLiteral;
import com.blazebit.persistence.parser.expression.TreatExpression;
import com.blazebit.persistence.parser.expression.TrimExpression;
import com.blazebit.persistence.parser.expression.TypeFunctionExpression;
import com.blazebit.persistence.parser.expression.WhenClauseExpression;
import com.blazebit.persistence.parser.predicate.BetweenPredicate;
import com.blazebit.persistence.parser.predicate.BooleanLiteral;
import com.blazebit.persistence.parser.predicate.CompoundPredicate;
import com.blazebit.persistence.parser.predicate.EqPredicate;
import com.blazebit.persistence.parser.predicate.ExistsPredicate;
import com.blazebit.persistence.parser.predicate.GePredicate;
import com.blazebit.persistence.parser.predicate.GtPredicate;
import com.blazebit.persistence.parser.predicate.InPredicate;
import com.blazebit.persistence.parser.predicate.IsEmptyPredicate;
import com.blazebit.persistence.parser.predicate.IsNullPredicate;
import com.blazebit.persistence.parser.predicate.LePredicate;
import com.blazebit.persistence.parser.predicate.LikePredicate;
import com.blazebit.persistence.parser.predicate.LtPredicate;
import com.blazebit.persistence.parser.predicate.MemberOfPredicate;
import com.blazebit.persistence.parser.util.JpaMetamodelUtils;
import com.blazebit.reflection.ReflectionUtils;

import javax.persistence.metamodel.Attribute;
import javax.persistence.metamodel.EntityType;
import javax.persistence.metamodel.ListAttribute;
import javax.persistence.metamodel.ManagedType;
import javax.persistence.metamodel.MapAttribute;
import javax.persistence.metamodel.PluralAttribute;
import javax.persistence.metamodel.SingularAttribute;
import javax.persistence.metamodel.Type;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * A visitor that can determine possible target types and JPA attributes of a path expression.
 *
 * @author Christian Beikov
 * @since 1.0.0
 */
public class PathTargetResolvingExpressionVisitor implements Expression.Visitor {

    private static final Class[] EMPTY = new Class[0];

    protected PathPosition currentPosition;
    protected List<PathPosition> pathPositions;
    protected final EntityMetamodel metamodel;
    protected final Type<?> rootType;
    protected final Attribute<?, ?> rootAttribute;
    protected final Map<String, Type<?>> rootTypes;

    /**
     * @author Christian Beikov
     * @since 1.0.0
     */
    protected static class PathPosition {

        public Type<?> currentClass;
        public Type<?> valueClass;
        public Type<?> keyClass;
        public Attribute<?, ?> attribute;

        public PathPosition(Type<?> currentClass, Attribute<?, ?> attribute) {
            this.currentClass = currentClass;
            this.attribute = attribute;
        }
        
        private PathPosition(Type<?> currentClass, Type<?> valueClass, Type<?> keyClass, Attribute<?, ?> attribute) {
            this.currentClass = currentClass;
            this.valueClass = valueClass;
            this.keyClass = keyClass;
            this.attribute = attribute;
        }

        public Type<?> getRealCurrentType() {
            return currentClass;
        }

        public Class<?> getRealCurrentClass() {
            return currentClass == null ? null : currentClass.getJavaType();
        }

        public Type<?> getCurrentType() {
            if (valueClass != null) {
                return valueClass;
            }
            if (keyClass != null) {
                return keyClass;
            }

            return currentClass;
        }

        public Class<?> getCurrentClass() {
            return getCurrentType() == null ? null : getCurrentType().getJavaType();
        }

        public Class<?> getKeyCurrentClass() {
            return keyClass == null ? null : keyClass.getJavaType();
        }

        public void setCurrentType(Type<?> currentClass) {
            this.currentClass = currentClass;
            this.valueClass = null;
            this.keyClass = null;
        }

        public Attribute<?, ?> getAttribute() {
            return attribute;
        }

        public void setAttribute(Attribute<?, ?> attribute) {
            this.attribute = attribute;
        }

        public boolean hasCollectionJoin() {
            return valueClass != null;
        }

        void setValueType(Type<?> valueClass) {
            this.valueClass = valueClass;
        }

        void setKeyType(Type<?> keyClass) {
            this.keyClass = keyClass;
        }

        public PathPosition copy() {
            return new PathPosition(currentClass, valueClass, keyClass, attribute);
        }
    }

    public PathTargetResolvingExpressionVisitor(EntityMetamodel metamodel, Type<?> rootType, String skipBaseNodeAlias) {
        this(metamodel, rootType, null, skipBaseNodeAlias, Collections.<String, Type<?>>emptyMap());
    }

    public PathTargetResolvingExpressionVisitor(EntityMetamodel metamodel, Type<?> rootType, Attribute<?, ?> rootAttribute, String skipBaseNodeAlias, Map<String, Type<?>> rootTypes) {
        this.metamodel = metamodel;
        this.pathPositions = new ArrayList<>();
        this.rootType = rootType;
        this.rootAttribute = rootAttribute;
        this.rootTypes = withRootType(rootTypes, rootType, skipBaseNodeAlias);
        clear();
    }

    private static Map<String, Type<?>> withRootType(Map<String, Type<?>> rootTypes, Type<?> rootType, String rootAlias) {
        Map<String, Type<?>> newRootTypes = new HashMap<>();
        if (rootTypes != null) {
            newRootTypes.putAll(rootTypes);
        }
        newRootTypes.put("this", rootType);
        if (rootAlias != null) {
            newRootTypes.put(rootAlias, rootType);
        }
        return Collections.unmodifiableMap(newRootTypes);
    }

    public void clear() {
        pathPositions.clear();
        pathPositions.add(currentPosition = new PathPosition(rootType, rootAttribute));
    }

    private Type<?> getType(Type<?> baseType, Attribute<?, ?> attribute) {
        if (attribute instanceof PluralAttribute<?, ?, ?>) {
            return metamodel.type(((PluralAttribute<?, ?, ?>) attribute).getJavaType());
        }

        Class<?> baseClass = baseType.getJavaType();
        if (baseClass != null) {
            Class<?> clazz = JpaMetamodelUtils.resolveFieldClass(baseType.getJavaType(), attribute);
            if (clazz != null) {
                return metamodel.type(clazz);
            }
        }
        return ((SingularAttribute<?, ?>) attribute).getType();
    }

    public Map<Attribute<?, ?>, Type<?>> getPossibleTargets() {
        Map<Attribute<?, ?>, Type<?>> possibleTargets = new HashMap<>();

        List<PathPosition> positions = pathPositions;
        int size = positions.size();
        for (int i = 0; i < size; i++) {
            PathPosition position = positions.get(i);
            possibleTargets.put(position.getAttribute(), position.getCurrentType());
        }
        
        return possibleTargets;
    }
    
    @Override
    public void visit(PropertyExpression expression) {
        String property = expression.getProperty();
        if (currentPosition.getCurrentType().getPersistenceType() == Type.PersistenceType.BASIC) {
            throw new IllegalArgumentException("Can't access property '" + property + "' on basic type '" + JpaMetamodelUtils.getTypeName(currentPosition.getCurrentType()) + "'. Did you forget to add the embeddable type to your persistence.xml?");
        }
        Attribute<?, ?> attribute = JpaMetamodelUtils.getAttribute((ManagedType<?>) currentPosition.getCurrentType(), property);
        if (attribute == null) {
            throw new IllegalArgumentException("Attribute '" + property + "' not found on type '" + JpaMetamodelUtils.getTypeName(currentPosition.getCurrentType()) + "'");
        }
        currentPosition.setAttribute(attribute);
        Type<?> type = getType(currentPosition.getCurrentType(), attribute);
        Type<?> valueType = null;
        Type<?> keyType = null;

        if (attribute instanceof PluralAttribute<?, ?, ?>) {
            Class<?> javaType = type.getJavaType();
            Class<?>[] typeArguments;
            if (javaType == null) {
                typeArguments = EMPTY;
            } else {
                if (attribute.getJavaMember() instanceof Field) {
                    typeArguments = ReflectionUtils.getResolvedFieldTypeArguments(currentPosition.getCurrentClass(), (Field) attribute.getJavaMember());
                } else if (attribute.getJavaMember() instanceof Method) {
                    typeArguments = ReflectionUtils.getResolvedMethodReturnTypeArguments(currentPosition.getCurrentClass(), (Method) attribute.getJavaMember());
                } else {
                    typeArguments = EMPTY;
                }
            }

            valueType = metamodel.type(JpaMetamodelUtils.resolveFieldClass(currentPosition.getCurrentClass(), attribute));

            if (typeArguments.length == 0) {
                // Raw types
                if (attribute instanceof MapAttribute<?, ?, ?>) {
                    keyType = ((MapAttribute<?, ?, ?>) attribute).getKeyType();
                }
            } else {
                if (typeArguments.length > 1) {
                    keyType = metamodel.type(typeArguments[0]);
                }
            }
        }

        currentPosition.setCurrentType(type);
        currentPosition.setValueType(valueType);
        currentPosition.setKeyType(keyType);
    }

    @Override
    public void visit(GeneralCaseExpression expression) {
        List<PathPosition> currentPositions = pathPositions;
        List<PathPosition> newPositions = new ArrayList<PathPosition>();
        
        int positionsSize = currentPositions.size();
        for (int j = 0; j < positionsSize; j++) {
            List<WhenClauseExpression> expressions = expression.getWhenClauses();
            int size = expressions.size();
            for (int i = 0; i < size; i++) {
                PathPosition position = currentPositions.get(j).copy();
                pathPositions = new ArrayList<>();
                pathPositions.add(currentPosition = position);
                expressions.get(i).accept(this);
                newPositions.addAll(pathPositions);
            }

            if (expression.getDefaultExpr() != null) {
                PathPosition position = currentPositions.get(j).copy();
                pathPositions = new ArrayList<>();
                pathPositions.add(currentPosition = position);
                expression.getDefaultExpr().accept(this);
                newPositions.addAll(pathPositions);
            }
        }
        
        currentPosition = null;
        pathPositions = newPositions;
    }

    @Override
    public void visit(PathExpression expression) {
        List<PathElementExpression> expressions = expression.getExpressions();
        int size = expressions.size();
        int i = 0;
        PathElementExpression firstExpression = expressions.get(0);
        if (firstExpression instanceof PropertyExpression) {
            String property = ((PropertyExpression) firstExpression).getProperty();
            if (rootTypes.containsKey(property)) {
                i = 1;
                currentPosition.setCurrentType(rootTypes.get(property));
            }
        }

        if (currentPosition.getCurrentType() == null) {
            if (expression.getPathReference() != null) {
                currentPosition.setCurrentType(expression.getPathReference().getType());
            }
            return;
        }

        for (; i < size; i++) {
            expressions.get(i).accept(this);
        }
    }

    @Override
    public void visit(ListIndexExpression expression) {
        expression.getPath().accept(this);
        Class<?> type = currentPosition.getRealCurrentClass();

        if (!List.class.isAssignableFrom(type)) {
            invalid(expression, "Does not resolve to java.util.List!");
        } else {
            currentPosition.setAttribute(new ListIndexAttribute<>((ListAttribute<?, ?>) currentPosition.getAttribute()));
            currentPosition.setValueType(null);
            currentPosition.setKeyType(metamodel.type(Integer.class));
        }
    }

    @Override
    public void visit(MapEntryExpression expression) {
        expression.getPath().accept(this);
        currentPosition.setAttribute(new MapEntryAttribute<>((MapAttribute<?, Object, ?>) currentPosition.getAttribute()));
        currentPosition.setCurrentType(metamodel.type(Map.Entry.class));
    }

    @Override
    public void visit(MapKeyExpression expression) {
        expression.getPath().accept(this);
        currentPosition.setAttribute(new MapKeyAttribute<>((MapAttribute<?, Object, ?>) currentPosition.getAttribute()));
        currentPosition.setValueType(null);
    }

    @Override
    public void visit(MapValueExpression expression) {
        expression.getPath().accept(this);
        currentPosition.setKeyType(null);
    }

    @Override
    public void visit(ArrayExpression expression) {
        // Only need the base to navigate down the path
        if (expression.getBase() instanceof EntityLiteral) {
            EntityType<?> type = metamodel.getEntity(((EntityLiteral) expression.getBase()).getValue());
            currentPosition.setCurrentType(type);
        } else {
            expression.getBase().accept(this);
        }
    }

    @Override
    public void visit(TreatExpression expression) {
        boolean handled = false;
        if (expression.getExpression() instanceof PathExpression) {
            PathExpression treatPath = (PathExpression) expression.getExpression();
            if (treatPath.getExpressions().size() == 1 && rootTypes.containsKey(treatPath.getExpressions().get(0).toString())) {
                // When we encounter a naked root treat like "TREAT(alias AS Subtype)" we always skip it
                handled = true;
            }
        }
        if (!handled) {
            expression.getExpression().accept(this);
        }

        EntityType<?> type = metamodel.getEntity(expression.getType());
        // TODO: should we check if the type is actually a sub- or super type?
        currentPosition.setCurrentType(type);
        currentPosition.setValueType(type);
    }

    @Override
    public void visit(ParameterExpression expression) {
        invalid(expression);
    }

    @Override
    public void visit(NullExpression expression) {
        invalid(expression);
    }

    @Override
    public void visit(SubqueryExpression expression) {
        invalid(expression);
    }

    @Override
    public void visit(ArithmeticExpression expression) {
        invalid(expression);
    }

    @Override
    public void visit(ArithmeticFactor expression) {
        invalid(expression);
    }

    @Override
    public void visit(NumericLiteral expression) {
        invalid(expression);
    }

    @Override
    public void visit(BooleanLiteral expression) {
        invalid(expression);
    }

    @Override
    public void visit(StringLiteral expression) {
        invalid(expression);
    }

    @Override
    public void visit(DateLiteral expression) {
        invalid(expression);
    }

    @Override
    public void visit(TimeLiteral expression) {
        invalid(expression);
    }

    @Override
    public void visit(TimestampLiteral expression) {
        invalid(expression);
    }

    @Override
    public void visit(EnumLiteral expression) {
        invalid(expression);
    }

    @Override
    public void visit(EntityLiteral expression) {
        invalid(expression);
    }

    @Override
    public void visit(FunctionExpression expression) {
        invalid(expression);
    }

    @Override
    public void visit(TypeFunctionExpression expression) {
        invalid(expression);
    }

    @Override
    public void visit(TrimExpression expression) {
        expression.getTrimSource().accept(this);
    }

    @Override
    public void visit(SimpleCaseExpression expression) {
        visit((GeneralCaseExpression) expression);
    }

    @Override
    public void visit(WhenClauseExpression expression) {
        expression.getResult().accept(this);
    }
    
    @Override
    public void visit(CompoundPredicate predicate) {
        invalid(predicate);
    }

    @Override
    public void visit(EqPredicate predicate) {
        invalid(predicate);
    }

    @Override
    public void visit(IsNullPredicate predicate) {
        invalid(predicate);
    }

    @Override
    public void visit(IsEmptyPredicate predicate) {
        invalid(predicate);
    }

    @Override
    public void visit(MemberOfPredicate predicate) {
        invalid(predicate);
    }

    @Override
    public void visit(LikePredicate predicate) {
        invalid(predicate);
    }

    @Override
    public void visit(BetweenPredicate predicate) {
        invalid(predicate);
    }

    @Override
    public void visit(InPredicate predicate) {
        invalid(predicate);
    }

    @Override
    public void visit(GtPredicate predicate) {
        invalid(predicate);
    }

    @Override
    public void visit(GePredicate predicate) {
        invalid(predicate);
    }

    @Override
    public void visit(LtPredicate predicate) {
        invalid(predicate);
    }

    @Override
    public void visit(LePredicate predicate) {
        invalid(predicate);
    }

    @Override
    public void visit(ExistsPredicate predicate) {
        invalid(predicate);
    }

    protected final void invalid(Object o) {
        throw new IllegalArgumentException("Illegal occurence of [" + o + "] in path chain resolver!");
    }

    protected final void invalid(Object o, String reason) {
        throw new IllegalArgumentException("Illegal occurence of [" + o + "] in path chain resolver! " + reason);
    }

}
