/*
 * Copyright 2023 Salesforce, Inc. All rights reserved.
 * The software in this package is published under the terms of the CPAL v1.0
 * license, a copy of which has been included with this distribution in the
 * LICENSE.txt file.
 */
package org.mule.runtime.ast.test.api.util;

import static java.util.Arrays.asList;
import static java.util.Collections.singletonList;
import static java.util.Optional.of;
import static java.util.stream.Stream.empty;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.sameInstance;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.mule.runtime.api.functional.Either.right;
import static org.mule.runtime.ast.api.util.AstTraversalDirection.BOTTOM_UP;
import static org.mule.runtime.ast.api.util.AstTraversalDirection.TOP_DOWN;
import static org.mule.runtime.ast.test.AllureConstants.ArtifactAst.ARTIFACT_AST;
import static org.mule.runtime.ast.test.AllureConstants.ArtifactAst.AstTraversal.AST_TRAVERSAL;

import org.mule.metadata.api.model.ArrayType;
import org.mule.metadata.api.visitor.MetadataTypeVisitor;
import org.mule.runtime.api.meta.model.parameter.ParameterModel;
import org.mule.runtime.api.meta.model.parameter.ParameterizedModel;
import org.mule.runtime.ast.api.ComponentAst;
import org.mule.runtime.ast.api.ComponentParameterAst;
import org.mule.runtime.ast.api.util.AstTraversalDirection;
import org.mule.runtime.ast.api.util.BaseComponentAst;

import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;

import org.junit.Test;

import io.qameta.allure.Feature;
import io.qameta.allure.Issue;
import io.qameta.allure.Story;

@Feature(ARTIFACT_AST)
@Story(AST_TRAVERSAL)
public class AstTraversalDirectionTestCase {

  private ComponentAst createComponentAst(ComponentAst... innerComponents) {
    ComponentAst component = mock(BaseComponentAst.class);
    when(component.getModel(ParameterizedModel.class)).thenReturn(of(mock(ParameterizedModel.class)));
    when(component.directChildrenStream()).thenAnswer(inv -> Stream.of(innerComponents));
    when(component.recursiveStream()).thenCallRealMethod();
    when(component.recursiveStream(any(AstTraversalDirection.class))).thenAnswer(inv -> {
      AstTraversalDirection direction = inv.getArgument(0);
      return StreamSupport.stream(component.recursiveSpliterator(direction), false);
    });
    when(component.recursiveSpliterator()).thenCallRealMethod();
    when(component.recursiveSpliterator(any(AstTraversalDirection.class))).thenAnswer(inv -> {
      AstTraversalDirection direction = inv.getArgument(0);
      return direction.recursiveSpliterator(component);
    });
    return component;
  }

  private ComponentAst createComponentAstWithComplexParam(ComponentAst... innerComponents) {
    ComponentAst component = mock(BaseComponentAst.class);
    when(component.getModel(ParameterizedModel.class)).thenReturn(of(mock(ParameterizedModel.class)));
    when(component.directChildrenStream()).thenReturn(empty());
    when(component.recursiveStream()).thenCallRealMethod();
    when(component.recursiveStream(any(AstTraversalDirection.class))).thenAnswer(inv -> {
      AstTraversalDirection direction = inv.getArgument(0);
      return StreamSupport.stream(component.recursiveSpliterator(direction), false);
    });
    when(component.recursiveSpliterator()).thenCallRealMethod();
    when(component.recursiveSpliterator(any(AstTraversalDirection.class))).thenAnswer(inv -> {
      AstTraversalDirection direction = inv.getArgument(0);
      return direction.recursiveSpliterator(component);
    });

    ArrayType arrayType = mock(ArrayType.class);
    doAnswer(inv -> {
      inv.getArgument(0, MetadataTypeVisitor.class).visitArrayType((ArrayType) inv.getMock());
      return null;
    }).when(arrayType).accept(any());

    ParameterModel complexParamModel = mock(ParameterModel.class);
    when(complexParamModel.getType()).thenReturn(arrayType);
    when(complexParamModel.getName()).thenReturn("complex");

    ComponentParameterAst param = mock(ComponentParameterAst.class);
    when(param.getModel()).thenReturn(complexParamModel);
    when(param.getValue()).thenReturn(right(asList(innerComponents)));

    when(component.getParameters()).thenReturn(singletonList(param));
    when(component.getParameter(anyString(), anyString())).thenReturn(param);

    return component;
  }

  @Test
  public void singleLevel() {
    ComponentAst child1 = createComponentAst();
    ComponentAst child2 = createComponentAst();

    ComponentAst root = createComponentAst(child1, child2);

    final List<ComponentAst> visitedComponents = new ArrayList<>();
    root.recursiveStream().forEach(visitedComponents::add);

    assertThat(visitedComponents, contains(sameInstance(root), sameInstance(child1), sameInstance(child2)));

    visitedComponents.clear();

    root.recursiveStream(TOP_DOWN).forEach(visitedComponents::add);

    assertThat(visitedComponents, contains(sameInstance(root), sameInstance(child1), sameInstance(child2)));

    visitedComponents.clear();

    root.recursiveStream(BOTTOM_UP).forEach(visitedComponents::add);

    assertThat(visitedComponents, contains(sameInstance(child1), sameInstance(child2), sameInstance(root)));
  }

  @Test
  public void twoLevels() {
    ComponentAst child11 = createComponentAst();
    ComponentAst child12 = createComponentAst();
    ComponentAst child1 = createComponentAst(child11, child12);

    ComponentAst child21 = createComponentAst();
    ComponentAst child22 = createComponentAst();
    ComponentAst child2 = createComponentAst(child21, child22);

    ComponentAst root = createComponentAst(child1, child2);

    final List<ComponentAst> visitedComponents = new ArrayList<>();
    root.recursiveStream().forEach(visitedComponents::add);

    assertThat(visitedComponents, contains(sameInstance(root), sameInstance(child1), sameInstance(child11), sameInstance(child12),
                                           sameInstance(child2), sameInstance(child21), sameInstance(child22)));

    visitedComponents.clear();

    root.recursiveStream(TOP_DOWN).forEach(visitedComponents::add);

    assertThat(visitedComponents, contains(sameInstance(root), sameInstance(child1), sameInstance(child11), sameInstance(child12),
                                           sameInstance(child2), sameInstance(child21), sameInstance(child22)));

    visitedComponents.clear();

    root.recursiveStream(BOTTOM_UP).forEach(visitedComponents::add);

    assertThat(visitedComponents, contains(sameInstance(child11), sameInstance(child12), sameInstance(child1),
                                           sameInstance(child21), sameInstance(child22), sameInstance(child2),
                                           sameInstance(root)));
  }

  @Test
  public void multipleLevels() {
    ComponentAst child11 = createComponentAst();

    ComponentAst child121 = createComponentAst();

    ComponentAst child1221 = createComponentAst();
    ComponentAst child1222 = createComponentAst();
    ComponentAst child1223 = createComponentAst();
    ComponentAst child122 = createComponentAst(child1221, child1222, child1223);

    ComponentAst child12 = createComponentAst(child121, child122);

    ComponentAst child1 = createComponentAst(child11, child12);

    ComponentAst child21 = createComponentAst();

    ComponentAst child2 = createComponentAst(child21);

    ComponentAst child3 = createComponentAst();

    ComponentAst root = createComponentAst(child1, child2, child3);

    final List<Object> visitedComponents = new ArrayList<>();

    root.recursiveStream().forEach(visitedComponents::add);

    assertThat(visitedComponents, contains(sameInstance(root), sameInstance(child1), sameInstance(child11), sameInstance(child12),
                                           sameInstance(child121), sameInstance(child122),
                                           sameInstance(child1221), sameInstance(child1222), sameInstance(child1223),
                                           sameInstance(child2), sameInstance(child21), sameInstance(child3)));

    visitedComponents.clear();

    root.recursiveStream(TOP_DOWN).forEach(visitedComponents::add);

    assertThat(visitedComponents, contains(sameInstance(root), sameInstance(child1), sameInstance(child11), sameInstance(child12),
                                           sameInstance(child121), sameInstance(child122),
                                           sameInstance(child1221), sameInstance(child1222), sameInstance(child1223),
                                           sameInstance(child2), sameInstance(child21), sameInstance(child3)));

    visitedComponents.clear();

    root.recursiveStream(BOTTOM_UP).forEach(visitedComponents::add);

    assertThat(visitedComponents, contains(sameInstance(child11),
                                           sameInstance(child121),
                                           sameInstance(child1221), sameInstance(child1222), sameInstance(child1223),
                                           sameInstance(child122),
                                           sameInstance(child12),
                                           sameInstance(child1),
                                           sameInstance(child21), sameInstance(child2),
                                           sameInstance(child3), sameInstance(root)));
  }

  @Test
  @Issue("MULE-19636")
  public void recursiveWithHierarchy() {
    ComponentAst child1 = createComponentAst();
    ComponentAst child2 = createComponentAst();

    ComponentAst root = createComponentAstWithComplexParam(child1, child2);

    final List<ComponentAst> visitedComponents = new ArrayList<>();

    TOP_DOWN.recursiveStreamWithHierarchy(Stream.of(root), false)
        .forEach(comp -> visitedComponents.add(comp.getFirst()));

    assertThat(visitedComponents, contains(sameInstance(root), sameInstance(child1), sameInstance(child2)));
  }

  @Test
  @Issue("MULE-19636")
  public void recursiveWithHierarchyIgnoreComplex() {
    ComponentAst child1 = createComponentAst();
    ComponentAst child2 = createComponentAst();

    ComponentAst root = createComponentAstWithComplexParam(child1, child2);

    final List<ComponentAst> visitedComponents = new ArrayList<>();

    TOP_DOWN.recursiveStreamWithHierarchy(Stream.of(root), true)
        .forEach(comp -> visitedComponents.add(comp.getFirst()));

    assertThat(visitedComponents, contains(sameInstance(root)));
  }
}
