/*
 * Decompiled with CFR 0.152.
 */
package com.powsybl.timeseries.ast;

import com.powsybl.timeseries.ast.AbstractBinaryNodeCalc;
import com.powsybl.timeseries.ast.AbstractSingleChildNodeCalc;
import com.powsybl.timeseries.ast.BinaryMaxCalc;
import com.powsybl.timeseries.ast.BinaryMinCalc;
import com.powsybl.timeseries.ast.BinaryOperation;
import com.powsybl.timeseries.ast.CachedNodeCalc;
import com.powsybl.timeseries.ast.MaxNodeCalc;
import com.powsybl.timeseries.ast.MinNodeCalc;
import com.powsybl.timeseries.ast.NodeCalc;
import com.powsybl.timeseries.ast.NodeCalcDuplicateDetector;
import com.powsybl.timeseries.ast.NodeCalcModifier;
import com.powsybl.timeseries.ast.TimeNodeCalc;
import com.powsybl.timeseries.ast.UnaryOperation;
import java.util.IdentityHashMap;
import java.util.Map;
import java.util.Set;

public class NodeCalcCacheCreator
extends NodeCalcModifier<Map<NodeCalc, NodeCalc>> {
    public static NodeCalc cacheDuplicated(NodeCalc nodeCalc) {
        return new NodeCalcCacheCreator().createCachedNodes(nodeCalc);
    }

    private NodeCalc createCachedNodes(NodeCalc nodeCalc) {
        Map<NodeCalc, Set<NodeCalc>> parents = NodeCalcDuplicateDetector.detectDuplicates(nodeCalc);
        IdentityHashMap cachedNodes = new IdentityHashMap();
        parents.forEach((child, childParents) -> {
            if (childParents.size() > 1) {
                cachedNodes.put(child, childParents.stream().filter(CachedNodeCalc.class::isInstance).findFirst().orElseGet(() -> new CachedNodeCalc((NodeCalc)child)));
            }
        });
        return nodeCalc.accept(this, cachedNodes, 0);
    }

    @Override
    public NodeCalc visit(BinaryOperation nodeCalc, Map<NodeCalc, NodeCalc> parents, NodeCalc left, NodeCalc right) {
        return this.visitBinaryNodeCalc(nodeCalc, parents, left, right);
    }

    @Override
    public NodeCalc visit(BinaryMaxCalc nodeCalc, Map<NodeCalc, NodeCalc> parents, NodeCalc left, NodeCalc right) {
        return this.visitBinaryNodeCalc(nodeCalc, parents, left, right);
    }

    @Override
    public NodeCalc visit(BinaryMinCalc nodeCalc, Map<NodeCalc, NodeCalc> parents, NodeCalc left, NodeCalc right) {
        return this.visitBinaryNodeCalc(nodeCalc, parents, left, right);
    }

    @Override
    public NodeCalc visit(UnaryOperation nodeCalc, Map<NodeCalc, NodeCalc> parents, NodeCalc child) {
        return this.visitSingleChildNodeCalc(nodeCalc, parents, child);
    }

    @Override
    public NodeCalc visit(MinNodeCalc nodeCalc, Map<NodeCalc, NodeCalc> parents, NodeCalc child) {
        return this.visitSingleChildNodeCalc(nodeCalc, parents, child);
    }

    @Override
    public NodeCalc visit(MaxNodeCalc nodeCalc, Map<NodeCalc, NodeCalc> parents, NodeCalc child) {
        return this.visitSingleChildNodeCalc(nodeCalc, parents, child);
    }

    @Override
    public NodeCalc visit(TimeNodeCalc nodeCalc, Map<NodeCalc, NodeCalc> parents, NodeCalc child) {
        return this.visitSingleChildNodeCalc(nodeCalc, parents, child);
    }

    @Override
    public NodeCalc visit(CachedNodeCalc nodeCalc, Map<NodeCalc, NodeCalc> parents, NodeCalc child) {
        if (child != null && !(child instanceof CachedNodeCalc)) {
            nodeCalc.setChild(child);
        }
        return null;
    }

    private NodeCalc visitBinaryNodeCalc(AbstractBinaryNodeCalc nodeCalc, Map<NodeCalc, NodeCalc> parents, NodeCalc left, NodeCalc right) {
        if (left != null) {
            nodeCalc.setLeft(left);
        }
        if (right != null) {
            nodeCalc.setRight(right);
        }
        return parents.getOrDefault(nodeCalc, null);
    }

    private NodeCalc visitSingleChildNodeCalc(AbstractSingleChildNodeCalc nodeCalc, Map<NodeCalc, NodeCalc> parents, NodeCalc child) {
        if (child != null) {
            nodeCalc.setChild(child);
        }
        return parents.getOrDefault(nodeCalc, null);
    }
}

