001 /*
002 * Copyright 2010-2013 JetBrains s.r.o.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 * http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017 package org.jetbrains.jet.codegen.inline;
018
019 import com.intellij.openapi.vfs.VirtualFile;
020 import com.intellij.psi.PsiElement;
021 import org.jetbrains.annotations.NotNull;
022 import org.jetbrains.annotations.Nullable;
023 import org.jetbrains.jet.codegen.*;
024 import org.jetbrains.jet.codegen.context.CodegenContext;
025 import org.jetbrains.jet.codegen.context.MethodContext;
026 import org.jetbrains.jet.codegen.context.PackageContext;
027 import org.jetbrains.jet.codegen.signature.JvmMethodParameterKind;
028 import org.jetbrains.jet.codegen.signature.JvmMethodParameterSignature;
029 import org.jetbrains.jet.codegen.signature.JvmMethodSignature;
030 import org.jetbrains.jet.codegen.state.GenerationState;
031 import org.jetbrains.jet.codegen.state.JetTypeMapper;
032 import org.jetbrains.jet.descriptors.serialization.descriptors.DeserializedSimpleFunctionDescriptor;
033 import org.jetbrains.jet.lang.descriptors.*;
034 import org.jetbrains.jet.lang.descriptors.impl.AnonymousFunctionDescriptor;
035 import org.jetbrains.jet.lang.psi.*;
036 import org.jetbrains.jet.lang.resolve.BindingContext;
037 import org.jetbrains.jet.lang.resolve.BindingContextUtils;
038 import org.jetbrains.jet.lang.resolve.DescriptorUtils;
039 import org.jetbrains.jet.lang.resolve.calls.model.ResolvedCall;
040 import org.jetbrains.jet.lang.resolve.java.AsmTypeConstants;
041 import org.jetbrains.jet.lang.types.lang.InlineStrategy;
042 import org.jetbrains.jet.lang.types.lang.InlineUtil;
043 import org.jetbrains.jet.renderer.DescriptorRenderer;
044 import org.jetbrains.org.objectweb.asm.MethodVisitor;
045 import org.jetbrains.org.objectweb.asm.Opcodes;
046 import org.jetbrains.org.objectweb.asm.Type;
047 import org.jetbrains.org.objectweb.asm.commons.Method;
048 import org.jetbrains.org.objectweb.asm.tree.MethodNode;
049 import org.jetbrains.org.objectweb.asm.util.Textifier;
050 import org.jetbrains.org.objectweb.asm.util.TraceMethodVisitor;
051
052 import java.io.IOException;
053 import java.io.PrintWriter;
054 import java.io.StringWriter;
055 import java.util.*;
056
057 import static org.jetbrains.jet.codegen.AsmUtil.getMethodAsmFlags;
058 import static org.jetbrains.jet.codegen.AsmUtil.isPrimitive;
059
060 public class InlineCodegen implements CallGenerator {
061 private final GenerationState state;
062 private final JetTypeMapper typeMapper;
063 private final BindingContext bindingContext;
064
065 private final SimpleFunctionDescriptor functionDescriptor;
066 private final JvmMethodSignature jvmSignature;
067 private final Call call;
068 private final MethodContext context;
069 private final ExpressionCodegen codegen;
070 private final FrameMap originalFunctionFrame;
071 private final boolean asFunctionInline;
072 private final int initialFrameSize;
073 private final boolean isSameModule;
074
075 protected final List<ParameterInfo> actualParameters = new ArrayList<ParameterInfo>();
076 protected final Map<Integer, LambdaInfo> expressionMap = new HashMap<Integer, LambdaInfo>();
077
078 private LambdaInfo activeLambda;
079
080 public InlineCodegen(
081 @NotNull ExpressionCodegen codegen,
082 @NotNull GenerationState state,
083 @NotNull SimpleFunctionDescriptor functionDescriptor,
084 @NotNull Call call
085 ) {
086 assert functionDescriptor.getInlineStrategy().isInline() : "InlineCodegen could inline only inline function but " + functionDescriptor;
087
088 this.state = state;
089 this.typeMapper = state.getTypeMapper();
090 this.codegen = codegen;
091 this.call = call;
092 this.functionDescriptor = functionDescriptor.getOriginal();
093 bindingContext = codegen.getBindingContext();
094 initialFrameSize = codegen.getFrameMap().getCurrentSize();
095
096 context = (MethodContext) getContext(functionDescriptor, state);
097 originalFunctionFrame = context.prepareFrame(typeMapper);
098 jvmSignature = typeMapper.mapSignature(functionDescriptor, context.getContextKind());
099
100 InlineStrategy inlineStrategy =
101 codegen.getContext().isInlineFunction() ? InlineStrategy.IN_PLACE : functionDescriptor.getInlineStrategy();
102 this.asFunctionInline = false;
103
104 isSameModule = !(functionDescriptor instanceof DeserializedSimpleFunctionDescriptor) /*not compiled library*/ &&
105 JvmCodegenUtil.isCallInsideSameModuleAsDeclared(functionDescriptor, codegen.getContext());
106 }
107
108
109 @Override
110 public void genCall(CallableMethod callableMethod, ResolvedCall<?> resolvedCall, int mask, ExpressionCodegen codegen) {
111 assert mask == 0 : "Default method invocation couldn't be inlined " + resolvedCall;
112
113 MethodNode node = null;
114
115 try {
116 node = createMethodNode(callableMethod);
117 endCall(inlineCall(node));
118 }
119 catch (CompilationException e) {
120 throw e;
121 }
122 catch (Exception e) {
123 boolean generateNodeText = !(e instanceof InlineException);
124 PsiElement element = BindingContextUtils.descriptorToDeclaration(bindingContext, this.codegen.getContext().getContextDescriptor());
125 throw new CompilationException("Couldn't inline method call '" +
126 functionDescriptor.getName() +
127 "' into \n" + (element != null ? element.getText() : "null psi element " + this.codegen.getContext().getContextDescriptor()) +
128 (generateNodeText ? ("\ncause: " + getNodeText(node)) : ""),
129 e, call.getCallElement());
130 }
131
132
133 }
134
135 private void endCall(@NotNull InlineResult result) {
136 leaveTemps();
137
138 state.getFactory().removeInlinedClasses(result.getClassesToRemove());
139 }
140
141 @NotNull
142 private MethodNode createMethodNode(CallableMethod callableMethod)
143 throws ClassNotFoundException, IOException {
144 MethodNode node;
145 if (functionDescriptor instanceof DeserializedSimpleFunctionDescriptor) {
146 VirtualFile file = InlineCodegenUtil.getVirtualFileForCallable((DeserializedSimpleFunctionDescriptor) functionDescriptor, state);
147 String methodDesc = callableMethod.getAsmMethod().getDescriptor();
148 DeclarationDescriptor parentDescriptor = functionDescriptor.getContainingDeclaration();
149 if (DescriptorUtils.isTrait(parentDescriptor)) {
150 methodDesc = "(" + typeMapper.mapType((ClassDescriptor) parentDescriptor).getDescriptor() + methodDesc.substring(1);
151 }
152 node = InlineCodegenUtil.getMethodNode(file.getInputStream(), functionDescriptor.getName().asString(), methodDesc);
153
154 if (node == null) {
155 throw new RuntimeException("Couldn't obtain compiled function body for " + descriptorName(functionDescriptor));
156 }
157 }
158 else {
159 PsiElement element = BindingContextUtils.descriptorToDeclaration(bindingContext, functionDescriptor);
160
161 if (element == null) {
162 throw new RuntimeException("Couldn't find declaration for function " + descriptorName(functionDescriptor));
163 }
164
165 JvmMethodSignature jvmSignature = typeMapper.mapSignature(functionDescriptor, context.getContextKind());
166 Method asmMethod = jvmSignature.getAsmMethod();
167 node = new MethodNode(InlineCodegenUtil.API,
168 getMethodAsmFlags(functionDescriptor, context.getContextKind()),
169 asmMethod.getName(),
170 asmMethod.getDescriptor(),
171 jvmSignature.getGenericsSignature(),
172 null);
173
174 //for maxLocals calculation
175 MethodVisitor adapter = InlineCodegenUtil.wrapWithMaxLocalCalc(node);
176 FunctionCodegen.generateMethodBody(adapter, functionDescriptor, context.getParentContext().intoFunction(functionDescriptor),
177 jvmSignature,
178 new FunctionGenerationStrategy.FunctionDefault(state,
179 functionDescriptor,
180 (JetDeclarationWithBody) element),
181 codegen.getParentCodegen());
182 adapter.visitMaxs(-1, -1);
183 adapter.visitEnd();
184 }
185 return node;
186 }
187
188 private InlineResult inlineCall(MethodNode node) {
189 generateClosuresBodies();
190
191 List<ParameterInfo> realParams = new ArrayList<ParameterInfo>(actualParameters);
192
193 putClosureParametersOnStack();
194
195 List<CapturedParamInfo> captured = getAllCaptured();
196
197 Parameters parameters = new Parameters(realParams, Parameters.shiftAndAddStubs(captured, realParams.size()));
198
199 InliningContext info = new RootInliningContext(expressionMap,
200 state,
201 codegen.getInlineNameGenerator()
202 .subGenerator(functionDescriptor.getName().asString()),
203 codegen.getContext(),
204 call,
205 codegen.getParentCodegen().getClassName());
206
207 MethodInliner inliner = new MethodInliner(node, parameters, info, new FieldRemapper(null, null, parameters), isSameModule, "Method inlining " + call.getCallElement().getText()); //with captured
208
209 LocalVarRemapper remapper = new LocalVarRemapper(parameters, initialFrameSize);
210
211 return inliner.doInline(codegen.v, remapper);
212 }
213
214 private void generateClosuresBodies() {
215 for (LambdaInfo info : expressionMap.values()) {
216 info.setNode(generateLambdaBody(info));
217 }
218 }
219
220 private MethodNode generateLambdaBody(LambdaInfo info) {
221 JetFunctionLiteral declaration = info.getFunctionLiteral();
222 FunctionDescriptor descriptor = info.getFunctionDescriptor();
223
224 MethodContext parentContext = codegen.getContext();
225
226 MethodContext context = parentContext.intoClosure(descriptor, codegen, typeMapper).intoInlinedLambda(descriptor);
227
228 JvmMethodSignature jvmMethodSignature = typeMapper.mapSignature(descriptor);
229 Method asmMethod = jvmMethodSignature.getAsmMethod();
230 MethodNode methodNode = new MethodNode(InlineCodegenUtil.API, getMethodAsmFlags(descriptor, context.getContextKind()), asmMethod.getName(), asmMethod.getDescriptor(), jvmMethodSignature.getGenericsSignature(), null);
231
232 MethodVisitor adapter = InlineCodegenUtil.wrapWithMaxLocalCalc(methodNode);
233
234 FunctionCodegen.generateMethodBody(adapter, descriptor, context, jvmMethodSignature, new FunctionGenerationStrategy.FunctionDefault(state, descriptor, declaration), codegen.getParentCodegen());
235 adapter.visitMaxs(-1, -1);
236
237 return methodNode;
238 }
239
240
241
242 @Override
243 public void afterParameterPut(@NotNull Type type, @Nullable StackValue stackValue, ValueParameterDescriptor valueParameterDescriptor) {
244 putCapturedInLocal(type, stackValue, valueParameterDescriptor, -1);
245 }
246
247 public void putCapturedInLocal(
248 @NotNull Type type, @Nullable StackValue stackValue, @Nullable ValueParameterDescriptor valueParameterDescriptor, int capturedParamIndex
249 ) {
250 if (!asFunctionInline && Type.VOID_TYPE != type) {
251 //TODO remap only inlinable closure => otherwise we could get a lot of problem
252 boolean couldBeRemapped = !shouldPutValue(type, stackValue, valueParameterDescriptor);
253 StackValue remappedIndex = couldBeRemapped ? stackValue : null;
254
255 ParameterInfo info = new ParameterInfo(type, false, couldBeRemapped ? -1 : codegen.getFrameMap().enterTemp(type), remappedIndex);
256
257 if (capturedParamIndex >= 0 && couldBeRemapped) {
258 CapturedParamInfo capturedParamInfo = activeLambda.getCapturedVars().get(capturedParamIndex);
259 capturedParamInfo.setRemapValue(remappedIndex != null ? remappedIndex : StackValue.local(info.getIndex(), info.getType()));
260 }
261
262 doWithParameter(info);
263 }
264 }
265
266 /*descriptor is null for captured vars*/
267 public boolean shouldPutValue(
268 @NotNull Type type,
269 @Nullable StackValue stackValue,
270 @Nullable ValueParameterDescriptor descriptor
271 ) {
272
273 if (stackValue == null) {
274 //default or vararg
275 return true;
276 }
277
278 //remap only inline functions (and maybe non primitives)
279 //TODO - clean asserion and remapping logic
280 if (isPrimitive(type) != isPrimitive(stackValue.type)) {
281 //don't remap boxing/unboxing primitives - lost identity and perfomance
282 return true;
283 }
284
285 if (stackValue instanceof StackValue.Local) {
286 return false;
287 }
288
289 if (stackValue instanceof StackValue.Composed) {
290 //see: Method.isSpecialStackValue: go through aload 0
291 if (codegen.getContext().isInliningLambda() && codegen.getContext().getContextDescriptor() instanceof AnonymousFunctionDescriptor) {
292 if (descriptor != null && !InlineUtil.hasNoinlineAnnotation(descriptor)) {
293 //TODO: check type of context
294 return false;
295 }
296 }
297 }
298 return true;
299 }
300
301 private void doWithParameter(ParameterInfo info) {
302 recordParamInfo(info, true);
303 putParameterOnStack(info);
304 }
305
306 private int recordParamInfo(ParameterInfo info, boolean addToFrame) {
307 Type type = info.type;
308 actualParameters.add(info);
309 if (info.getType().getSize() == 2) {
310 actualParameters.add(ParameterInfo.STUB);
311 }
312 if (addToFrame) {
313 return originalFunctionFrame.enterTemp(type);
314 }
315 return -1;
316 }
317
318 private void putParameterOnStack(ParameterInfo info) {
319 if (!info.isSkippedOrRemapped()) {
320 int index = info.getIndex();
321 Type type = info.type;
322 StackValue.local(index, type).store(type, codegen.v);
323 }
324 }
325
326 @Override
327 public void putHiddenParams() {
328 List<JvmMethodParameterSignature> types = jvmSignature.getValueParameters();
329
330 if (!isStaticMethod(functionDescriptor, context)) {
331 Type type = AsmTypeConstants.OBJECT_TYPE;
332 ParameterInfo info = new ParameterInfo(type, false, codegen.getFrameMap().enterTemp(type), -1);
333 recordParamInfo(info, false);
334 }
335
336 for (JvmMethodParameterSignature param : types) {
337 if (param.getKind() == JvmMethodParameterKind.VALUE) {
338 break;
339 }
340 Type type = param.getAsmType();
341 ParameterInfo info = new ParameterInfo(type, false, codegen.getFrameMap().enterTemp(type), -1);
342 recordParamInfo(info, false);
343 }
344
345 for (ListIterator<? extends ParameterInfo> iterator = actualParameters.listIterator(actualParameters.size()); iterator.hasPrevious(); ) {
346 ParameterInfo param = iterator.previous();
347 putParameterOnStack(param);
348 }
349 }
350
351 public void leaveTemps() {
352 FrameMap frameMap = codegen.getFrameMap();
353 for (ListIterator<? extends ParameterInfo> iterator = actualParameters.listIterator(actualParameters.size()); iterator.hasPrevious(); ) {
354 ParameterInfo param = iterator.previous();
355 if (!param.isSkippedOrRemapped()) {
356 frameMap.leaveTemp(param.type);
357 }
358 }
359 }
360
361 public static boolean isInliningClosure(JetExpression expression, ValueParameterDescriptor valueParameterDescriptora) {
362 //TODO deparenthisise
363 return expression instanceof JetFunctionLiteralExpression &&
364 !InlineUtil.hasNoinlineAnnotation(valueParameterDescriptora);
365 }
366
367 public void rememberClosure(JetFunctionLiteralExpression expression, Type type) {
368 ParameterInfo closureInfo = new ParameterInfo(type, true, -1, -1);
369 int index = recordParamInfo(closureInfo, true);
370
371 LambdaInfo info = new LambdaInfo(expression, typeMapper);
372 expressionMap.put(index, info);
373
374 closureInfo.setLambda(info);
375 }
376
377 private void putClosureParametersOnStack() {
378 //TODO: SORT
379 int currentSize = actualParameters.size();
380 for (LambdaInfo next : expressionMap.values()) {
381 if (next.closure != null) {
382 activeLambda = next;
383 next.setParamOffset(currentSize);
384 codegen.pushClosureOnStack(next.closure, false, this);
385 currentSize += next.getCapturedVarsSize();
386 }
387 }
388 activeLambda = null;
389 }
390
391 private List<CapturedParamInfo> getAllCaptured() {
392 //TODO: SORT
393 List<CapturedParamInfo> result = new ArrayList<CapturedParamInfo>();
394 for (LambdaInfo next : expressionMap.values()) {
395 if (next.closure != null) {
396 result.addAll(next.getCapturedVars());
397 }
398 }
399 return result;
400 }
401
402 public static CodegenContext getContext(DeclarationDescriptor descriptor, GenerationState state) {
403 if (descriptor instanceof PackageFragmentDescriptor) {
404 return new PackageContext((PackageFragmentDescriptor) descriptor, null, null);
405 }
406
407 CodegenContext parent = getContext(descriptor.getContainingDeclaration(), state);
408
409 if (descriptor instanceof ClassDescriptor) {
410 OwnerKind kind = DescriptorUtils.isTrait(descriptor) ? OwnerKind.TRAIT_IMPL : OwnerKind.IMPLEMENTATION;
411 return parent.intoClass((ClassDescriptor) descriptor, kind, state);
412 }
413 else if (descriptor instanceof FunctionDescriptor) {
414 return parent.intoFunction((FunctionDescriptor) descriptor);
415 }
416
417 throw new IllegalStateException("Couldn't build context for " + descriptorName(descriptor));
418 }
419
420 private static boolean isStaticMethod(FunctionDescriptor functionDescriptor, MethodContext context) {
421 return (getMethodAsmFlags(functionDescriptor, context.getContextKind()) & Opcodes.ACC_STATIC) != 0;
422 }
423
424 @NotNull
425 public static String getNodeText(@Nullable MethodNode node) {
426 if (node == null) {
427 return "Not generated";
428 }
429 Textifier p = new Textifier();
430 node.accept(new TraceMethodVisitor(p));
431 StringWriter sw = new StringWriter();
432 p.print(new PrintWriter(sw));
433 sw.flush();
434 return node.name + " " + node.desc + ": \n " + sw.getBuffer().toString();
435 }
436
437 private static String descriptorName(DeclarationDescriptor descriptor) {
438 return DescriptorRenderer.SHORT_NAMES_IN_TYPES.render(descriptor);
439 }
440
441 @Override
442 public void genValueAndPut(
443 @NotNull ValueParameterDescriptor valueParameterDescriptor,
444 @NotNull JetExpression argumentExpression,
445 @NotNull Type parameterType
446 ) {
447 //TODO deparenthisise
448 if (isInliningClosure(argumentExpression, valueParameterDescriptor)) {
449 rememberClosure((JetFunctionLiteralExpression) argumentExpression, parameterType);
450 } else {
451 StackValue value = codegen.gen(argumentExpression);
452 if (shouldPutValue(parameterType, value, valueParameterDescriptor)) {
453 value.put(parameterType, codegen.v);
454 }
455 afterParameterPut(parameterType, value, valueParameterDescriptor);
456 }
457 }
458
459 @Override
460 public void putCapturedValueOnStack(
461 @NotNull StackValue stackValue, @NotNull Type valueType, int paramIndex
462 ) {
463 if (shouldPutValue(stackValue.type, stackValue, null)) {
464 stackValue.put(stackValue.type, codegen.v);
465 }
466 putCapturedInLocal(stackValue.type, stackValue, null, paramIndex);
467 }
468 }