001    /*
002     * Copyright 2010-2014 JetBrains s.r.o.
003     *
004     * Licensed under the Apache License, Version 2.0 (the "License");
005     * you may not use this file except in compliance with the License.
006     * You may obtain a copy of the License at
007     *
008     * http://www.apache.org/licenses/LICENSE-2.0
009     *
010     * Unless required by applicable law or agreed to in writing, software
011     * distributed under the License is distributed on an "AS IS" BASIS,
012     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013     * See the License for the specific language governing permissions and
014     * limitations under the License.
015     */
016    
017    package org.jetbrains.jet.codegen.optimization.boxing;
018    
019    import com.google.common.base.Predicate;
020    import com.google.common.collect.Collections2;
021    import com.intellij.openapi.util.Pair;
022    import org.jetbrains.annotations.NotNull;
023    import org.jetbrains.jet.codegen.optimization.OptimizationUtils;
024    import org.jetbrains.jet.codegen.optimization.transformer.MethodTransformer;
025    import org.jetbrains.org.objectweb.asm.Opcodes;
026    import org.jetbrains.org.objectweb.asm.Type;
027    import org.jetbrains.org.objectweb.asm.commons.InstructionAdapter;
028    import org.jetbrains.org.objectweb.asm.tree.*;
029    import org.jetbrains.org.objectweb.asm.tree.analysis.BasicValue;
030    import org.jetbrains.org.objectweb.asm.tree.analysis.Frame;
031    
032    import java.util.*;
033    
034    public class RedundantBoxingMethodTransformer extends MethodTransformer {
035        public RedundantBoxingMethodTransformer(MethodTransformer methodTransformer) {
036            super(methodTransformer);
037        }
038    
039        @Override
040        public void transform(@NotNull String internalClassName, @NotNull MethodNode node) {
041            RedundantBoxingInterpreter interpreter = new RedundantBoxingInterpreter(node.instructions);
042            Frame<BasicValue>[] frames = analyze(
043                    internalClassName, node, interpreter
044            );
045            interpretPopInstructionsForBoxedValues(interpreter, node, frames);
046    
047            RedundantBoxedValuesCollection valuesToOptimize = interpreter.getCandidatesBoxedValues();
048    
049            if (!valuesToOptimize.isEmpty()) {
050                // has side effect on valuesToOptimize and frames, containing BoxedBasicValues that are unsafe to remove
051                removeValuesClashingWithVariables(valuesToOptimize, node, frames);
052    
053                adaptLocalVariableTableForBoxedValues(node, frames);
054    
055                applyVariablesRemapping(node, buildVariablesRemapping(valuesToOptimize, node));
056    
057                adaptInstructionsForBoxedValues(node, valuesToOptimize);
058            }
059    
060            super.transform(internalClassName, node);
061        }
062    
063        private static void interpretPopInstructionsForBoxedValues(
064                @NotNull RedundantBoxingInterpreter interpreter,
065                @NotNull MethodNode node,
066                @NotNull Frame<BasicValue>[] frames
067        ) {
068            for (int i = 0; i < node.instructions.size(); i++) {
069                AbstractInsnNode insn = node.instructions.get(i);
070                if ((insn.getOpcode() != Opcodes.POP && insn.getOpcode() != Opcodes.POP2) || frames[i] == null) {
071                    continue;
072                }
073    
074                BasicValue top = frames[i].getStack(frames[i].getStackSize() - 1);
075                interpreter.processPopInstruction(insn, top);
076    
077                if (top.getSize() == 1 && insn.getOpcode() == Opcodes.POP2) {
078                    interpreter.processPopInstruction(insn, frames[i].getStack(frames[i].getStackSize() - 2));
079                }
080            }
081        }
082    
083        private static void removeValuesClashingWithVariables(
084                @NotNull RedundantBoxedValuesCollection values,
085                @NotNull MethodNode node,
086                @NotNull Frame<BasicValue>[] frames
087        ) {
088            while (removeValuesClashingWithVariablesPass(values, node, frames)) {
089                // do nothing
090            }
091        }
092    
093        private static boolean removeValuesClashingWithVariablesPass(
094                @NotNull RedundantBoxedValuesCollection values,
095                @NotNull MethodNode node,
096                @NotNull Frame<BasicValue>[] frames
097        ) {
098            boolean needToRepeat = false;
099    
100            for (LocalVariableNode localVariableNode : node.localVariables) {
101                if (Type.getType(localVariableNode.desc).getSort() != Type.OBJECT) {
102                    continue;
103                }
104    
105                List<BasicValue> usedValues = getValuesStoredOrLoadedToVariable(localVariableNode, node, frames);
106    
107                Collection<BasicValue> boxed = Collections2.filter(usedValues, new Predicate<BasicValue>() {
108                    @Override
109                    public boolean apply(BasicValue input) {
110                        return input instanceof BoxedBasicValue;
111                    }
112                });
113    
114                if (boxed.isEmpty()) continue;
115    
116                final BoxedBasicValue firstBoxed = (BoxedBasicValue) boxed.iterator().next();
117    
118                if (!Collections2.filter(usedValues, new Predicate<BasicValue>() {
119                    @Override
120                    public boolean apply(BasicValue input) {
121                        return input == null ||
122                               !(input instanceof BoxedBasicValue) ||
123                               !((BoxedBasicValue) input).isSafeToRemove() ||
124                               !((BoxedBasicValue) input).getPrimitiveType().equals(firstBoxed.getPrimitiveType());
125                    }
126                }).isEmpty()) {
127                    for (BasicValue value : usedValues) {
128                        if (value instanceof BoxedBasicValue && ((BoxedBasicValue) value).isSafeToRemove()) {
129                            values.remove((BoxedBasicValue) value);
130                            needToRepeat = true;
131                        }
132                    }
133                }
134            }
135    
136            return needToRepeat;
137        }
138    
139        private static void adaptLocalVariableTableForBoxedValues(@NotNull MethodNode node, @NotNull Frame<BasicValue>[] frames) {
140            for (LocalVariableNode localVariableNode : node.localVariables) {
141                if (Type.getType(localVariableNode.desc).getSort() != Type.OBJECT) {
142                    continue;
143                }
144    
145                for (BasicValue value : getValuesStoredOrLoadedToVariable(localVariableNode, node, frames)) {
146                    if (value == null || !(value instanceof BoxedBasicValue) || !((BoxedBasicValue) value).isSafeToRemove()) continue;
147                    localVariableNode.desc = ((BoxedBasicValue) value).getPrimitiveType().getDescriptor();
148                }
149            }
150        }
151    
152        @NotNull
153        private static List<BasicValue> getValuesStoredOrLoadedToVariable(
154                @NotNull LocalVariableNode localVariableNode,
155                @NotNull MethodNode node,
156                @NotNull Frame<BasicValue>[] frames
157        ) {
158            List<BasicValue> values = new ArrayList<BasicValue>();
159            InsnList insnList = node.instructions;
160            int from = insnList.indexOf(localVariableNode.start) + 1;
161            int to = insnList.indexOf(localVariableNode.end) - 1;
162    
163            for (int i = from; i <= to; i++) {
164                if (i < 0 || i >= insnList.size()) continue;
165    
166                AbstractInsnNode insn = insnList.get(i);
167                if ((insn.getOpcode() == Opcodes.ASTORE || insn.getOpcode() == Opcodes.ALOAD) &&
168                    ((VarInsnNode) insn).var == localVariableNode.index) {
169    
170                    // frames[i] can be null in case of exception handlers
171                    if (frames[i] == null) {
172                        values.add(null);
173                        continue;
174                    }
175    
176                    if (insn.getOpcode() == Opcodes.ASTORE) {
177                        values.add(frames[i].getStack(frames[i].getStackSize() - 1));
178                    }
179                    else {
180                        values.add(frames[i].getLocal(((VarInsnNode) insn).var));
181                    }
182                }
183            }
184    
185            return values;
186        }
187    
188        @NotNull
189        private static int[] buildVariablesRemapping(@NotNull RedundantBoxedValuesCollection values, @NotNull MethodNode node) {
190            Set<Integer> doubleSizedVars = new HashSet<Integer>();
191            for (BoxedBasicValue value : values) {
192                if (value.getPrimitiveType().getSize() == 2) {
193                    doubleSizedVars.addAll(value.getVariablesIndexes());
194                }
195            }
196    
197            node.maxLocals += doubleSizedVars.size();
198            int[] remapping = new int[node.maxLocals];
199            for (int i = 0; i < remapping.length; i++) {
200                remapping[i] = i;
201            }
202    
203            for (int varIndex : doubleSizedVars) {
204                for (int i = varIndex + 1; i < remapping.length; i++) {
205                    remapping[i]++;
206                }
207            }
208    
209            return remapping;
210        }
211    
212        private static void applyVariablesRemapping(@NotNull MethodNode node, @NotNull int[] remapping) {
213            for (AbstractInsnNode insn : node.instructions.toArray()) {
214                if (insn instanceof VarInsnNode) {
215                    ((VarInsnNode) insn).var = remapping[((VarInsnNode) insn).var];
216                }
217                if (insn instanceof IincInsnNode) {
218                    ((IincInsnNode) insn).var = remapping[((IincInsnNode) insn).var];
219                }
220            }
221    
222            for (LocalVariableNode localVariableNode : node.localVariables) {
223                localVariableNode.index = remapping[localVariableNode.index];
224            }
225        }
226    
227        private static void adaptInstructionsForBoxedValues(
228                @NotNull MethodNode node,
229                @NotNull RedundantBoxedValuesCollection values
230        ) {
231            for (BoxedBasicValue value : values) {
232                adaptInstructionsForBoxedValue(node, value);
233            }
234        }
235    
236        private static void adaptInstructionsForBoxedValue(@NotNull MethodNode node, @NotNull BoxedBasicValue value) {
237            adaptBoxingInstruction(node, value);
238    
239            for (Pair<AbstractInsnNode, Type> cast : value.getUnboxingWithCastInsns()) {
240                adaptCastInstruction(node, value, cast);
241            }
242    
243            for (AbstractInsnNode insn : value.getAssociatedInsns()) {
244                adaptInstruction(node, insn, value);
245            }
246        }
247    
248        private static void adaptBoxingInstruction(@NotNull MethodNode node, @NotNull BoxedBasicValue value) {
249            if (!value.isFromProgressionIterator()) {
250                node.instructions.remove(value.getBoxingInsn());
251            }
252            else {
253                ProgressionIteratorBasicValue iterator = value.getProgressionIterator();
254                assert iterator != null : "iterator should not be null because isFromProgressionIterator returns true";
255    
256                //add checkcast to kotlin/<T>Iterator before next() call
257                node.instructions.insertBefore(
258                        value.getBoxingInsn(),
259                        new TypeInsnNode(Opcodes.CHECKCAST, iterator.getType().getInternalName())
260                );
261    
262                //invoke concrete method (kotlin/<T>iteraror.next<T>())
263                node.instructions.set(
264                        value.getBoxingInsn(),
265                        new MethodInsnNode(
266                                Opcodes.INVOKEVIRTUAL,
267                                iterator.getType().getInternalName(),
268                                iterator.getNextMethodName(),
269                                iterator.getNextMethodDesc(),
270                                false
271                        )
272                );
273            }
274        }
275    
276        private static void adaptCastInstruction(
277                @NotNull MethodNode node,
278                @NotNull BoxedBasicValue value,
279                @NotNull Pair<AbstractInsnNode, Type> castWithType
280        ) {
281            AbstractInsnNode castInsn = castWithType.getFirst();
282            MethodNode castInsnsListener = new MethodNode(OptimizationUtils.API);
283            new InstructionAdapter(castInsnsListener).cast(value.getPrimitiveType(), castWithType.getSecond());
284    
285            for (AbstractInsnNode insn : castInsnsListener.instructions.toArray()) {
286                node.instructions.insertBefore(castInsn, insn);
287            }
288    
289            node.instructions.remove(castInsn);
290        }
291    
292        private static void adaptInstruction(
293                @NotNull MethodNode node, @NotNull AbstractInsnNode insn, @NotNull BoxedBasicValue value
294        ) {
295            boolean isDoubleSize = value.isDoubleSize();
296    
297            switch (insn.getOpcode()) {
298                case Opcodes.POP:
299                    if (isDoubleSize) {
300                        node.instructions.set(
301                                insn,
302                                new InsnNode(Opcodes.POP2)
303                        );
304                    }
305                    break;
306                case Opcodes.DUP:
307                    if (isDoubleSize) {
308                        node.instructions.set(
309                                insn,
310                                new InsnNode(Opcodes.DUP2)
311                        );
312                    }
313                    break;
314                case Opcodes.ASTORE:
315                case Opcodes.ALOAD:
316                    int intVarOpcode = insn.getOpcode() == Opcodes.ASTORE ? Opcodes.ISTORE : Opcodes.ILOAD;
317                    node.instructions.set(
318                            insn,
319                            new VarInsnNode(
320                                    value.getPrimitiveType().getOpcode(intVarOpcode),
321                                    ((VarInsnNode) insn).var
322                            )
323                    );
324                    break;
325                case Opcodes.INSTANCEOF:
326                    node.instructions.insertBefore(
327                            insn,
328                            new InsnNode(isDoubleSize ? Opcodes.POP2 : Opcodes.POP)
329                    );
330                    node.instructions.set(insn, new InsnNode(Opcodes.ICONST_1));
331                    break;
332                default:
333                    // CHECKCAST or unboxing-method call
334                    node.instructions.remove(insn);
335            }
336        }
337    }