/*
 * Decompiled with CFR 0.152.
 */
package com.antgroup.geaflow.plan.optimizer;

import com.antgroup.geaflow.common.errorcode.RuntimeErrors;
import com.antgroup.geaflow.common.exception.GeaflowRuntimeException;
import com.antgroup.geaflow.common.serialize.SerializerFactory;
import com.antgroup.geaflow.operator.base.AbstractOperator;
import com.antgroup.geaflow.operator.impl.window.FilterOperator;
import com.antgroup.geaflow.operator.impl.window.KeySelectorOperator;
import com.antgroup.geaflow.operator.impl.window.MapOperator;
import com.antgroup.geaflow.operator.impl.window.SinkOperator;
import com.antgroup.geaflow.operator.impl.window.UnionOperator;
import com.antgroup.geaflow.partitioner.IPartitioner;
import com.antgroup.geaflow.plan.graph.PipelineEdge;
import com.antgroup.geaflow.plan.graph.PipelineGraph;
import com.antgroup.geaflow.plan.graph.PipelineVertex;
import com.google.common.base.Preconditions;
import com.google.common.collect.ArrayListMultimap;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class UnionOptimizer
implements Serializable {
    private static final Logger LOGGER = LoggerFactory.getLogger(UnionOptimizer.class);
    private Map<Integer, PipelineVertex> vertexMap;
    private Map<Integer, Set<PipelineEdge>> outputEdges;
    private Map<Integer, Set<PipelineEdge>> inputEdges;
    private Set<PipelineVertex> visited = new HashSet<PipelineVertex>();
    private Set<PipelineVertex> newVertices = new HashSet<PipelineVertex>();
    private Set<PipelineEdge> newEdges = new LinkedHashSet<PipelineEdge>();
    private Map<Integer, PipelineVertex> newVertexMap = new HashMap<Integer, PipelineVertex>();
    private boolean needOptimize;
    private List statelessOperator = new ArrayList<Class>(Arrays.asList(FilterOperator.class, MapOperator.class, UnionOperator.class, KeySelectorOperator.class));
    private PipelineVertex unionVertex;
    private int removeUnionNum = 0;

    public UnionOptimizer(boolean extraOptimizeSink) {
        LOGGER.info("extraOptimizeSink {}", (Object)extraOptimizeSink);
        if (extraOptimizeSink) {
            this.statelessOperator.add(SinkOperator.class);
        }
    }

    private void init(PipelineGraph plan) {
        this.vertexMap = plan.getVertexMap();
        this.outputEdges = plan.getVertexOutputEdges();
        this.inputEdges = plan.getVertexInputEdges();
        this.needOptimize = false;
    }

    public boolean optimizePlan(PipelineGraph plan) {
        this.init(plan);
        if (!this.pushUpPartitionFunction(plan)) {
            return false;
        }
        try {
            plan.getSourceVertices().forEach(this::dfs);
            if (!this.needOptimize) {
                return false;
            }
            this.kahn(plan.getSourceVertices(), 0);
        }
        catch (Exception ex) {
            LOGGER.warn("Unexpected exception happened while optimizing, thus give up", (Throwable)ex);
            return false;
        }
        if (this.newVertices.size() + this.removeUnionNum < plan.getVertexMap().size()) {
            LOGGER.warn(String.format("vertices number %s plus remove union vertices num %s is smaller after optimization %s, this is not right, thus give up", new Object[0]), new Object[]{this.newVertices.size(), this.removeUnionNum, plan.getVertexMap().size()});
            return false;
        }
        plan.setPipelineEdges(this.newEdges);
        plan.setPipelineVertices(this.newVertices);
        return true;
    }

    private boolean pushUpPartitionFunction(PipelineGraph plan) {
        for (PipelineVertex vertex : plan.getPipelineVertices()) {
            PipelineEdge edge;
            if (!(vertex.getOperator() instanceof UnionOperator) || (edge = this.getOutEdgeFromVertex(vertex)) == null || edge.getPartition().getPartitionType().isEnablePushUp()) continue;
            IPartitioner partitionFunction = edge.getPartition();
            edge.setPartitionType(IPartitioner.PartitionType.forward);
            for (PipelineEdge inEdge : plan.getVertexInputEdges().get(vertex.getVertexId())) {
                if (inEdge.getPartitionType() == IPartitioner.PartitionType.key || inEdge.getPartition() != null) {
                    return false;
                }
                inEdge.setPartition(partitionFunction);
                edge.setPartitionType(IPartitioner.PartitionType.key);
            }
        }
        return true;
    }

    private void dfs(PipelineVertex vertex) {
        if (!this.visited.add(vertex)) {
            return;
        }
        LOGGER.debug("visit vertex {}", (Object)vertex);
        if (this.needPushUp(vertex)) {
            this.pushUp(vertex);
            vertex = this.unionVertex;
            this.needOptimize = true;
        } else if (this.unionVertex != null) {
            this.tryRemoveVertex(this.unionVertex);
            this.unionVertex = null;
        }
        for (PipelineEdge executeEdge : this.outputEdges.get(vertex.getVertexId())) {
            PipelineVertex nextVertex = this.vertexMap.get(executeEdge.getTargetId());
            this.dfs(nextVertex);
        }
    }

    private PipelineEdge getOutEdgeFromVertex(PipelineVertex vertex) {
        Iterator<PipelineEdge> it = this.outputEdges.get(vertex.getVertexId()).iterator();
        if (it.hasNext()) {
            return it.next();
        }
        return null;
    }

    private void pushUp(PipelineVertex sVertex) {
        sVertex.setDuplication();
        if (sVertex.equals(this.unionVertex)) {
            return;
        }
        LOGGER.info("pushUp vertex {}", (Object)sVertex);
        PipelineEdge unionOutEdge = this.getOutEdgeFromVertex(this.unionVertex);
        Preconditions.checkNotNull((Object)unionOutEdge);
        PipelineEdge sOutEdge = this.getOutEdgeFromVertex(sVertex);
        this.inputEdges.get(sVertex.getVertexId()).remove(unionOutEdge);
        for (PipelineEdge unionInEdge : this.inputEdges.get(this.unionVertex.getVertexId())) {
            unionInEdge.setTargetId(sVertex.getVertexId());
            this.inputEdges.get(sVertex.getVertexId()).add(unionInEdge);
        }
        if (sOutEdge != null) {
            this.inputEdges.get(sOutEdge.getTargetId()).remove(sOutEdge);
            unionOutEdge.setTargetId(sOutEdge.getTargetId());
            this.inputEdges.get(sOutEdge.getTargetId()).add(unionOutEdge);
        }
        this.inputEdges.get(this.unionVertex.getVertexId()).clear();
        if (sOutEdge != null) {
            sOutEdge.setTargetId(this.unionVertex.getVertexId());
            this.inputEdges.get(this.unionVertex.getVertexId()).add(sOutEdge);
        }
        if (sOutEdge != null && sOutEdge.getPartitionType().equals((Object)IPartitioner.PartitionType.key)) {
            unionOutEdge.setStreamOrdinal(sOutEdge.getStreamOrdinal());
            unionOutEdge.setPartitionType(IPartitioner.PartitionType.key);
            sOutEdge.setPartitionType(IPartitioner.PartitionType.forward);
        }
        this.tryRemoveVertex(sVertex);
    }

    private boolean needPushUp(PipelineVertex vertex) {
        if (this.outputEdges.get(vertex.getVertexId()).size() > 1) {
            return false;
        }
        for (PipelineEdge edge : this.outputEdges.get(vertex.getVertexId())) {
            if (edge.getPartition() == null || edge.getPartition().getPartitionType().isEnablePushUp()) continue;
            return false;
        }
        if (this.unionVertex != null && this.getOutVertexIdSet(this.unionVertex).contains(vertex.getVertexId()) && this.statelessOperator.contains(vertex.getOperator().getClass())) {
            return true;
        }
        if (vertex.getOperator() instanceof UnionOperator) {
            this.unionVertex = vertex;
            return true;
        }
        return false;
    }

    private void tryRemoveVertex(PipelineVertex vertex) {
        LOGGER.debug("try remove vertex {}", (Object)vertex);
        if (vertex.isDuplication() && vertex.getOperator() instanceof UnionOperator && this.outputEdges.get(vertex.getVertexId()).size() == 1) {
            LOGGER.info("remove vertex {}", (Object)vertex);
            ++this.removeUnionNum;
            PipelineEdge outputEdge = this.outputEdges.get(vertex.getVertexId()).iterator().next();
            int targetId = outputEdge.getTargetId();
            this.inputEdges.get(targetId).remove(outputEdge);
            for (PipelineEdge inEdge : this.inputEdges.get(vertex.getVertexId())) {
                inEdge.setTargetId(targetId);
                this.inputEdges.get(targetId).add(inEdge);
                if (outputEdge.getPartitionType() != IPartitioner.PartitionType.key) continue;
                inEdge.setPartitionType(IPartitioner.PartitionType.key);
            }
        }
    }

    private void kahn(List<PipelineVertex> vertices, int id) throws IOException, ClassNotFoundException {
        ArrayDeque<PipelineVertex> toVisitQueue = new ArrayDeque<PipelineVertex>(vertices);
        HashSet<PipelineEdge> visitedEdge = new HashSet<PipelineEdge>();
        ArrayListMultimap oldIdToNewIdMap = ArrayListMultimap.create();
        while (!toVisitQueue.isEmpty()) {
            PipelineVertex vertex = (PipelineVertex)toVisitQueue.poll();
            if (!vertex.isDuplication()) {
                PipelineVertex newVertex = this.cloneVertex(vertex, ++id, 0);
                this.newVertices.add(newVertex);
                this.newVertexMap.put(id, newVertex);
                oldIdToNewIdMap.put((Object)vertex.getVertexId(), (Object)id);
            }
            int index = 0;
            for (PipelineEdge inEdge : this.inputEdges.get(vertex.getVertexId())) {
                PipelineVertex oriSrcVertex = this.vertexMap.get(inEdge.getSrcId());
                int oriSrcId = oriSrcVertex.getVertexId();
                for (Integer srcId : oldIdToNewIdMap.get((Object)oriSrcId)) {
                    if (vertex.isDuplication()) {
                        PipelineVertex newVertex;
                        if (!((newVertex = this.cloneVertex(vertex, ++id, this.newVertexMap.get(srcId).getParallelism())).getOperator() instanceof SinkOperator)) {
                            this.changeOperatorName(newVertex, index++);
                        }
                        this.newVertices.add(newVertex);
                        this.newVertexMap.put(id, newVertex);
                        oldIdToNewIdMap.put((Object)vertex.getVertexId(), (Object)id);
                    }
                    PipelineEdge newEdge = new PipelineEdge((int)srcId, (int)srcId, id, inEdge.getPartition(), inEdge.getStreamOrdinal(), inEdge.getEncoder());
                    if (this.vertexMap.get(oriSrcId).isDuplication()) {
                        IPartitioner partitionFunction = newEdge.getPartition();
                        newEdge.setEdgeName(String.format("union-%d-%s-%s-%s", new Object[]{id, newEdge.getPartitionType(), PipelineEdge.JoinStream.values()[newEdge.getStreamOrdinal()], partitionFunction != null ? partitionFunction.getClass().getSimpleName() : "none"}));
                    }
                    this.newEdges.add(newEdge);
                }
            }
            for (PipelineEdge executeEdge : this.outputEdges.get(vertex.getVertexId())) {
                PipelineVertex nextVertex = this.vertexMap.get(executeEdge.getTargetId());
                visitedEdge.add(executeEdge);
                if (!visitedEdge.containsAll((Collection)this.inputEdges.get(nextVertex.getVertexId()))) continue;
                toVisitQueue.add(nextVertex);
            }
        }
    }

    private byte[] toByteArray(Object obj) throws IOException {
        ByteArrayOutputStream bos = new ByteArrayOutputStream();
        ObjectOutputStream oos = new ObjectOutputStream(bos);
        oos.writeObject(obj);
        oos.flush();
        byte[] bytes = bos.toByteArray();
        oos.close();
        bos.close();
        return bytes;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private Object toObject(byte[] bytes) throws IOException, ClassNotFoundException {
        ByteArrayInputStream bis = new ByteArrayInputStream(bytes);
        bis.mark(0);
        Object obj = null;
        try (ObjectInputStream ois = new ObjectInputStream(bis);){
            obj = ois.readObject();
        }
        bis.close();
        return obj;
    }

    private PipelineVertex cloneVertex(PipelineVertex vertex, int id, int parallelism) throws IOException, ClassNotFoundException {
        PipelineVertex cloned;
        block3: {
            LOGGER.debug("clone Vertex {}", (Object)vertex);
            try {
                byte[] out = SerializerFactory.getKryoSerializer().serialize((Object)vertex);
                cloned = (PipelineVertex)SerializerFactory.getKryoSerializer().deserialize(out);
            }
            catch (Exception ex) {
                LOGGER.warn("vertex {} kryo fail, try java serde, ex: {}", (Object)vertex, (Object)Arrays.toString(ex.getStackTrace()));
                cloned = (PipelineVertex)this.toObject(this.toByteArray(vertex));
                if (cloned != null) break block3;
                throw new GeaflowRuntimeException(RuntimeErrors.INST.undefinedError(vertex.getVertexString() + " is not Serializable"), (Throwable)ex);
            }
        }
        cloned.setVertexId(id);
        AbstractOperator abstractOperator = (AbstractOperator)cloned.getOperator();
        abstractOperator.getOpArgs().setOpId(id);
        abstractOperator.setFunction(((AbstractOperator)vertex.getOperator()).getFunction());
        if (parallelism != 0) {
            cloned.setParallelism(parallelism);
            abstractOperator.getOpArgs().setParallelism(parallelism);
        }
        return cloned;
    }

    private void changeOperatorName(PipelineVertex vertex, int index) {
        if (StringUtils.isNotEmpty((CharSequence)((AbstractOperator)vertex.getOperator()).getOpArgs().getOpName())) {
            ((AbstractOperator)vertex.getOperator()).getOpArgs().setOpName(String.format("%s-%d", ((AbstractOperator)vertex.getOperator()).getOpArgs().getOpName(), index));
        }
    }

    private Set<Integer> getOutVertexIdSet(PipelineVertex vertex) {
        return this.outputEdges.get(vertex.getVertexId()).stream().map(PipelineEdge::getTargetId).collect(Collectors.toSet());
    }
}

