/*
 * Decompiled with CFR 0.152.
 */
package sootup.interceptors;

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import javax.annotation.Nonnull;
import sootup.core.graph.MutableStmtGraph;
import sootup.core.jimple.basic.LValue;
import sootup.core.jimple.basic.Local;
import sootup.core.jimple.basic.Value;
import sootup.core.jimple.common.stmt.AbstractDefinitionStmt;
import sootup.core.jimple.common.stmt.Stmt;
import sootup.core.model.Body;
import sootup.core.transform.BodyInterceptor;
import sootup.core.views.View;

public class LocalSplitter
implements BodyInterceptor {
    public void interceptBody(@Nonnull Body.BodyBuilder builder, @Nonnull View view) {
        MutableStmtGraph graph = builder.getStmtGraph();
        List stmts = graph.getStmts();
        Map<Local, List<Integer>> assignmentsByLocal = this.groupAssignmentsByLocal(stmts);
        HashSet<Local> newLocals = new HashSet<Local>();
        Set locals = builder.getLocals();
        for (Local local : locals) {
            DisjointSetForest<PartialStmt> disjointSet = new DisjointSetForest<PartialStmt>();
            List assignments = assignmentsByLocal.getOrDefault(local, Collections.emptyList()).stream().map(i -> (AbstractDefinitionStmt)stmts.get((int)i)).collect(Collectors.toList());
            if (assignments.size() <= 1) {
                newLocals.add(local);
                continue;
            }
            for (AbstractDefinitionStmt assignment : assignments) {
                PartialStmt defStmt = new PartialStmt((Stmt)assignment, true);
                disjointSet.add(defStmt);
                ArrayDeque stack = new ArrayDeque(graph.getAllSuccessors((Stmt)assignment));
                HashSet<Stmt> visited = new HashSet<Stmt>();
                while (!stack.isEmpty()) {
                    Optional defOpt;
                    Stmt stmt = (Stmt)stack.pop();
                    if (!visited.add(stmt)) continue;
                    if (stmt.getUses().anyMatch(l -> l == local)) {
                        PartialStmt useStmt = new PartialStmt(stmt, false);
                        disjointSet.add(useStmt);
                        disjointSet.union(defStmt, useStmt);
                    }
                    if ((defOpt = stmt.getDef()).isPresent() && defOpt.get() == local) continue;
                    stack.addAll(graph.getAllSuccessors(stmt));
                }
            }
            if (disjointSet.getSetCount() <= 1) {
                newLocals.add(local);
                continue;
            }
            HashMap representativeToNewLocal = new HashMap();
            int[] nextId = new int[]{0};
            Function<PartialStmt, Local> getNewLocal = partialStmt -> representativeToNewLocal.computeIfAbsent(disjointSet.find((PartialStmt)partialStmt), s -> {
                int n;
                Local newLocal;
                do {
                    n = nextId[0];
                    nextId[0] = n + 1;
                } while (locals.contains(newLocal = local.withName(local.getName() + "#" + n)));
                return newLocal;
            });
            for (int i2 = 0; i2 < stmts.size(); ++i2) {
                Local newUseLocal;
                Local newDefLocal;
                Stmt stmt = (Stmt)stmts.get(i2);
                Optional stmtDef = stmt.getDef();
                boolean localIsDef = stmtDef.isPresent() && stmtDef.get() == local;
                boolean localIsUse = stmt.getUses().anyMatch(l -> l == local);
                Stmt oldStmt = stmt;
                if (localIsDef && local != (newDefLocal = getNewLocal.apply(new PartialStmt(oldStmt, true)))) {
                    newLocals.add(newDefLocal);
                    stmt = ((AbstractDefinitionStmt)stmt).withNewDef(newDefLocal);
                }
                if (localIsUse && local != (newUseLocal = getNewLocal.apply(new PartialStmt(oldStmt, false)))) {
                    newLocals.add(newUseLocal);
                    stmt = stmt.withNewUse((Value)local, (Value)newUseLocal);
                }
                if (oldStmt == stmt) continue;
                graph.replaceNode(oldStmt, stmt);
                stmts.set(i2, stmt);
            }
        }
        builder.setLocals(newLocals);
    }

    @Nonnull
    Map<Local, List<Integer>> groupAssignmentsByLocal(List<Stmt> statements) {
        HashMap<Local, List<Integer>> groupings = new HashMap<Local, List<Integer>>();
        for (int i = 0; i < statements.size(); ++i) {
            AbstractDefinitionStmt defStmt;
            LValue leftOp;
            Stmt stmt = statements.get(i);
            if (!(stmt instanceof AbstractDefinitionStmt) || !((leftOp = (defStmt = (AbstractDefinitionStmt)stmt).getLeftOp()) instanceof Local)) continue;
            groupings.computeIfAbsent((Local)leftOp, x -> new ArrayList()).add(i);
        }
        return groupings;
    }

    static class DisjointSetForest<T> {
        @Nonnull
        private final Map<T, T> parent = new HashMap<T, T>();
        @Nonnull
        private final Map<T, Integer> sizes = new HashMap<T, Integer>();

        DisjointSetForest() {
        }

        void add(@Nonnull T node) {
            if (this.parent.containsKey(node)) {
                return;
            }
            this.parent.put(node, node);
            this.sizes.put(node, 1);
        }

        @Nonnull
        T find(T node) {
            T parentNode = this.parent.get(node);
            if (parentNode == null) {
                throw new IllegalArgumentException("The DisjointSetForest does not contain the node.");
            }
            T itNode = node;
            while (parentNode != itNode) {
                T grandparent = this.parent.get(parentNode);
                this.parent.put(itNode, grandparent);
                itNode = grandparent;
                parentNode = this.parent.get(grandparent);
            }
            return itNode;
        }

        void union(@Nonnull T first, @Nonnull T second) {
            T smaller;
            T larger;
            if ((first = this.find(first)) == (second = this.find(second))) {
                return;
            }
            Integer firstSize = this.sizes.get(first);
            Integer secondSize = this.sizes.get(second);
            if (firstSize > secondSize) {
                larger = first;
                smaller = second;
            } else {
                larger = second;
                smaller = first;
            }
            this.parent.put(smaller, larger);
            this.sizes.put(larger, firstSize + secondSize);
            this.sizes.remove(smaller);
        }

        int getSetCount() {
            return this.sizes.size();
        }
    }

    static class PartialStmt {
        @Nonnull
        final Stmt backingStmt;
        final boolean isDef;

        PartialStmt(@Nonnull Stmt backingStmt, boolean isDef) {
            this.backingStmt = backingStmt;
            this.isDef = isDef;
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            PartialStmt that = (PartialStmt)o;
            if (this.isDef != that.isDef) {
                return false;
            }
            return this.backingStmt.equals((Object)that.backingStmt);
        }

        public int hashCode() {
            int result = this.backingStmt.hashCode();
            result = 31 * result + (this.isDef ? 1 : 0);
            return result;
        }
    }
}

