001    /*
002     * Copyright 2010-2013 JetBrains s.r.o.
003     *
004     * Licensed under the Apache License, Version 2.0 (the "License");
005     * you may not use this file except in compliance with the License.
006     * You may obtain a copy of the License at
007     *
008     * http://www.apache.org/licenses/LICENSE-2.0
009     *
010     * Unless required by applicable law or agreed to in writing, software
011     * distributed under the License is distributed on an "AS IS" BASIS,
012     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013     * See the License for the specific language governing permissions and
014     * limitations under the License.
015     */
016    
017    package org.jetbrains.jet.lang.types.expressions;
018    
019    import com.google.common.collect.Lists;
020    import com.intellij.psi.PsiElement;
021    import com.intellij.util.Function;
022    import com.intellij.util.containers.ContainerUtil;
023    import kotlin.Function0;
024    import org.jetbrains.annotations.NotNull;
025    import org.jetbrains.annotations.Nullable;
026    import org.jetbrains.jet.lang.descriptors.*;
027    import org.jetbrains.jet.lang.descriptors.annotations.Annotations;
028    import org.jetbrains.jet.lang.descriptors.impl.*;
029    import org.jetbrains.jet.lang.psi.*;
030    import org.jetbrains.jet.lang.resolve.*;
031    import org.jetbrains.jet.lang.resolve.name.LabelName;
032    import org.jetbrains.jet.lang.resolve.name.Name;
033    import org.jetbrains.jet.lang.resolve.scopes.JetScope;
034    import org.jetbrains.jet.lang.types.*;
035    import org.jetbrains.jet.lang.types.checker.JetTypeChecker;
036    import org.jetbrains.jet.lang.types.lang.KotlinBuiltIns;
037    import org.jetbrains.jet.util.slicedmap.WritableSlice;
038    
039    import java.util.Collection;
040    import java.util.Collections;
041    import java.util.List;
042    
043    import static org.jetbrains.jet.lang.diagnostics.Errors.*;
044    import static org.jetbrains.jet.lang.resolve.BindingContext.*;
045    import static org.jetbrains.jet.lang.resolve.calls.context.ContextDependency.INDEPENDENT;
046    import static org.jetbrains.jet.lang.types.TypeUtils.*;
047    import static org.jetbrains.jet.lang.types.expressions.CoercionStrategy.COERCION_TO_UNIT;
048    
049    public class ClosureExpressionsTypingVisitor extends ExpressionTypingVisitor {
050    
051        protected ClosureExpressionsTypingVisitor(@NotNull ExpressionTypingInternals facade) {
052            super(facade);
053        }
054    
055        @Override
056        public JetTypeInfo visitObjectLiteralExpression(@NotNull final JetObjectLiteralExpression expression, final ExpressionTypingContext context) {
057            DelegatingBindingTrace delegatingBindingTrace = context.trace.get(TRACE_DELTAS_CACHE, expression.getObjectDeclaration());
058            if (delegatingBindingTrace != null) {
059                delegatingBindingTrace.addAllMyDataTo(context.trace);
060                JetType type = context.trace.get(EXPRESSION_TYPE, expression);
061                return DataFlowUtils.checkType(type, expression, context, context.dataFlowInfo);
062            }
063            final JetType[] result = new JetType[1];
064            final TemporaryBindingTrace temporaryTrace = TemporaryBindingTrace.create(context.trace, "trace to resolve object literal expression", expression);
065            ObservableBindingTrace.RecordHandler<PsiElement, ClassDescriptor> handler = new ObservableBindingTrace.RecordHandler<PsiElement, ClassDescriptor>() {
066    
067                @Override
068                public void handleRecord(WritableSlice<PsiElement, ClassDescriptor> slice, PsiElement declaration, final ClassDescriptor descriptor) {
069                    if (slice == CLASS && declaration == expression.getObjectDeclaration()) {
070                        JetType defaultType = DeferredType.createRecursionIntolerant(components.globalContext.getStorageManager(),
071                                                                                     context.trace,
072                                                                                     new Function0<JetType>() {
073                                                                                         @Override
074                                                                                         public JetType invoke() {
075                                                                                             return descriptor.getDefaultType();
076                                                                                         }
077                                                                                     });
078                        result[0] = defaultType;
079                        if (!context.trace.get(PROCESSED, expression)) {
080                            temporaryTrace.record(EXPRESSION_TYPE, expression, defaultType);
081                            temporaryTrace.record(PROCESSED, expression);
082                        }
083                    }
084                }
085            };
086            ObservableBindingTrace traceAdapter = new ObservableBindingTrace(temporaryTrace);
087            traceAdapter.addHandler(CLASS, handler);
088            TopDownAnalyzer.processClassOrObject(components.globalContext,
089                                                 null, // don't need to add classifier of object literal to any scope
090                                                 context.replaceBindingTrace(traceAdapter).replaceContextDependency(INDEPENDENT),
091                                                 context.scope.getContainingDeclaration(),
092                                                 expression.getObjectDeclaration());
093    
094            DelegatingBindingTrace cloneDelta = new DelegatingBindingTrace(
095                    new BindingTraceContext().getBindingContext(), "cached delta trace for object literal expression resolve", expression);
096            temporaryTrace.addAllMyDataTo(cloneDelta);
097            context.trace.record(TRACE_DELTAS_CACHE, expression.getObjectDeclaration(), cloneDelta);
098            temporaryTrace.commit();
099            return DataFlowUtils.checkType(result[0], expression, context, context.dataFlowInfo);
100        }
101    
102        @Override
103        public JetTypeInfo visitFunctionLiteralExpression(@NotNull JetFunctionLiteralExpression expression, ExpressionTypingContext context) {
104            JetBlockExpression bodyExpression = expression.getFunctionLiteral().getBodyExpression();
105            if (bodyExpression == null) return null;
106    
107            Name callerName = getCallerName(expression);
108            if (callerName != null) {
109                context.labelResolver.enterLabeledElement(new LabelName(callerName.asString()), expression);
110            }
111    
112            JetType expectedType = context.expectedType;
113            boolean functionTypeExpected = !noExpectedType(expectedType) && KotlinBuiltIns.getInstance().isFunctionOrExtensionFunctionType(
114                    expectedType);
115    
116            AnonymousFunctionDescriptor functionDescriptor = createFunctionDescriptor(expression, context, functionTypeExpected);
117            JetType safeReturnType = computeReturnType(expression, context, functionDescriptor, functionTypeExpected);
118            functionDescriptor.setReturnType(safeReturnType);
119    
120            JetType receiver = DescriptorUtils.getReceiverParameterType(functionDescriptor.getReceiverParameter());
121            List<JetType> valueParametersTypes = DescriptorUtils.getValueParametersTypes(functionDescriptor.getValueParameters());
122            JetType resultType = KotlinBuiltIns.getInstance().getFunctionType(
123                    Annotations.EMPTY, receiver, valueParametersTypes, safeReturnType);
124            if (!noExpectedType(expectedType) && KotlinBuiltIns.getInstance().isFunctionOrExtensionFunctionType(expectedType)) {
125                // all checks were done before
126                return JetTypeInfo.create(resultType, context.dataFlowInfo);
127            }
128    
129            if (callerName != null) {
130                context.labelResolver.exitLabeledElement(expression);
131            }
132    
133            return DataFlowUtils.checkType(resultType, expression, context, context.dataFlowInfo);
134        }
135    
136        @Nullable
137        private static Name getCallerName(@NotNull JetFunctionLiteralExpression expression) {
138            JetCallExpression callExpression = getContainingCallExpression(expression);
139            if (callExpression == null) return null;
140    
141            JetExpression calleeExpression = callExpression.getCalleeExpression();
142            if (calleeExpression instanceof JetSimpleNameExpression) {
143                JetSimpleNameExpression nameExpression = (JetSimpleNameExpression) calleeExpression;
144                return nameExpression.getReferencedNameAsName();
145            }
146    
147            return null;
148        }
149    
150        @Nullable
151        private static JetCallExpression getContainingCallExpression(JetFunctionLiteralExpression expression) {
152            PsiElement parent = expression.getParent();
153            if (parent instanceof JetCallExpression) {
154                // f {}
155                return (JetCallExpression) parent;
156            }
157    
158            if (parent instanceof JetValueArgument) {
159                // f ({}) or f(p = {})
160                JetValueArgument argument = (JetValueArgument) parent;
161                PsiElement argList = argument.getParent();
162                if (argList == null) return null;
163                PsiElement call = argList.getParent();
164                if (call instanceof JetCallExpression) {
165                    return (JetCallExpression) call;
166                }
167            }
168            return null;
169        }
170    
171        @NotNull
172        private AnonymousFunctionDescriptor createFunctionDescriptor(
173                @NotNull JetFunctionLiteralExpression expression,
174                @NotNull ExpressionTypingContext context,
175                boolean functionTypeExpected
176        ) {
177            JetFunctionLiteral functionLiteral = expression.getFunctionLiteral();
178            JetTypeReference receiverTypeRef = functionLiteral.getReceiverTypeRef();
179            AnonymousFunctionDescriptor functionDescriptor = new AnonymousFunctionDescriptor(
180                    context.scope.getContainingDeclaration(), Annotations.EMPTY, CallableMemberDescriptor.Kind.DECLARATION);
181    
182            List<ValueParameterDescriptor> valueParameterDescriptors = createValueParameterDescriptors(context, functionLiteral,
183                                                                                                       functionDescriptor, functionTypeExpected);
184    
185            JetType effectiveReceiverType;
186            if (receiverTypeRef == null) {
187                if (functionTypeExpected) {
188                    effectiveReceiverType = KotlinBuiltIns.getInstance().getReceiverType(context.expectedType);
189                }
190                else {
191                    effectiveReceiverType = null;
192                }
193            }
194            else {
195                effectiveReceiverType = components.expressionTypingServices.getTypeResolver().resolveType(context.scope, receiverTypeRef, context.trace, true);
196            }
197            functionDescriptor.initialize(effectiveReceiverType,
198                                          ReceiverParameterDescriptor.NO_RECEIVER_PARAMETER,
199                                          Collections.<TypeParameterDescriptorImpl>emptyList(),
200                                          valueParameterDescriptors,
201                                          /*unsubstitutedReturnType = */ null,
202                                          Modality.FINAL,
203                                          Visibilities.LOCAL
204            );
205            BindingContextUtils.recordFunctionDeclarationToDescriptor(context.trace, functionLiteral, functionDescriptor);
206            return functionDescriptor;
207        }
208    
209        @NotNull
210        private List<ValueParameterDescriptor> createValueParameterDescriptors(
211                @NotNull ExpressionTypingContext context,
212                @NotNull JetFunctionLiteral functionLiteral,
213                @NotNull FunctionDescriptorImpl functionDescriptor,
214                boolean functionTypeExpected
215        ) {
216            List<ValueParameterDescriptor> valueParameterDescriptors = Lists.newArrayList();
217            List<JetParameter> declaredValueParameters = functionLiteral.getValueParameters();
218    
219            List<ValueParameterDescriptor> expectedValueParameters =  (functionTypeExpected)
220                                                              ? KotlinBuiltIns.getInstance().getValueParameters(functionDescriptor, context.expectedType)
221                                                              : null;
222    
223            JetParameterList valueParameterList = functionLiteral.getValueParameterList();
224            boolean hasDeclaredValueParameters = valueParameterList != null;
225            if (functionTypeExpected && !hasDeclaredValueParameters && expectedValueParameters.size() == 1) {
226                ValueParameterDescriptor valueParameterDescriptor = expectedValueParameters.get(0);
227                ValueParameterDescriptor it = new ValueParameterDescriptorImpl(
228                        functionDescriptor, null, 0, Annotations.EMPTY, Name.identifier("it"),
229                        valueParameterDescriptor.getType(), valueParameterDescriptor.hasDefaultValue(), valueParameterDescriptor.getVarargElementType()
230                );
231                valueParameterDescriptors.add(it);
232                context.trace.record(AUTO_CREATED_IT, it);
233            }
234            else {
235                if (expectedValueParameters != null && declaredValueParameters.size() != expectedValueParameters.size()) {
236                    List<JetType> expectedParameterTypes = DescriptorUtils.getValueParametersTypes(expectedValueParameters);
237                    context.trace.report(EXPECTED_PARAMETERS_NUMBER_MISMATCH.on(functionLiteral, expectedParameterTypes.size(), expectedParameterTypes));
238                }
239                for (int i = 0; i < declaredValueParameters.size(); i++) {
240                    ValueParameterDescriptor valueParameterDescriptor = createValueParameterDescriptor(
241                            context, functionDescriptor, declaredValueParameters, expectedValueParameters, i);
242                    valueParameterDescriptors.add(valueParameterDescriptor);
243                }
244            }
245            return valueParameterDescriptors;
246        }
247    
248        @NotNull
249        private ValueParameterDescriptor createValueParameterDescriptor(
250                @NotNull ExpressionTypingContext context,
251                @NotNull FunctionDescriptorImpl functionDescriptor,
252                @NotNull List<JetParameter> declaredValueParameters,
253                @Nullable List<ValueParameterDescriptor> expectedValueParameters,
254                int index
255        ) {
256            JetParameter declaredParameter = declaredValueParameters.get(index);
257            JetTypeReference typeReference = declaredParameter.getTypeReference();
258    
259            JetType expectedType;
260            if (expectedValueParameters != null && index < expectedValueParameters.size()) {
261                expectedType = expectedValueParameters.get(index).getType();
262            }
263            else {
264                expectedType = null;
265            }
266            JetType type;
267            if (typeReference != null) {
268                type = components.expressionTypingServices.getTypeResolver().resolveType(context.scope, typeReference, context.trace, true);
269                if (expectedType != null) {
270                    if (!JetTypeChecker.INSTANCE.isSubtypeOf(expectedType, type)) {
271                        context.trace.report(EXPECTED_PARAMETER_TYPE_MISMATCH.on(declaredParameter, expectedType));
272                    }
273                }
274            }
275            else {
276                if (expectedType == null || expectedType == DONT_CARE || ErrorUtils.isUninferredParameter(expectedType)) {
277                    context.trace.report(CANNOT_INFER_PARAMETER_TYPE.on(declaredParameter));
278                }
279                if (expectedType != null) {
280                    type = expectedType;
281                }
282                else {
283                    type = CANT_INFER_LAMBDA_PARAM_TYPE;
284                }
285            }
286            return components.expressionTypingServices.getDescriptorResolver().resolveValueParameterDescriptorWithAnnotationArguments(
287                    context.scope, functionDescriptor, declaredParameter, index, type, context.trace);
288        }
289    
290        @NotNull
291        private JetType computeReturnType(
292                @NotNull JetFunctionLiteralExpression expression,
293                @NotNull ExpressionTypingContext context,
294                @NotNull SimpleFunctionDescriptorImpl functionDescriptor,
295                boolean functionTypeExpected
296        ) {
297            JetType expectedReturnType = functionTypeExpected ? KotlinBuiltIns.getInstance().getReturnTypeFromFunctionType(context.expectedType) : null;
298            JetType returnType = computeUnsafeReturnType(expression, context, functionDescriptor, expectedReturnType);
299    
300            if (!expression.getFunctionLiteral().hasDeclaredReturnType() && functionTypeExpected) {
301                if (KotlinBuiltIns.getInstance().isUnit(expectedReturnType)) {
302                    return KotlinBuiltIns.getInstance().getUnitType();
303                }
304            }
305            return returnType == null ? CANT_INFER_LAMBDA_PARAM_TYPE : returnType;
306        }
307    
308        @Nullable
309        private JetType computeUnsafeReturnType(
310                @NotNull JetFunctionLiteralExpression expression,
311                @NotNull ExpressionTypingContext context,
312                @NotNull SimpleFunctionDescriptorImpl functionDescriptor,
313                @Nullable JetType expectedReturnType
314        ) {
315            JetFunctionLiteral functionLiteral = expression.getFunctionLiteral();
316            JetBlockExpression bodyExpression = functionLiteral.getBodyExpression();
317            assert bodyExpression != null;
318    
319            JetScope functionInnerScope = FunctionDescriptorUtil.getFunctionInnerScope(context.scope, functionDescriptor, context.trace);
320            JetTypeReference returnTypeRef = functionLiteral.getReturnTypeRef();
321            JetType declaredReturnType = null;
322            if (returnTypeRef != null) {
323                declaredReturnType = components.expressionTypingServices.getTypeResolver().resolveType(context.scope, returnTypeRef, context.trace, true);
324                // This is needed for ControlStructureTypingVisitor#visitReturnExpression() to properly type-check returned expressions
325                functionDescriptor.setReturnType(declaredReturnType);
326                if (expectedReturnType != null) {
327                    if (!JetTypeChecker.INSTANCE.isSubtypeOf(declaredReturnType, expectedReturnType)) {
328                        context.trace.report(EXPECTED_RETURN_TYPE_MISMATCH.on(returnTypeRef, expectedReturnType));
329                    }
330                }
331            }
332    
333            // Type-check the body
334            ExpressionTypingContext newContext = context.replaceScope(functionInnerScope)
335                    .replaceExpectedType(declaredReturnType != null
336                                         ? declaredReturnType
337                                         : (expectedReturnType != null ? expectedReturnType : NO_EXPECTED_TYPE));
338    
339            JetType typeOfBodyExpression = components.expressionTypingServices.getBlockReturnedType(bodyExpression, COERCION_TO_UNIT, newContext).getType();
340    
341            List<JetType> returnedExpressionTypes = Lists.newArrayList(getTypesOfLocallyReturnedExpressions(
342                    functionLiteral, context.trace, collectReturns(bodyExpression)));
343            ContainerUtil.addIfNotNull(returnedExpressionTypes, typeOfBodyExpression);
344    
345            if (declaredReturnType != null) return declaredReturnType;
346            if (returnedExpressionTypes.isEmpty()) return null;
347            return CommonSupertypes.commonSupertype(returnedExpressionTypes);
348        }
349    
350        private static List<JetType> getTypesOfLocallyReturnedExpressions(
351                final JetFunctionLiteral functionLiteral,
352                final BindingTrace trace,
353                Collection<JetReturnExpression> returnExpressions
354        ) {
355            return ContainerUtil.mapNotNull(returnExpressions, new Function<JetReturnExpression, JetType>() {
356                @Override
357                public JetType fun(JetReturnExpression returnExpression) {
358                    JetSimpleNameExpression label = returnExpression.getTargetLabel();
359                    if (label == null) {
360                        // No label => non-local return
361                        return null;
362                    }
363    
364                    PsiElement labelTarget = trace.get(BindingContext.LABEL_TARGET, label);
365                    if (labelTarget != functionLiteral) {
366                        // Either a local return of inner lambda/function or a non-local return
367                        return null;
368                    }
369    
370                    JetExpression returnedExpression = returnExpression.getReturnedExpression();
371                    if (returnedExpression == null) {
372                        return KotlinBuiltIns.getInstance().getUnitType();
373                    }
374                    JetType returnedType = trace.get(EXPRESSION_TYPE, returnedExpression);
375                    assert returnedType != null : "No type for returned expression: " + returnedExpression + ",\n" +
376                                                  "the type should have been computed by getBlockReturnedType() above";
377                    return returnedType;
378                }
379            });
380        }
381    
382        public static Collection<JetReturnExpression> collectReturns(@NotNull JetExpression expression) {
383            Collection<JetReturnExpression> result = Lists.newArrayList();
384            expression.accept(
385                    new JetTreeVisitor<Collection<JetReturnExpression>>() {
386                        @Override
387                        public Void visitReturnExpression(
388                                @NotNull JetReturnExpression expression, Collection<JetReturnExpression> data
389                        ) {
390                            data.add(expression);
391                            return null;
392                        }
393                    },
394                    result
395            );
396            return result;
397        }
398    }