package com.vaadin.copilot.javarewriter;

import java.util.HashSet;
import java.util.Optional;
import java.util.Set;

import com.vaadin.copilot.IdentityHashSet;

import com.github.javaparser.Range;
import com.github.javaparser.ast.CompilationUnit;
import com.github.javaparser.ast.Node;
import com.github.javaparser.ast.NodeList;
import com.github.javaparser.ast.body.ClassOrInterfaceDeclaration;
import com.github.javaparser.ast.body.ConstructorDeclaration;
import com.github.javaparser.ast.body.FieldDeclaration;
import com.github.javaparser.ast.body.MethodDeclaration;
import com.github.javaparser.ast.comments.Comment;
import com.github.javaparser.ast.observer.AstObserver;
import com.github.javaparser.ast.observer.ObservableProperty;
import com.github.javaparser.ast.stmt.BlockStmt;
import com.github.javaparser.ast.stmt.Statement;

/**
 * Observer for the JavaRewriter that tracks which parts have been deleted and
 * which nodes have potentially been added or modified.
 *
 * <p>
 * For deleted code we must store the ranges and not the nodes themselves, as
 * the nodes are no longer part of the AST and we are unable to find the related
 * statements later on.
 */
public class JavaRewriterObserver implements AstObserver {
    private final Set<Range> removedRanges = new HashSet<>();
    private final Set<Node> addedOrModifiedNodes = new IdentityHashSet<Node>();

    @Override
    public void propertyChange(Node observedNode, ObservableProperty property, Object oldValue, Object newValue) {
        if (property == ObservableProperty.COMMENT && newValue instanceof Comment comment) {
            addedOrModifiedNodes.add(comment);
        } else if (property == ObservableProperty.NAME) {
            markRemoved(observedNode);
            addedOrModifiedNodes.add(observedNode);
        } else if (property == ObservableProperty.ARGUMENTS) {
            addedOrModifiedNodes.add(observedNode);
        } else if (property == ObservableProperty.TYPE_ARGUMENTS) {
            markRemoved(observedNode);
            addedOrModifiedNodes.add(observedNode);
        }
    }

    @Override
    public void parentChange(Node observedNode, Node previousParent, Node newParent) {
        if (newParent == null) {
            markRemoved(observedNode);
        } else {
            addedOrModifiedNodes.add(observedNode);
        }
    }

    @Override
    public void listChange(NodeList<?> observedNode, AstObserver.ListChangeType type, int index,
            Node nodeAddedOrRemoved) {
        Node parentNode = observedNode.getParentNode().orElse(null);
        if (parentNode != null) {
            if (!(parentNode instanceof CompilationUnit) && !(parentNode instanceof ClassOrInterfaceDeclaration)
                    && !(parentNode instanceof BlockStmt) && !(parentNode instanceof ConstructorDeclaration)) {
                // We are interested in changes happening inside statements only
                // We want to avoid rewriting the whole class or a whole block
                // (like a whole constructor)

                markRemoved(parentNode);
                // We need to process the parent node as "added" later in case
                // e.g. only one method parameter was removed
                addedOrModifiedNodes.add(parentNode);
                return;
            }
            if (nodeAddedOrRemoved instanceof FieldDeclaration && parentNode instanceof ClassOrInterfaceDeclaration
                    && type.equals(ListChangeType.REMOVAL)) {
                nodeAddedOrRemoved.getRange().ifPresent(removedRanges::add);
                return;
            }
        }

        if (type == ListChangeType.ADDITION) {
            addedOrModifiedNodes.add(nodeAddedOrRemoved);
        }
    }

    @Override
    public void listReplacement(NodeList<?> observedNode, int index, Node oldNode, Node newNode) {
        markRemoved(oldNode);
        addedOrModifiedNodes.add(newNode);
    }

    private void markRemoved(Node node) {
        Optional.ofNullable(JavaRewriterUtil.findAncestor(node, Statement.class)).flatMap(Node::getRange)
                .ifPresent(removedRanges::add);
        Optional.ofNullable(JavaRewriterUtil.findAncestor(node, FieldDeclaration.class)).flatMap(Node::getRange)
                .ifPresent(removedRanges::add);
        if (node instanceof MethodDeclaration) {
            // If the node is a method, we also want to remove the whole
            // method declaration
            node.getRange().ifPresent(removedRanges::add);
        }
    }

    public Set<Node> getAddedOrModifiedNodes() {
        return addedOrModifiedNodes;
    }

    public Set<Range> getRemovedRanges() {
        return removedRanges;
    }
}
