/*
 * Decompiled with CFR 0.152.
 */
package com.boozallen.aiops.mda.generator;

import com.boozallen.aiops.mda.generator.TargetedPipelinePyProjectGenerator;
import com.boozallen.aiops.mda.generator.util.PipelineUtils;
import com.boozallen.aiops.mda.metamodel.element.Pipeline;
import com.boozallen.aiops.mda.metamodel.element.PostAction;
import com.boozallen.aiops.mda.metamodel.element.Step;
import com.boozallen.aiops.mda.metamodel.element.python.MachineLearningPipeline;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.velocity.VelocityContext;
import org.technologybrewery.fermenter.mda.generator.GenerationContext;

public class MlTrainingPyProjectGenerator
extends TargetedPipelinePyProjectGenerator {
    private static final String ONNX_ML_TOOLS_DEPENDENCY = "onnxmltools = \"^1.11.1\"";
    private static final String ONNX_KERAS_DEPENDENCY = "tf2onnx = \"^1.12.1\"";

    @Override
    protected void doGenerateFile(GenerationContext generationContext, VelocityContext velocityContext, Pipeline pipeline) {
        MachineLearningPipeline mlPipeline = new MachineLearningPipeline(pipeline);
        Step trainingStep = mlPipeline.getTrainingStep();
        Set<String> postActionDependencies = null;
        velocityContext.put("pipeline", (Object)mlPipeline);
        List<PostAction> postActions = trainingStep.getPostActions();
        if (CollectionUtils.isNotEmpty(postActions)) {
            postActionDependencies = this.getPostActionDependencies(postActions);
            velocityContext.put("postActionRequirements", postActionDependencies);
        }
        this.generateFile(generationContext, velocityContext);
        if (CollectionUtils.isNotEmpty(postActionDependencies)) {
            this.manualActionNotificationService.addNoticeToAddPythonDependencies(generationContext, postActionDependencies, "post action support");
        }
    }

    protected Set<String> getPostActionDependencies(List<PostAction> postActions) {
        LinkedHashSet<String> postActionRequirements = new LinkedHashSet<String>();
        for (PostAction postAction : postActions) {
            if (!PipelineUtils.forOnnxModelConversion(postAction)) continue;
            postActionRequirements.add(ONNX_ML_TOOLS_DEPENDENCY);
            String modelSource = postAction.getModelSource();
            if (!"keras".equals(modelSource)) continue;
            postActionRequirements.add(ONNX_KERAS_DEPENDENCY);
        }
        return postActionRequirements;
    }
}

