package com.avioconsulting.mule.opentelemetry.internal.store;

import com.avioconsulting.mule.opentelemetry.api.sdk.SemanticAttributes;
import com.avioconsulting.mule.opentelemetry.api.store.SpanMeta;
import com.avioconsulting.mule.opentelemetry.api.store.TransactionMeta;
import com.avioconsulting.mule.opentelemetry.api.traces.TraceComponent;
import com.avioconsulting.mule.opentelemetry.internal.processor.service.ComponentRegistryService;
import com.avioconsulting.mule.opentelemetry.internal.processor.util.TraceComponentManager;
import com.avioconsulting.mule.opentelemetry.internal.util.ComponentsUtil;
import io.opentelemetry.api.trace.Span;
import io.opentelemetry.api.trace.SpanBuilder;
import io.opentelemetry.api.trace.SpanKind;
import io.opentelemetry.context.Context;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.time.Instant;
import java.util.HashMap;
import java.util.Map;
import java.util.StringTokenizer;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.function.Function;

import static com.avioconsulting.mule.opentelemetry.api.sdk.SemanticAttributes.*;
import static com.avioconsulting.mule.opentelemetry.internal.util.ComponentsUtil.BATCH_AGGREGATOR;
import static com.avioconsulting.mule.opentelemetry.internal.util.ComponentsUtil.*;
import static com.avioconsulting.mule.opentelemetry.internal.util.BatchHelperUtil.copyBatchTags;
import static com.avioconsulting.mule.opentelemetry.internal.util.OpenTelemetryUtil.tagsToAttributes;

public class BatchTransaction extends AbstractTransaction {

  private static final Logger LOGGER = LoggerFactory.getLogger(BatchTransaction.class);

  // Lock-free implementation with AtomicReferences
  private final ConcurrentHashMap<String, AtomicReference<ContainerSpan>> stepSpans = new ConcurrentHashMap<>();
  private final Map<String, ProcessorSpan> stepProcessorSpans = new ConcurrentHashMap<>();
  private final Span rootSpan;
  private final Function<String, SpanBuilder> spanBuilderFunction;

  private final ComponentRegistryService componentRegistryService;
  private boolean rootSpanEnded = false;
  private final Map<String, String> batchTags = new HashMap<>();
  private final Map<String, String> stepLocationNames = new HashMap<>();
  private final Context rootContext;

  public BatchTransaction(String jobInstanceId, String traceId, String batchJobName,
      Span rootSpan, TraceComponent batchTraceComponent, Function<String, SpanBuilder> spanBuilderFunction,
      ComponentRegistryService componentRegistryService) {
    super(jobInstanceId, traceId, batchJobName, batchTraceComponent.getStartTime());
    batchTraceComponent.copyTagsTo(this.batchTags);
    this.rootSpan = rootSpan;
    this.rootContext = rootSpan.storeInContext(Context.current());
    this.spanBuilderFunction = spanBuilderFunction;
    this.componentRegistryService = componentRegistryService;
    tagsToAttributes(batchTraceComponent, rootSpan);
    extractStepLocations(batchTraceComponent);
    setTransactionContext();
  }

  /**
   * List of job steps and locations are read while processing the batch:job
   * element.
   * This helps to know where steps are located.
   * 
   * @param batchTraceComponent
   *            {@link TraceComponent} representing batch
   * @link TraceComponent with tags containing MULE_BATCH_JOB_STEPS
   */
  private void extractStepLocations(TraceComponent batchTraceComponent) {
    String jobSteps = batchTraceComponent.getTag(MULE_BATCH_JOB_STEPS.getKey());
    if (jobSteps != null) {
      StringTokenizer tokenizer = new StringTokenizer(jobSteps, ",");
      while (tokenizer.hasMoreTokens()) {
        String step = tokenizer.nextToken().trim();
        if (step.isEmpty()) {
          continue;
        }
        String stepName = step.substring(0, step.indexOf("|"));
        String stepLocation = step.substring(step.indexOf("|") + 1);
        stepLocationNames.put(stepLocation, stepName);
      }
    }
  }

  /**
   * Computes and Adds a container such as Step OR On-Complete block span if one
   * does not exist
   * 
   * @param location
   *            {@link String}
   * @param stepName
   *            {@link String}
   * @param processorTrace
   *            {@link TraceComponent}
   * @return A newly created or an existing {@link ContainerSpan} for this step
   *         location
   */
  private ContainerSpan addOrGetContainerSpan(final String location, String stepName, TraceComponent processorTrace) {
    AtomicReference<ContainerSpan> containerSpanRef = stepSpans.computeIfAbsent(location,
        s -> new AtomicReference<>());
    String name = BATCH_STEP_TAG;
    String spanName = stepName;
    if (ComponentsUtil.isBatchOnComplete(location, componentRegistryService)) {
      name = BATCH_ON_COMPLETE_TAG;
      spanName = BATCH_ON_COMPLETE_TAG;
    } else {
      name = BATCH_STEP_TAG + ":" + stepName;
    }
    ContainerSpan ContainerSpan = containerSpanRef.get();
    if (ContainerSpan == null) {
      try (TraceComponent stepTraceComponent = TraceComponentManager.getInstance()
          .createTraceComponent(processorTrace.getTransactionId(), name)) {
        stepTraceComponent
            .withSpanName(spanName)
            .withSpanKind(SpanKind.INTERNAL)
            .withLocation(location)
            .withEventContextId(processorTrace.getEventContextId())
            .withStartTime(processorTrace.getStartTime());
        processorTrace.copyTagsTo(stepTraceComponent, key -> key.startsWith("mule.app.processor"));
        if (stepName != null) {
          stepTraceComponent.addTag(MULE_BATCH_JOB_STEP_NAME.getKey(), stepName);
        }
        SpanBuilder spanBuilder = spanBuilderFunction.apply(name)
            .setParent(rootContext);
        ContainerSpan newContainerSpan = new ContainerSpan(location, spanBuilder.startSpan(),
            stepTraceComponent);
        if (containerSpanRef.compareAndSet(null, newContainerSpan)) {
          // This only executes if the update succeeds. If another thread updates the
          // reference, this will be skipped.
          stepProcessorSpans.putIfAbsent(stepTraceComponent.getSpanName(),
              newContainerSpan.getRootProcessorSpan());
          ContainerSpan = newContainerSpan;
        }
      }
    }
    return ContainerSpan;
  }

  public SpanMeta addProcessorSpan(String containerPath, TraceComponent traceComponent, SpanBuilder spanBuilder) {
    SpanMeta spanMeta = null;
    String stepName = traceComponent.getTag(SemanticAttributes.MULE_BATCH_JOB_STEP_NAME.getKey());
    if (isBatchOnComplete(containerPath, componentRegistryService)) {
      stepName = BATCH_ON_COMPLETE_TAG;
    }
    ProcessorSpan containerProcessorSpan = getStepProcessorSpan(traceComponent, stepName);

    ContainerSpan ContainerSpan = null;
    if (containerProcessorSpan == null) {
      ContainerSpan = addOrGetContainerSpan(containerPath, stepLocationNames.get(containerPath), traceComponent);
      containerProcessorSpan = getStepProcessorSpan(traceComponent, stepName);
    } else {
      ContainerSpan = stepSpans.get(containerProcessorSpan.getLocation()).get();
    }

    if (containerProcessorSpan.getLocation().equalsIgnoreCase(containerPath)) {
      spanMeta = processContainerChild(containerPath, traceComponent, spanBuilder, ContainerSpan,
          containerProcessorSpan);
    } else {
      SpanMeta aggrSpan = addAggregatorSpanIfNeeded(containerPath, traceComponent, ContainerSpan,
          containerProcessorSpan);
      if (aggrSpan != null) {
        traceComponent.withContext(aggrSpan.getContext());
      }
      spanMeta = ContainerSpan.addProcessorSpan(containerPath, traceComponent, spanBuilder);
    }
    return spanMeta;
  }

  private SpanMeta addAggregatorSpanIfNeeded(String containerPath, TraceComponent traceComponent,
      ContainerSpan ContainerSpan,
      ProcessorSpan containerProcessorSpan) {
    SpanMeta aggrSpan = null;
    if (containerPath.endsWith("/aggregator")) {
      if (null == ContainerSpan.findSpan(traceComponent.contextScopedPath(containerPath))) {
        SpanBuilder aggrSpanBuilder = spanBuilderFunction.apply(BATCH_AGGREGATOR)
            .setParent(containerProcessorSpan.getContext())
            .setSpanKind(SpanKind.INTERNAL)
            .setStartTimestamp(traceComponent.getStartTime());
        try (TraceComponent aggrTraceComponent = TraceComponentManager.getInstance()
            .createTraceComponent(traceComponent.getTransactionId(), BATCH_AGGREGATOR)) {
          aggrTraceComponent
              .withLocation(containerPath)
              .withSpanName(BATCH_AGGREGATOR)
              .withContext(containerProcessorSpan.getContext())
              .withSpanKind(SpanKind.INTERNAL)
              .withStartTime(traceComponent.getStartTime())
              .withEventContextId(traceComponent.getEventContextId())
              .withSiblings(traceComponent.getSiblings());
          copyBatchTags(traceComponent, aggrTraceComponent);
          aggrTraceComponent.addTag(MULE_APP_PROCESSOR_NAMESPACE.getKey(),
              "batch");
          aggrTraceComponent.addTag(MULE_APP_PROCESSOR_NAME.getKey(),
              "aggregator");
          aggrSpan = ContainerSpan.addProcessorSpan(
              containerPath.substring(0, containerPath.lastIndexOf("/")),
              aggrTraceComponent,
              aggrSpanBuilder);
        }
      }
    }
    return aggrSpan;
  }

  /**
   * Creates a new span when target is processor 0 in the step, otherwise
   * adds a new span under the existing record parent.
   *
   * When processing on-complete block, adds spans to the parent block.
   * 
   * @param containerPath
   *            {@link String} must be a step location
   * @param traceComponent
   *            {@link TraceComponent}
   * @param spanBuilder
   *            {@link SpanBuilder}
   * @param stepSpan
   *            {@link ContainerSpan}
   * @param processorSpan
   *            {@link ProcessorSpan}
   * @return Newly created span for a target component
   */
  private SpanMeta processContainerChild(String containerPath, TraceComponent traceComponent,
      SpanBuilder spanBuilder,
      ContainerSpan stepSpan, ProcessorSpan processorSpan) {
    SpanMeta spanMeta;
    if (isBatchOnComplete(containerPath, componentRegistryService)) {
      // String onCompletePath = containerPath + "/on-complete";
      spanMeta = stepSpan.addProcessorSpan(containerPath, traceComponent, spanBuilder);
    } else {
      String recordPath = containerPath + "/record";
      if (traceComponent.getLocation().equalsIgnoreCase(containerPath + "/processors/0")) {
        // Create a record span
        try (TraceComponent recordTrace = TraceComponentManager.getInstance()
            .createTraceComponent(traceComponent.getTransactionId(), BATCH_STEP_RECORD_TAG)) {
          recordTrace
              .withLocation(recordPath)
              .withStartTime(traceComponent.getStartTime())
              .withSpanName(BATCH_STEP_RECORD_TAG)
              .withEventContextId(traceComponent.getEventContextId());
          traceComponent.copyTagsTo(recordTrace);
          SpanBuilder record = spanBuilderFunction.apply(recordTrace.getName());

          SpanMeta recordSpanMeta = stepSpan.addChildContainer(recordTrace,
              record.setParent(processorSpan.getContext()));
          spanMeta = stepSpan.addProcessorSpan(recordTrace.getLocation(), traceComponent,
              spanBuilder.setParent(recordSpanMeta.getContext()));
        }
      } else {
        spanMeta = stepSpan.addProcessorSpan(recordPath, traceComponent, spanBuilder);
      }
    }
    return spanMeta;
  }

  private ProcessorSpan getStepProcessorSpan(TraceComponent traceComponent, String containerName) {
    if (containerName == null) {
      return null;
    }
    ProcessorSpan processorSpan = stepProcessorSpans.get(containerName);
    if (processorSpan != null) {
      traceComponent.copyTagsTo(processorSpan.getTags());
    }
    return processorSpan;
  }

  private ProcessorSpan getStepProcessorSpan(TraceComponent traceComponent) {
    String containerName = traceComponent.getTag(SemanticAttributes.MULE_BATCH_JOB_STEP_NAME.getKey());
    return getStepProcessorSpan(traceComponent, containerName);
  }

  @Override
  public SpanMeta endProcessorSpan(TraceComponent traceComponent, Consumer<Span> spanUpdater, Instant endTime) {
    if (BATCH_STEP_TAG.equalsIgnoreCase(traceComponent.getName())
        || BATCH_ON_COMPLETE_TAG.equalsIgnoreCase(traceComponent.getName())) {
      return endContainerSpan(traceComponent, stepProcessorSpans.get(traceComponent.getSpanName()));
    } else {
      String spanName = traceComponent.getTag(SemanticAttributes.MULE_BATCH_JOB_STEP_NAME.getKey());
      if (spanName == null) {
        String locationParent = getLocationParent(traceComponent.getLocation());
        if (isBatchOnComplete(locationParent, componentRegistryService)) {
          spanName = BATCH_ON_COMPLETE_TAG;
        }
      }
      if (spanName != null) {
        String stepLocation = stepProcessorSpans.get(spanName).getLocation();
        ContainerSpan ContainerSpan = stepSpans.get(stepLocation).get();
        SpanMeta spanMeta = ContainerSpan.endProcessorSpan(traceComponent, spanUpdater, endTime);
        String aggregatorLocation = stepLocation + "/aggregator";
        if (traceComponent.getLocation() != null
            && traceComponent.getLocation().startsWith(aggregatorLocation)) {
          ProcessorSpan aggrSpan = ContainerSpan
              .findSpan(traceComponent.contextScopedPath(aggregatorLocation));
          if (aggrSpan.getSiblings() == 0) {
            try (TraceComponent aggrTraceComponent = TraceComponentManager.getInstance()
                .createTraceComponent(traceComponent.getTransactionId(), BATCH_AGGREGATOR)) {
              aggrTraceComponent
                  .withLocation(aggregatorLocation)
                  .withSpanName(BATCH_AGGREGATOR)
                  .withSpanKind(SpanKind.INTERNAL)
                  .withEventContextId(traceComponent.getEventContextId())
                  .withEndTime(traceComponent.getEndTime());
              traceComponent.copyTagsTo(aggrTraceComponent);
              ContainerSpan.endProcessorSpan(aggrTraceComponent, spanUpdater, endTime);
            }
          }
        }
        return spanMeta;
      }
    }
    return null;
  }

  private ProcessorSpan endContainerSpan(TraceComponent traceComponent, ProcessorSpan processorSpan) {
    if (processorSpan == null)
      return null;
    ContainerSpan stepSpan = stepSpans.get(processorSpan.getLocation()).get();
    processorSpan.setEndTime(traceComponent.getEndTime());
    traceComponent.copyTagsTo(processorSpan.getTags());
    stepSpan.getSpan().end(traceComponent.getEndTime());
    return processorSpan;
  }

  @Override
  public Span getTransactionSpan() {
    return rootSpan;
  }

  @Override
  public void endRootSpan(TraceComponent traceComponent, Consumer<Span> endSpan) {
    this.stepProcessorSpans.forEach((location, processorSpan) -> {
      endContainerSpan(traceComponent, processorSpan);
    });
    super.endRootSpan(traceComponent, endSpan);
    rootSpanEnded = true;
  }

  @Override
  public boolean hasEnded() {
    return rootSpanEnded;
  }

  @Override
  public void addChildTransaction(TraceComponent traceComponent, SpanBuilder spanBuilder) {
    ProcessorSpan processorSpan = getStepProcessorSpan(traceComponent);
    ContainerSpan stepSpan = stepSpans.get(processorSpan.getLocation()).get();
    stepSpan.addChildContainer(traceComponent, spanBuilder);
  }

  @Override
  public TransactionMeta endChildTransaction(TraceComponent traceComponent, Consumer<Span> endSpan) {
    ProcessorSpan processorSpan = getStepProcessorSpan(traceComponent);
    ContainerSpan stepSpan = stepSpans.get(processorSpan.getLocation()).get();
    return stepSpan.endChildContainer(traceComponent, endSpan);
  }

  @Override
  public ProcessorSpan findSpan(String location) {
    ProcessorSpan processorSpan = null;
    for (Map.Entry<String, AtomicReference<ContainerSpan>> entry : stepSpans.entrySet()) {
      if ((processorSpan = entry.getValue().get().findSpan(location)) != null) {
        return processorSpan;
      }
    }
    return null;
  }

  @Override
  public Span getSpan() {
    return rootSpan;
  }

  @Override
  public Map<String, String> getTags() {
    return batchTags;
  }
}
