001 /*
002 * Copyright 2010-2014 JetBrains s.r.o.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 * http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017 package org.jetbrains.jet.codegen.inline;
018
019 import com.google.common.collect.Lists;
020 import com.intellij.util.ArrayUtil;
021 import org.jetbrains.annotations.NotNull;
022 import org.jetbrains.org.objectweb.asm.Label;
023 import org.jetbrains.org.objectweb.asm.MethodVisitor;
024 import org.jetbrains.org.objectweb.asm.Opcodes;
025 import org.jetbrains.org.objectweb.asm.Type;
026 import org.jetbrains.org.objectweb.asm.commons.InstructionAdapter;
027 import org.jetbrains.org.objectweb.asm.commons.Method;
028 import org.jetbrains.org.objectweb.asm.commons.RemappingMethodAdapter;
029 import org.jetbrains.org.objectweb.asm.tree.*;
030 import org.jetbrains.org.objectweb.asm.tree.analysis.*;
031 import org.jetbrains.jet.codegen.ClosureCodegen;
032 import org.jetbrains.jet.codegen.StackValue;
033 import org.jetbrains.jet.codegen.state.JetTypeMapper;
034
035 import java.util.*;
036
037 import static org.jetbrains.jet.codegen.inline.InlineCodegenUtil.isInvokeOnLambda;
038 import static org.jetbrains.jet.codegen.inline.InlineCodegenUtil.isLambdaConstructorCall;
039
040 public class MethodInliner {
041
042 private final MethodNode node;
043
044 private final Parameters parameters;
045
046 private final InliningContext inliningContext;
047
048 private final FieldRemapper nodeRemapper;
049
050 private final boolean isSameModule;
051
052 private final String errorPrefix;
053
054 private final JetTypeMapper typeMapper;
055
056 private final List<InvokeCall> invokeCalls = new ArrayList<InvokeCall>();
057
058 //keeps order
059 private final List<ConstructorInvocation> constructorInvocations = new ArrayList<ConstructorInvocation>();
060 //current state
061 private final Map<String, String> currentTypeMapping = new HashMap<String, String>();
062
063 private final InlineResult result;
064
065 /*
066 *
067 * @param node
068 * @param parameters
069 * @param inliningContext
070 * @param lambdaType - in case on lambda 'invoke' inlining
071 */
072 public MethodInliner(
073 @NotNull MethodNode node,
074 @NotNull Parameters parameters,
075 @NotNull InliningContext parent,
076 @NotNull FieldRemapper nodeRemapper,
077 boolean isSameModule,
078 @NotNull String errorPrefix
079 ) {
080 this.node = node;
081 this.parameters = parameters;
082 this.inliningContext = parent;
083 this.nodeRemapper = nodeRemapper;
084 this.isSameModule = isSameModule;
085 this.errorPrefix = errorPrefix;
086 this.typeMapper = parent.state.getTypeMapper();
087 this.result = InlineResult.create();
088 }
089
090
091 public InlineResult doInline(MethodVisitor adapter, LocalVarRemapper remapper) {
092 return doInline(adapter, remapper, true);
093 }
094
095 public InlineResult doInline(
096 MethodVisitor adapter,
097 LocalVarRemapper remapper,
098 boolean remapReturn
099 ) {
100 //analyze body
101 MethodNode transformedNode = markPlacesForInlineAndRemoveInlinable(node);
102
103 transformedNode = doInline(transformedNode);
104 removeClosureAssertions(transformedNode);
105 transformedNode.instructions.resetLabels();
106
107 Label end = new Label();
108 RemapVisitor visitor = new RemapVisitor(adapter, end, remapper, remapReturn, nodeRemapper);
109 try {
110 transformedNode.accept(visitor);
111 }
112 catch (Exception e) {
113 throw wrapException(e, transformedNode, "couldn't inline method call");
114 }
115
116 visitor.visitLabel(end);
117
118 return result;
119 }
120
121 private MethodNode doInline(MethodNode node) {
122
123 final Deque<InvokeCall> currentInvokes = new LinkedList<InvokeCall>(invokeCalls);
124
125 MethodNode resultNode = new MethodNode(node.access, node.name, node.desc, node.signature, null);
126
127 final Iterator<ConstructorInvocation> iterator = constructorInvocations.iterator();
128
129 RemappingMethodAdapter remappingMethodAdapter = new RemappingMethodAdapter(resultNode.access, resultNode.desc, resultNode,
130 new TypeRemapper(currentTypeMapping));
131
132 InlineAdapter inliner = new InlineAdapter(remappingMethodAdapter, parameters.totalSize()) {
133
134 private ConstructorInvocation invocation;
135 @Override
136 public void anew(Type type) {
137 if (isLambdaConstructorCall(type.getInternalName(), "<init>")) {
138 invocation = iterator.next();
139
140 if (invocation.shouldRegenerate()) {
141 //TODO: need poping of type but what to do with local funs???
142 Type newLambdaType = Type.getObjectType(inliningContext.nameGenerator.genLambdaClassName());
143 currentTypeMapping.put(invocation.getOwnerInternalName(), newLambdaType.getInternalName());
144 AnonymousObjectTransformer transformer =
145 new AnonymousObjectTransformer(invocation.getOwnerInternalName(),
146 inliningContext
147 .subInlineWithClassRegeneration(
148 inliningContext.nameGenerator,
149 currentTypeMapping,
150 invocation),
151 isSameModule, newLambdaType
152 );
153
154 InlineResult transformResult = transformer.doTransform(invocation, nodeRemapper);
155 result.addAllClassesToRemove(transformResult);
156
157 if (inliningContext.isInliningLambda) {
158 //this class is transformed and original not used so we should remove original one after inlining
159 result.addClassToRemove(invocation.getOwnerInternalName());
160 }
161 }
162 }
163
164 //in case of regenerated invocation type would be remapped to new one via remappingMethodAdapter
165 super.anew(type);
166 }
167
168 @Override
169 public void visitMethodInsn(int opcode, String owner, String name, String desc, boolean itf) {
170 if (/*INLINE_RUNTIME.equals(owner) &&*/ isInvokeOnLambda(owner, name)) { //TODO add method
171 assert !currentInvokes.isEmpty();
172 InvokeCall invokeCall = currentInvokes.remove();
173 LambdaInfo info = invokeCall.lambdaInfo;
174
175 if (info == null) {
176 //noninlinable lambda
177 super.visitMethodInsn(opcode, owner, name, desc, itf);
178 return;
179 }
180
181 int valueParamShift = getNextLocalIndex();//NB: don't inline cause it changes
182 putStackValuesIntoLocals(info.getParamsWithoutCapturedValOrVar(), valueParamShift, this, desc);
183
184 Parameters lambdaParameters = info.addAllParameters();
185
186 InlinedLambdaRemapper newCapturedRemapper =
187 new InlinedLambdaRemapper(info.getLambdaClassType().getInternalName(), nodeRemapper, lambdaParameters);
188
189 setInlining(true);
190 MethodInliner inliner = new MethodInliner(info.getNode(), lambdaParameters,
191 inliningContext.subInlineLambda(info),
192 newCapturedRemapper, true /*cause all calls in same module as lambda*/,
193 "Lambda inlining " + info.getLambdaClassType().getInternalName());
194
195 LocalVarRemapper remapper = new LocalVarRemapper(lambdaParameters, valueParamShift);
196 InlineResult lambdaResult = inliner.doInline(this.mv, remapper);//TODO add skipped this and receiver
197 result.addAllClassesToRemove(lambdaResult);
198
199 //return value boxing/unboxing
200 Method bridge = typeMapper.mapSignature(ClosureCodegen.getInvokeFunction(info.getFunctionDescriptor())).getAsmMethod();
201 Method delegate = typeMapper.mapSignature(info.getFunctionDescriptor()).getAsmMethod();
202 StackValue.onStack(delegate.getReturnType()).put(bridge.getReturnType(), this);
203 setInlining(false);
204 }
205 else if (isLambdaConstructorCall(owner, name)) { //TODO add method
206 assert invocation != null : "<init> call not corresponds to new call" + owner + " " + name;
207 if (invocation.shouldRegenerate()) {
208 //put additional captured parameters on stack
209 for (CapturedParamInfo capturedParamInfo : invocation.getAllRecapturedParameters()) {
210 visitFieldInsn(Opcodes.GETSTATIC, capturedParamInfo.getContainingLambdaName(), "$$$" + capturedParamInfo.getOriginalFieldName(), capturedParamInfo.getType().getDescriptor());
211 }
212 super.visitMethodInsn(opcode, invocation.getNewLambdaType().getInternalName(), name, invocation.getNewConstructorDescriptor(), itf);
213 invocation = null;
214 } else {
215 super.visitMethodInsn(opcode, changeOwnerForExternalPackage(owner, opcode), name, desc, itf);
216 }
217 }
218 else {
219 super.visitMethodInsn(opcode, changeOwnerForExternalPackage(owner, opcode), name, desc, itf);
220 }
221 }
222
223 };
224
225 node.accept(inliner);
226
227 return resultNode;
228 }
229
230 @NotNull
231 public static CapturedParamInfo findCapturedField(FieldInsnNode node, FieldRemapper fieldRemapper) {
232 assert node.name.startsWith("$$$") : "Captured field template should start with $$$ prefix";
233 FieldInsnNode fin = new FieldInsnNode(node.getOpcode(), node.owner, node.name.substring(3), node.desc);
234 CapturedParamInfo field = fieldRemapper.findField(fin);
235 if (field == null) {
236 throw new IllegalStateException("Couldn't find captured field " + node.owner + "." + node.name + " in " + fieldRemapper.getLambdaInternalName());
237 }
238 return field;
239 }
240
241 @NotNull
242 public MethodNode prepareNode(@NotNull MethodNode node) {
243 final int capturedParamsSize = parameters.getCaptured().size();
244 final int realParametersSize = parameters.getReal().size();
245 Type[] types = Type.getArgumentTypes(node.desc);
246 Type returnType = Type.getReturnType(node.desc);
247
248 ArrayList<Type> capturedTypes = parameters.getCapturedTypes();
249 Type[] allTypes = ArrayUtil.mergeArrays(types, capturedTypes.toArray(new Type[capturedTypes.size()]));
250
251 node.instructions.resetLabels();
252 MethodNode transformedNode = new MethodNode(InlineCodegenUtil.API, node.access, node.name, Type.getMethodDescriptor(returnType, allTypes), node.signature, null) {
253
254 private final boolean isInliningLambda = nodeRemapper.isInsideInliningLambda();
255
256 private int getNewIndex(int var) {
257 return var + (var < realParametersSize ? 0 : capturedParamsSize);
258 }
259
260 @Override
261 public void visitVarInsn(int opcode, int var) {
262 super.visitVarInsn(opcode, getNewIndex(var));
263 }
264
265 @Override
266 public void visitIincInsn(int var, int increment) {
267 super.visitIincInsn(getNewIndex(var), increment);
268 }
269
270 @Override
271 public void visitMaxs(int maxStack, int maxLocals) {
272 super.visitMaxs(maxStack, maxLocals + capturedParamsSize);
273 }
274
275 @Override
276 public void visitLineNumber(int line, Label start) {
277 if(isInliningLambda) {
278 super.visitLineNumber(line, start);
279 }
280 }
281
282 @Override
283 public void visitLocalVariable(
284 String name, String desc, String signature, Label start, Label end, int index
285 ) {
286 if (isInliningLambda) {
287 super.visitLocalVariable(name, desc, signature, start, end, getNewIndex(index));
288 }
289 }
290 };
291
292 node.accept(transformedNode);
293
294 transformCaptured(transformedNode);
295
296 return transformedNode;
297 }
298
299 @NotNull
300 protected MethodNode markPlacesForInlineAndRemoveInlinable(@NotNull MethodNode node) {
301 node = prepareNode(node);
302
303 Analyzer<SourceValue> analyzer = new Analyzer<SourceValue>(new SourceInterpreter());
304 Frame<SourceValue>[] sources;
305 try {
306 sources = analyzer.analyze("fake", node);
307 }
308 catch (AnalyzerException e) {
309 throw wrapException(e, node, "couldn't inline method call");
310 }
311
312 AbstractInsnNode cur = node.instructions.getFirst();
313 int index = 0;
314 Set<LabelNode> deadLabels = new HashSet<LabelNode>();
315
316 while (cur != null) {
317 Frame<SourceValue> frame = sources[index];
318
319 if (frame != null) {
320 if (cur.getType() == AbstractInsnNode.METHOD_INSN) {
321 MethodInsnNode methodInsnNode = (MethodInsnNode) cur;
322 String owner = methodInsnNode.owner;
323 String desc = methodInsnNode.desc;
324 String name = methodInsnNode.name;
325 //TODO check closure
326 int paramLength = Type.getArgumentTypes(desc).length + 1;//non static
327 if (isInvokeOnLambda(owner, name) /*&& methodInsnNode.owner.equals(INLINE_RUNTIME)*/) {
328 SourceValue sourceValue = frame.getStack(frame.getStackSize() - paramLength);
329
330 LambdaInfo lambdaInfo = null;
331 int varIndex = -1;
332
333 if (sourceValue.insns.size() == 1) {
334 AbstractInsnNode insnNode = sourceValue.insns.iterator().next();
335
336 lambdaInfo = getLambdaIfExists(insnNode);
337 if (lambdaInfo != null) {
338 //remove inlinable access
339 node.instructions.remove(insnNode);
340 }
341 }
342
343 invokeCalls.add(new InvokeCall(varIndex, lambdaInfo));
344 }
345 else if (isLambdaConstructorCall(owner, name)) {
346 Map<Integer, LambdaInfo> lambdaMapping = new HashMap<Integer, LambdaInfo>();
347 int paramStart = frame.getStackSize() - paramLength;
348
349 for (int i = 0; i < paramLength; i++) {
350 SourceValue sourceValue = frame.getStack(paramStart + i);
351 if (sourceValue.insns.size() == 1) {
352 AbstractInsnNode insnNode = sourceValue.insns.iterator().next();
353 LambdaInfo lambdaInfo = getLambdaIfExists(insnNode);
354 if (lambdaInfo != null) {
355 lambdaMapping.put(i, lambdaInfo);
356 node.instructions.remove(insnNode);
357 }
358 }
359 }
360
361 constructorInvocations.add(new ConstructorInvocation(owner, desc, lambdaMapping, isSameModule, inliningContext.classRegeneration));
362 }
363 }
364 }
365
366 AbstractInsnNode prevNode = cur;
367 cur = cur.getNext();
368 index++;
369
370 //given frame is <tt>null</tt> if and only if the corresponding instruction cannot be reached (dead code).
371 if (frame == null) {
372 //clean dead code otherwise there is problems in unreachable finally block, don't touch label it cause try/catch/finally problems
373 if (prevNode.getType() == AbstractInsnNode.LABEL) {
374 deadLabels.add((LabelNode) prevNode);
375 } else {
376 node.instructions.remove(prevNode);
377 }
378 }
379 }
380
381 //clean dead try/catch blocks
382 List<TryCatchBlockNode> blocks = node.tryCatchBlocks;
383 for (Iterator<TryCatchBlockNode> iterator = blocks.iterator(); iterator.hasNext(); ) {
384 TryCatchBlockNode block = iterator.next();
385 if (deadLabels.contains(block.start) && deadLabels.contains(block.end)) {
386 iterator.remove();
387 }
388 }
389
390 return node;
391 }
392
393 public LambdaInfo getLambdaIfExists(AbstractInsnNode insnNode) {
394 if (insnNode.getOpcode() == Opcodes.ALOAD) {
395 int varIndex = ((VarInsnNode) insnNode).var;
396 if (varIndex < parameters.totalSize()) {
397 return parameters.get(varIndex).getLambda();
398 }
399 }
400 else if (insnNode instanceof FieldInsnNode) {
401 FieldInsnNode fieldInsnNode = (FieldInsnNode) insnNode;
402 if (fieldInsnNode.name.startsWith("$$$")) {
403 return findCapturedField(fieldInsnNode, nodeRemapper).getLambda();
404 }
405 }
406
407 return null;
408 }
409
410 private static void removeClosureAssertions(MethodNode node) {
411 AbstractInsnNode cur = node.instructions.getFirst();
412 while (cur != null && cur.getNext() != null) {
413 AbstractInsnNode next = cur.getNext();
414 if (next.getType() == AbstractInsnNode.METHOD_INSN) {
415 MethodInsnNode methodInsnNode = (MethodInsnNode) next;
416 if (methodInsnNode.name.equals("checkParameterIsNotNull") && methodInsnNode.owner.equals("kotlin/jvm/internal/Intrinsics")) {
417 AbstractInsnNode prev = cur.getPrevious();
418
419 assert cur.getOpcode() == Opcodes.LDC : "checkParameterIsNotNull should go after LDC but " + cur;
420 assert prev.getOpcode() == Opcodes.ALOAD : "checkParameterIsNotNull should be invoked on local var but " + prev;
421
422 node.instructions.remove(prev);
423 node.instructions.remove(cur);
424 cur = next.getNext();
425 node.instructions.remove(next);
426 next = cur;
427 }
428 }
429 cur = next;
430 }
431 }
432
433 private void transformCaptured(@NotNull MethodNode node) {
434 if (nodeRemapper.isRoot()) {
435 return;
436 }
437
438 //Fold all captured variable chain - ALOAD 0 ALOAD this$0 GETFIELD $captured - to GETFIELD $$$$captured
439 //On future decoding this field could be inline or unfolded in another field access chain (it can differ in some missed this$0)
440 AbstractInsnNode cur = node.instructions.getFirst();
441 while (cur != null) {
442 if (cur instanceof VarInsnNode && cur.getOpcode() == Opcodes.ALOAD) {
443 if (((VarInsnNode) cur).var == 0) {
444 List<AbstractInsnNode> accessChain = getCapturedFieldAccessChain((VarInsnNode) cur);
445 AbstractInsnNode insnNode = nodeRemapper.foldFieldAccessChainIfNeeded(accessChain, node);
446 if (insnNode != null) {
447 cur = insnNode;
448 }
449 }
450 }
451 cur = cur.getNext();
452 }
453 }
454
455 @NotNull
456 public static List<AbstractInsnNode> getCapturedFieldAccessChain(@NotNull VarInsnNode aload0) {
457 List<AbstractInsnNode> fieldAccessChain = new ArrayList<AbstractInsnNode>();
458 fieldAccessChain.add(aload0);
459 AbstractInsnNode next = aload0.getNext();
460 while (next != null && next instanceof FieldInsnNode || next instanceof LabelNode) {
461 if (next instanceof LabelNode) {
462 next = next.getNext();
463 continue; //it will be delete on transformation
464 }
465 fieldAccessChain.add(next);
466 if ("this$0".equals(((FieldInsnNode) next).name)) {
467 next = next.getNext();
468 }
469 else {
470 break;
471 }
472 }
473
474 return fieldAccessChain;
475 }
476
477 public static void putStackValuesIntoLocals(List<Type> directOrder, int shift, InstructionAdapter iv, String descriptor) {
478 Type[] actualParams = Type.getArgumentTypes(descriptor);
479 assert actualParams.length == directOrder.size() : "Number of expected and actual params should be equals!";
480
481 int size = 0;
482 for (Type next : directOrder) {
483 size += next.getSize();
484 }
485
486 shift += size;
487 int index = directOrder.size();
488
489 for (Type next : Lists.reverse(directOrder)) {
490 shift -= next.getSize();
491 Type typeOnStack = actualParams[--index];
492 if (!typeOnStack.equals(next)) {
493 StackValue.onStack(typeOnStack).put(next, iv);
494 }
495 iv.store(shift, next);
496 }
497 }
498
499 //TODO: check annotation on class - it's package part
500 //TODO: check it's external module
501 //TODO?: assert method exists in facade?
502 public String changeOwnerForExternalPackage(String type, int opcode) {
503 if (isSameModule || (opcode & Opcodes.INVOKESTATIC) == 0) {
504 return type;
505 }
506
507 int i = type.indexOf('-');
508 if (i >= 0) {
509 return type.substring(0, i);
510 }
511 return type;
512 }
513
514
515 public RuntimeException wrapException(@NotNull Exception originalException, @NotNull MethodNode node, @NotNull String errorSuffix) {
516 if (originalException instanceof InlineException) {
517 return new InlineException(errorPrefix + ": " + errorSuffix, originalException);
518 } else {
519 return new InlineException(errorPrefix + ": " + errorSuffix + "\ncause: " +
520 InlineCodegen.getNodeText(node), originalException);
521 }
522 }
523 }