001    /*
002     * Copyright 2010-2015 JetBrains s.r.o.
003     *
004     * Licensed under the Apache License, Version 2.0 (the "License");
005     * you may not use this file except in compliance with the License.
006     * You may obtain a copy of the License at
007     *
008     * http://www.apache.org/licenses/LICENSE-2.0
009     *
010     * Unless required by applicable law or agreed to in writing, software
011     * distributed under the License is distributed on an "AS IS" BASIS,
012     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013     * See the License for the specific language governing permissions and
014     * limitations under the License.
015     */
016    
017    package org.jetbrains.kotlin.codegen.optimization.boxing;
018    
019    import com.google.common.base.Predicate;
020    import com.google.common.collect.Collections2;
021    import com.intellij.openapi.util.Pair;
022    import org.jetbrains.annotations.NotNull;
023    import org.jetbrains.kotlin.codegen.optimization.transformer.MethodTransformer;
024    import org.jetbrains.org.objectweb.asm.Opcodes;
025    import org.jetbrains.org.objectweb.asm.Type;
026    import org.jetbrains.org.objectweb.asm.commons.InstructionAdapter;
027    import org.jetbrains.org.objectweb.asm.tree.*;
028    import org.jetbrains.org.objectweb.asm.tree.analysis.BasicValue;
029    import org.jetbrains.org.objectweb.asm.tree.analysis.Frame;
030    
031    import java.util.*;
032    
033    public class RedundantBoxingMethodTransformer extends MethodTransformer {
034    
035        @Override
036        public void transform(@NotNull String internalClassName, @NotNull MethodNode node) {
037            RedundantBoxingInterpreter interpreter = new RedundantBoxingInterpreter(node.instructions);
038            Frame<BasicValue>[] frames = analyze(
039                    internalClassName, node, interpreter
040            );
041            interpretPopInstructionsForBoxedValues(interpreter, node, frames);
042    
043            RedundantBoxedValuesCollection valuesToOptimize = interpreter.getCandidatesBoxedValues();
044    
045            if (!valuesToOptimize.isEmpty()) {
046                // has side effect on valuesToOptimize and frames, containing BoxedBasicValues that are unsafe to remove
047                removeValuesClashingWithVariables(valuesToOptimize, node, frames);
048    
049                adaptLocalVariableTableForBoxedValues(node, frames);
050    
051                applyVariablesRemapping(node, buildVariablesRemapping(valuesToOptimize, node));
052    
053                adaptInstructionsForBoxedValues(node, valuesToOptimize);
054            }
055        }
056    
057        private static void interpretPopInstructionsForBoxedValues(
058                @NotNull RedundantBoxingInterpreter interpreter,
059                @NotNull MethodNode node,
060                @NotNull Frame<BasicValue>[] frames
061        ) {
062            for (int i = 0; i < node.instructions.size(); i++) {
063                AbstractInsnNode insn = node.instructions.get(i);
064                if ((insn.getOpcode() != Opcodes.POP && insn.getOpcode() != Opcodes.POP2) || frames[i] == null) {
065                    continue;
066                }
067    
068                BasicValue top = frames[i].getStack(frames[i].getStackSize() - 1);
069                interpreter.processPopInstruction(insn, top);
070    
071                if (top.getSize() == 1 && insn.getOpcode() == Opcodes.POP2) {
072                    interpreter.processPopInstruction(insn, frames[i].getStack(frames[i].getStackSize() - 2));
073                }
074            }
075        }
076    
077        private static void removeValuesClashingWithVariables(
078                @NotNull RedundantBoxedValuesCollection values,
079                @NotNull MethodNode node,
080                @NotNull Frame<BasicValue>[] frames
081        ) {
082            while (removeValuesClashingWithVariablesPass(values, node, frames)) {
083                // do nothing
084            }
085        }
086    
087        private static boolean removeValuesClashingWithVariablesPass(
088                @NotNull RedundantBoxedValuesCollection values,
089                @NotNull MethodNode node,
090                @NotNull Frame<BasicValue>[] frames
091        ) {
092            boolean needToRepeat = false;
093    
094            for (LocalVariableNode localVariableNode : node.localVariables) {
095                if (Type.getType(localVariableNode.desc).getSort() != Type.OBJECT) {
096                    continue;
097                }
098    
099                List<BasicValue> usedValues = getValuesStoredOrLoadedToVariable(localVariableNode, node, frames);
100    
101                Collection<BasicValue> boxed = Collections2.filter(usedValues, new Predicate<BasicValue>() {
102                    @Override
103                    public boolean apply(BasicValue input) {
104                        return input instanceof BoxedBasicValue;
105                    }
106                });
107    
108                if (boxed.isEmpty()) continue;
109    
110                final BoxedBasicValue firstBoxed = (BoxedBasicValue) boxed.iterator().next();
111    
112                if (!Collections2.filter(usedValues, new Predicate<BasicValue>() {
113                    @Override
114                    public boolean apply(BasicValue input) {
115                        return input == null ||
116                               !(input instanceof BoxedBasicValue) ||
117                               !((BoxedBasicValue) input).isSafeToRemove() ||
118                               !((BoxedBasicValue) input).getPrimitiveType().equals(firstBoxed.getPrimitiveType());
119                    }
120                }).isEmpty()) {
121                    for (BasicValue value : usedValues) {
122                        if (value instanceof BoxedBasicValue && ((BoxedBasicValue) value).isSafeToRemove()) {
123                            values.remove((BoxedBasicValue) value);
124                            needToRepeat = true;
125                        }
126                    }
127                }
128            }
129    
130            return needToRepeat;
131        }
132    
133        private static void adaptLocalVariableTableForBoxedValues(@NotNull MethodNode node, @NotNull Frame<BasicValue>[] frames) {
134            for (LocalVariableNode localVariableNode : node.localVariables) {
135                if (Type.getType(localVariableNode.desc).getSort() != Type.OBJECT) {
136                    continue;
137                }
138    
139                for (BasicValue value : getValuesStoredOrLoadedToVariable(localVariableNode, node, frames)) {
140                    if (value == null || !(value instanceof BoxedBasicValue) || !((BoxedBasicValue) value).isSafeToRemove()) continue;
141                    localVariableNode.desc = ((BoxedBasicValue) value).getPrimitiveType().getDescriptor();
142                }
143            }
144        }
145    
146        @NotNull
147        private static List<BasicValue> getValuesStoredOrLoadedToVariable(
148                @NotNull LocalVariableNode localVariableNode,
149                @NotNull MethodNode node,
150                @NotNull Frame<BasicValue>[] frames
151        ) {
152            List<BasicValue> values = new ArrayList<BasicValue>();
153            InsnList insnList = node.instructions;
154            int from = insnList.indexOf(localVariableNode.start) + 1;
155            int to = insnList.indexOf(localVariableNode.end) - 1;
156    
157            Frame<BasicValue> frameForFromInstr = frames[from];
158            if (frameForFromInstr != null) {
159                BasicValue localVarValue = frameForFromInstr.getLocal(localVariableNode.index);
160                if (localVarValue != null) {
161                    values.add(localVarValue);
162                }
163            }
164    
165            for (int i = from; i <= to; i++) {
166                if (i < 0 || i >= insnList.size()) continue;
167    
168                AbstractInsnNode insn = insnList.get(i);
169                if ((insn.getOpcode() == Opcodes.ASTORE || insn.getOpcode() == Opcodes.ALOAD) &&
170                    ((VarInsnNode) insn).var == localVariableNode.index) {
171    
172                    if (frames[i] == null) {
173                        //unreachable code
174                        continue;
175                    }
176    
177                    if (insn.getOpcode() == Opcodes.ASTORE) {
178                        values.add(frames[i].getStack(frames[i].getStackSize() - 1));
179                    }
180                    else {
181                        values.add(frames[i].getLocal(((VarInsnNode) insn).var));
182                    }
183                }
184            }
185    
186            return values;
187        }
188    
189        @NotNull
190        private static int[] buildVariablesRemapping(@NotNull RedundantBoxedValuesCollection values, @NotNull MethodNode node) {
191            Set<Integer> doubleSizedVars = new HashSet<Integer>();
192            for (BoxedBasicValue value : values) {
193                if (value.getPrimitiveType().getSize() == 2) {
194                    doubleSizedVars.addAll(value.getVariablesIndexes());
195                }
196            }
197    
198            node.maxLocals += doubleSizedVars.size();
199            int[] remapping = new int[node.maxLocals];
200            for (int i = 0; i < remapping.length; i++) {
201                remapping[i] = i;
202            }
203    
204            for (int varIndex : doubleSizedVars) {
205                for (int i = varIndex + 1; i < remapping.length; i++) {
206                    remapping[i]++;
207                }
208            }
209    
210            return remapping;
211        }
212    
213        private static void applyVariablesRemapping(@NotNull MethodNode node, @NotNull int[] remapping) {
214            for (AbstractInsnNode insn : node.instructions.toArray()) {
215                if (insn instanceof VarInsnNode) {
216                    ((VarInsnNode) insn).var = remapping[((VarInsnNode) insn).var];
217                }
218                if (insn instanceof IincInsnNode) {
219                    ((IincInsnNode) insn).var = remapping[((IincInsnNode) insn).var];
220                }
221            }
222    
223            for (LocalVariableNode localVariableNode : node.localVariables) {
224                localVariableNode.index = remapping[localVariableNode.index];
225            }
226        }
227    
228        private static void adaptInstructionsForBoxedValues(
229                @NotNull MethodNode node,
230                @NotNull RedundantBoxedValuesCollection values
231        ) {
232            for (BoxedBasicValue value : values) {
233                adaptInstructionsForBoxedValue(node, value);
234            }
235        }
236    
237        private static void adaptInstructionsForBoxedValue(@NotNull MethodNode node, @NotNull BoxedBasicValue value) {
238            adaptBoxingInstruction(node, value);
239    
240            for (Pair<AbstractInsnNode, Type> cast : value.getUnboxingWithCastInsns()) {
241                adaptCastInstruction(node, value, cast);
242            }
243    
244            for (AbstractInsnNode insn : value.getAssociatedInsns()) {
245                adaptInstruction(node, insn, value);
246            }
247        }
248    
249        private static void adaptBoxingInstruction(@NotNull MethodNode node, @NotNull BoxedBasicValue value) {
250            if (!value.isFromProgressionIterator()) {
251                node.instructions.remove(value.getBoxingInsn());
252            }
253            else {
254                ProgressionIteratorBasicValue iterator = value.getProgressionIterator();
255                assert iterator != null : "iterator should not be null because isFromProgressionIterator returns true";
256    
257                //add checkcast to kotlin/<T>Iterator before next() call
258                node.instructions.insertBefore(
259                        value.getBoxingInsn(),
260                        new TypeInsnNode(Opcodes.CHECKCAST, iterator.getType().getInternalName())
261                );
262    
263                //invoke concrete method (kotlin/<T>iterator.next<T>())
264                node.instructions.set(
265                        value.getBoxingInsn(),
266                        new MethodInsnNode(
267                                Opcodes.INVOKEVIRTUAL,
268                                iterator.getType().getInternalName(),
269                                iterator.getNextMethodName(),
270                                iterator.getNextMethodDesc(),
271                                false
272                        )
273                );
274            }
275        }
276    
277        private static void adaptCastInstruction(
278                @NotNull MethodNode node,
279                @NotNull BoxedBasicValue value,
280                @NotNull Pair<AbstractInsnNode, Type> castWithType
281        ) {
282            AbstractInsnNode castInsn = castWithType.getFirst();
283            MethodNode castInsnsListener = new MethodNode(Opcodes.ASM5);
284            new InstructionAdapter(castInsnsListener).cast(value.getPrimitiveType(), castWithType.getSecond());
285    
286            for (AbstractInsnNode insn : castInsnsListener.instructions.toArray()) {
287                node.instructions.insertBefore(castInsn, insn);
288            }
289    
290            node.instructions.remove(castInsn);
291        }
292    
293        private static void adaptInstruction(
294                @NotNull MethodNode node, @NotNull AbstractInsnNode insn, @NotNull BoxedBasicValue value
295        ) {
296            boolean isDoubleSize = value.isDoubleSize();
297    
298            switch (insn.getOpcode()) {
299                case Opcodes.POP:
300                    if (isDoubleSize) {
301                        node.instructions.set(
302                                insn,
303                                new InsnNode(Opcodes.POP2)
304                        );
305                    }
306                    break;
307                case Opcodes.DUP:
308                    if (isDoubleSize) {
309                        node.instructions.set(
310                                insn,
311                                new InsnNode(Opcodes.DUP2)
312                        );
313                    }
314                    break;
315                case Opcodes.ASTORE:
316                case Opcodes.ALOAD:
317                    int intVarOpcode = insn.getOpcode() == Opcodes.ASTORE ? Opcodes.ISTORE : Opcodes.ILOAD;
318                    node.instructions.set(
319                            insn,
320                            new VarInsnNode(
321                                    value.getPrimitiveType().getOpcode(intVarOpcode),
322                                    ((VarInsnNode) insn).var
323                            )
324                    );
325                    break;
326                case Opcodes.INSTANCEOF:
327                    node.instructions.insertBefore(
328                            insn,
329                            new InsnNode(isDoubleSize ? Opcodes.POP2 : Opcodes.POP)
330                    );
331                    node.instructions.set(insn, new InsnNode(Opcodes.ICONST_1));
332                    break;
333                default:
334                    // CHECKCAST or unboxing-method call
335                    node.instructions.remove(insn);
336            }
337        }
338    }