001    /*
002     * Copyright 2010-2013 JetBrains s.r.o.
003     *
004     * Licensed under the Apache License, Version 2.0 (the "License");
005     * you may not use this file except in compliance with the License.
006     * You may obtain a copy of the License at
007     *
008     * http://www.apache.org/licenses/LICENSE-2.0
009     *
010     * Unless required by applicable law or agreed to in writing, software
011     * distributed under the License is distributed on an "AS IS" BASIS,
012     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013     * See the License for the specific language governing permissions and
014     * limitations under the License.
015     */
016    
017    package org.jetbrains.jet.lang.types.expressions;
018    
019    import com.google.common.collect.Lists;
020    import com.intellij.psi.PsiElement;
021    import com.intellij.util.Function;
022    import com.intellij.util.containers.ContainerUtil;
023    import kotlin.Function0;
024    import org.jetbrains.annotations.NotNull;
025    import org.jetbrains.annotations.Nullable;
026    import org.jetbrains.jet.lang.descriptors.*;
027    import org.jetbrains.jet.lang.descriptors.annotations.Annotations;
028    import org.jetbrains.jet.lang.descriptors.impl.*;
029    import org.jetbrains.jet.lang.psi.*;
030    import org.jetbrains.jet.lang.resolve.*;
031    import org.jetbrains.jet.lang.resolve.name.LabelName;
032    import org.jetbrains.jet.lang.resolve.name.Name;
033    import org.jetbrains.jet.lang.resolve.scopes.JetScope;
034    import org.jetbrains.jet.lang.types.CommonSupertypes;
035    import org.jetbrains.jet.lang.types.DeferredType;
036    import org.jetbrains.jet.lang.types.JetType;
037    import org.jetbrains.jet.lang.types.JetTypeInfo;
038    import org.jetbrains.jet.lang.types.checker.JetTypeChecker;
039    import org.jetbrains.jet.lang.types.lang.KotlinBuiltIns;
040    import org.jetbrains.jet.util.slicedmap.WritableSlice;
041    
042    import java.util.Collection;
043    import java.util.Collections;
044    import java.util.List;
045    
046    import static org.jetbrains.jet.lang.diagnostics.Errors.*;
047    import static org.jetbrains.jet.lang.resolve.BindingContext.*;
048    import static org.jetbrains.jet.lang.resolve.calls.context.ContextDependency.INDEPENDENT;
049    import static org.jetbrains.jet.lang.types.TypeUtils.*;
050    import static org.jetbrains.jet.lang.types.expressions.CoercionStrategy.COERCION_TO_UNIT;
051    
052    public class ClosureExpressionsTypingVisitor extends ExpressionTypingVisitor {
053    
054        protected ClosureExpressionsTypingVisitor(@NotNull ExpressionTypingInternals facade) {
055            super(facade);
056        }
057    
058        @Override
059        public JetTypeInfo visitObjectLiteralExpression(@NotNull final JetObjectLiteralExpression expression, final ExpressionTypingContext context) {
060            DelegatingBindingTrace delegatingBindingTrace = context.trace.get(TRACE_DELTAS_CACHE, expression.getObjectDeclaration());
061            if (delegatingBindingTrace != null) {
062                delegatingBindingTrace.addAllMyDataTo(context.trace);
063                JetType type = context.trace.get(EXPRESSION_TYPE, expression);
064                return DataFlowUtils.checkType(type, expression, context, context.dataFlowInfo);
065            }
066            final JetType[] result = new JetType[1];
067            final TemporaryBindingTrace temporaryTrace = TemporaryBindingTrace.create(context.trace, "trace to resolve object literal expression", expression);
068            ObservableBindingTrace.RecordHandler<PsiElement, ClassDescriptor> handler = new ObservableBindingTrace.RecordHandler<PsiElement, ClassDescriptor>() {
069    
070                @Override
071                public void handleRecord(WritableSlice<PsiElement, ClassDescriptor> slice, PsiElement declaration, final ClassDescriptor descriptor) {
072                    if (slice == CLASS && declaration == expression.getObjectDeclaration()) {
073                        JetType defaultType = DeferredType.createRecursionIntolerant(components.globalContext.getStorageManager(),
074                                                                                     context.trace,
075                                                                                     new Function0<JetType>() {
076                                                                                         @Override
077                                                                                         public JetType invoke() {
078                                                                                             return descriptor.getDefaultType();
079                                                                                         }
080                                                                                     });
081                        result[0] = defaultType;
082                        if (!context.trace.get(PROCESSED, expression)) {
083                            temporaryTrace.record(EXPRESSION_TYPE, expression, defaultType);
084                            temporaryTrace.record(PROCESSED, expression);
085                        }
086                    }
087                }
088            };
089            ObservableBindingTrace traceAdapter = new ObservableBindingTrace(temporaryTrace);
090            traceAdapter.addHandler(CLASS, handler);
091            TopDownAnalyzer.processClassOrObject(components.globalContext,
092                                                 null, // don't need to add classifier of object literal to any scope
093                                                 context.replaceBindingTrace(traceAdapter).replaceContextDependency(INDEPENDENT),
094                                                 context.scope.getContainingDeclaration(),
095                                                 expression.getObjectDeclaration());
096    
097            DelegatingBindingTrace cloneDelta = new DelegatingBindingTrace(
098                    new BindingTraceContext().getBindingContext(), "cached delta trace for object literal expression resolve", expression);
099            temporaryTrace.addAllMyDataTo(cloneDelta);
100            context.trace.record(TRACE_DELTAS_CACHE, expression.getObjectDeclaration(), cloneDelta);
101            temporaryTrace.commit();
102            return DataFlowUtils.checkType(result[0], expression, context, context.dataFlowInfo);
103        }
104    
105        @Override
106        public JetTypeInfo visitFunctionLiteralExpression(@NotNull JetFunctionLiteralExpression expression, ExpressionTypingContext context) {
107            JetBlockExpression bodyExpression = expression.getFunctionLiteral().getBodyExpression();
108            if (bodyExpression == null) return null;
109    
110            Name callerName = getCallerName(expression);
111            if (callerName != null) {
112                context.labelResolver.enterLabeledElement(new LabelName(callerName.asString()), expression);
113            }
114    
115            JetType expectedType = context.expectedType;
116            boolean functionTypeExpected = !noExpectedType(expectedType) && KotlinBuiltIns.getInstance().isFunctionOrExtensionFunctionType(
117                    expectedType);
118    
119            AnonymousFunctionDescriptor functionDescriptor = createFunctionDescriptor(expression, context, functionTypeExpected);
120            JetType safeReturnType = computeReturnType(expression, context, functionDescriptor, functionTypeExpected);
121            functionDescriptor.setReturnType(safeReturnType);
122    
123            JetType receiver = DescriptorUtils.getReceiverParameterType(functionDescriptor.getReceiverParameter());
124            List<JetType> valueParametersTypes = DescriptorUtils.getValueParametersTypes(functionDescriptor.getValueParameters());
125            JetType resultType = KotlinBuiltIns.getInstance().getFunctionType(
126                    Annotations.EMPTY, receiver, valueParametersTypes, safeReturnType);
127            if (!noExpectedType(expectedType) && KotlinBuiltIns.getInstance().isFunctionOrExtensionFunctionType(expectedType)) {
128                // all checks were done before
129                return JetTypeInfo.create(resultType, context.dataFlowInfo);
130            }
131    
132            if (callerName != null) {
133                context.labelResolver.exitLabeledElement(expression);
134            }
135    
136            return DataFlowUtils.checkType(resultType, expression, context, context.dataFlowInfo);
137        }
138    
139        @Nullable
140        private static Name getCallerName(@NotNull JetFunctionLiteralExpression expression) {
141            JetCallExpression callExpression = getContainingCallExpression(expression);
142            if (callExpression == null) return null;
143    
144            JetExpression calleeExpression = callExpression.getCalleeExpression();
145            if (calleeExpression instanceof JetSimpleNameExpression) {
146                JetSimpleNameExpression nameExpression = (JetSimpleNameExpression) calleeExpression;
147                return nameExpression.getReferencedNameAsName();
148            }
149    
150            return null;
151        }
152    
153        @Nullable
154        private static JetCallExpression getContainingCallExpression(JetFunctionLiteralExpression expression) {
155            PsiElement parent = expression.getParent();
156            if (parent instanceof JetCallExpression) {
157                // f {}
158                return (JetCallExpression) parent;
159            }
160    
161            if (parent instanceof JetValueArgument) {
162                // f ({}) or f(p = {})
163                JetValueArgument argument = (JetValueArgument) parent;
164                PsiElement argList = argument.getParent();
165                if (argList == null) return null;
166                PsiElement call = argList.getParent();
167                if (call instanceof JetCallExpression) {
168                    return (JetCallExpression) call;
169                }
170            }
171            return null;
172        }
173    
174        @NotNull
175        private AnonymousFunctionDescriptor createFunctionDescriptor(
176                @NotNull JetFunctionLiteralExpression expression,
177                @NotNull ExpressionTypingContext context,
178                boolean functionTypeExpected
179        ) {
180            JetFunctionLiteral functionLiteral = expression.getFunctionLiteral();
181            JetTypeReference receiverTypeRef = functionLiteral.getReceiverTypeRef();
182            AnonymousFunctionDescriptor functionDescriptor = new AnonymousFunctionDescriptor(
183                    context.scope.getContainingDeclaration(), Annotations.EMPTY, CallableMemberDescriptor.Kind.DECLARATION);
184    
185            List<ValueParameterDescriptor> valueParameterDescriptors = createValueParameterDescriptors(context, functionLiteral,
186                                                                                                       functionDescriptor, functionTypeExpected);
187    
188            JetType effectiveReceiverType;
189            if (receiverTypeRef == null) {
190                if (functionTypeExpected) {
191                    effectiveReceiverType = KotlinBuiltIns.getInstance().getReceiverType(context.expectedType);
192                }
193                else {
194                    effectiveReceiverType = null;
195                }
196            }
197            else {
198                effectiveReceiverType = components.expressionTypingServices.getTypeResolver().resolveType(context.scope, receiverTypeRef, context.trace, true);
199            }
200            functionDescriptor.initialize(effectiveReceiverType,
201                                          ReceiverParameterDescriptor.NO_RECEIVER_PARAMETER,
202                                          Collections.<TypeParameterDescriptorImpl>emptyList(),
203                                          valueParameterDescriptors,
204                                          /*unsubstitutedReturnType = */ null,
205                                          Modality.FINAL,
206                                          Visibilities.LOCAL
207            );
208            BindingContextUtils.recordFunctionDeclarationToDescriptor(context.trace, functionLiteral, functionDescriptor);
209            return functionDescriptor;
210        }
211    
212        @NotNull
213        private List<ValueParameterDescriptor> createValueParameterDescriptors(
214                @NotNull ExpressionTypingContext context,
215                @NotNull JetFunctionLiteral functionLiteral,
216                @NotNull FunctionDescriptorImpl functionDescriptor,
217                boolean functionTypeExpected
218        ) {
219            List<ValueParameterDescriptor> valueParameterDescriptors = Lists.newArrayList();
220            List<JetParameter> declaredValueParameters = functionLiteral.getValueParameters();
221    
222            List<ValueParameterDescriptor> expectedValueParameters =  (functionTypeExpected)
223                                                              ? KotlinBuiltIns.getInstance().getValueParameters(functionDescriptor, context.expectedType)
224                                                              : null;
225    
226            JetParameterList valueParameterList = functionLiteral.getValueParameterList();
227            boolean hasDeclaredValueParameters = valueParameterList != null;
228            if (functionTypeExpected && !hasDeclaredValueParameters && expectedValueParameters.size() == 1) {
229                ValueParameterDescriptor valueParameterDescriptor = expectedValueParameters.get(0);
230                ValueParameterDescriptor it = new ValueParameterDescriptorImpl(
231                        functionDescriptor, 0, Annotations.EMPTY, Name.identifier("it"),
232                        valueParameterDescriptor.getType(), valueParameterDescriptor.hasDefaultValue(), valueParameterDescriptor.getVarargElementType()
233                );
234                valueParameterDescriptors.add(it);
235                context.trace.record(AUTO_CREATED_IT, it);
236            }
237            else {
238                if (expectedValueParameters != null && declaredValueParameters.size() != expectedValueParameters.size()) {
239                    List<JetType> expectedParameterTypes = DescriptorUtils.getValueParametersTypes(expectedValueParameters);
240                    context.trace.report(EXPECTED_PARAMETERS_NUMBER_MISMATCH.on(functionLiteral, expectedParameterTypes.size(), expectedParameterTypes));
241                }
242                for (int i = 0; i < declaredValueParameters.size(); i++) {
243                    ValueParameterDescriptor valueParameterDescriptor = createValueParameterDescriptor(
244                            context, functionDescriptor, declaredValueParameters, expectedValueParameters, i);
245                    valueParameterDescriptors.add(valueParameterDescriptor);
246                }
247            }
248            return valueParameterDescriptors;
249        }
250    
251        @NotNull
252        private ValueParameterDescriptor createValueParameterDescriptor(
253                @NotNull ExpressionTypingContext context,
254                @NotNull FunctionDescriptorImpl functionDescriptor,
255                @NotNull List<JetParameter> declaredValueParameters,
256                @Nullable List<ValueParameterDescriptor> expectedValueParameters,
257                int index
258        ) {
259            JetParameter declaredParameter = declaredValueParameters.get(index);
260            JetTypeReference typeReference = declaredParameter.getTypeReference();
261    
262            JetType expectedType;
263            if (expectedValueParameters != null && index < expectedValueParameters.size()) {
264                expectedType = expectedValueParameters.get(index).getType();
265            }
266            else {
267                expectedType = null;
268            }
269            JetType type;
270            if (typeReference != null) {
271                type = components.expressionTypingServices.getTypeResolver().resolveType(context.scope, typeReference, context.trace, true);
272                if (expectedType != null) {
273                    if (!JetTypeChecker.INSTANCE.isSubtypeOf(expectedType, type)) {
274                        context.trace.report(EXPECTED_PARAMETER_TYPE_MISMATCH.on(declaredParameter, expectedType));
275                    }
276                }
277            }
278            else {
279                if (expectedType == null || expectedType == DONT_CARE || expectedType == CANT_INFER_TYPE_PARAMETER) {
280                    context.trace.report(CANNOT_INFER_PARAMETER_TYPE.on(declaredParameter));
281                }
282                if (expectedType != null) {
283                    type = expectedType;
284                }
285                else {
286                    type = CANT_INFER_LAMBDA_PARAM_TYPE;
287                }
288            }
289            return components.expressionTypingServices.getDescriptorResolver().resolveValueParameterDescriptorWithAnnotationArguments(
290                    context.scope, functionDescriptor, declaredParameter, index, type, context.trace);
291        }
292    
293        @NotNull
294        private JetType computeReturnType(
295                @NotNull JetFunctionLiteralExpression expression,
296                @NotNull ExpressionTypingContext context,
297                @NotNull SimpleFunctionDescriptorImpl functionDescriptor,
298                boolean functionTypeExpected
299        ) {
300            JetType expectedReturnType = functionTypeExpected ? KotlinBuiltIns.getInstance().getReturnTypeFromFunctionType(context.expectedType) : null;
301            JetType returnType = computeUnsafeReturnType(expression, context, functionDescriptor, expectedReturnType);
302    
303            if (!expression.getFunctionLiteral().hasDeclaredReturnType() && functionTypeExpected) {
304                if (KotlinBuiltIns.getInstance().isUnit(expectedReturnType)) {
305                    return KotlinBuiltIns.getInstance().getUnitType();
306                }
307            }
308            return returnType == null ? CANT_INFER_LAMBDA_PARAM_TYPE : returnType;
309        }
310    
311        @Nullable
312        private JetType computeUnsafeReturnType(
313                @NotNull JetFunctionLiteralExpression expression,
314                @NotNull ExpressionTypingContext context,
315                @NotNull SimpleFunctionDescriptorImpl functionDescriptor,
316                @Nullable JetType expectedReturnType
317        ) {
318            JetFunctionLiteral functionLiteral = expression.getFunctionLiteral();
319            JetBlockExpression bodyExpression = functionLiteral.getBodyExpression();
320            assert bodyExpression != null;
321    
322            JetScope functionInnerScope = FunctionDescriptorUtil.getFunctionInnerScope(context.scope, functionDescriptor, context.trace);
323            JetTypeReference returnTypeRef = functionLiteral.getReturnTypeRef();
324            JetType declaredReturnType = null;
325            if (returnTypeRef != null) {
326                declaredReturnType = components.expressionTypingServices.getTypeResolver().resolveType(context.scope, returnTypeRef, context.trace, true);
327                // This is needed for ControlStructureTypingVisitor#visitReturnExpression() to properly type-check returned expressions
328                functionDescriptor.setReturnType(declaredReturnType);
329                if (expectedReturnType != null) {
330                    if (!JetTypeChecker.INSTANCE.isSubtypeOf(declaredReturnType, expectedReturnType)) {
331                        context.trace.report(EXPECTED_RETURN_TYPE_MISMATCH.on(returnTypeRef, expectedReturnType));
332                    }
333                }
334            }
335    
336            // Type-check the body
337            ExpressionTypingContext newContext = context.replaceScope(functionInnerScope)
338                    .replaceExpectedType(declaredReturnType != null
339                                         ? declaredReturnType
340                                         : (expectedReturnType != null ? expectedReturnType : NO_EXPECTED_TYPE));
341    
342            JetType typeOfBodyExpression = components.expressionTypingServices.getBlockReturnedType(bodyExpression, COERCION_TO_UNIT, newContext).getType();
343    
344            List<JetType> returnedExpressionTypes = Lists.newArrayList(getTypesOfLocallyReturnedExpressions(
345                    functionLiteral, context.trace, collectReturns(bodyExpression)));
346            ContainerUtil.addIfNotNull(returnedExpressionTypes, typeOfBodyExpression);
347    
348            if (declaredReturnType != null) return declaredReturnType;
349            if (returnedExpressionTypes.isEmpty()) return null;
350            return CommonSupertypes.commonSupertype(returnedExpressionTypes);
351        }
352    
353        private static List<JetType> getTypesOfLocallyReturnedExpressions(
354                final JetFunctionLiteral functionLiteral,
355                final BindingTrace trace,
356                Collection<JetReturnExpression> returnExpressions
357        ) {
358            return ContainerUtil.mapNotNull(returnExpressions, new Function<JetReturnExpression, JetType>() {
359                @Override
360                public JetType fun(JetReturnExpression returnExpression) {
361                    JetSimpleNameExpression label = returnExpression.getTargetLabel();
362                    if (label == null) {
363                        // No label => non-local return
364                        return null;
365                    }
366    
367                    PsiElement labelTarget = trace.get(BindingContext.LABEL_TARGET, label);
368                    if (labelTarget != functionLiteral) {
369                        // Either a local return of inner lambda/function or a non-local return
370                        return null;
371                    }
372    
373                    JetExpression returnedExpression = returnExpression.getReturnedExpression();
374                    if (returnedExpression == null) {
375                        return KotlinBuiltIns.getInstance().getUnitType();
376                    }
377                    JetType returnedType = trace.get(EXPRESSION_TYPE, returnedExpression);
378                    assert returnedType != null : "No type for returned expression: " + returnedExpression + ",\n" +
379                                                  "the type should have been computed by getBlockReturnedType() above";
380                    return returnedType;
381                }
382            });
383        }
384    
385        public static Collection<JetReturnExpression> collectReturns(@NotNull JetExpression expression) {
386            Collection<JetReturnExpression> result = Lists.newArrayList();
387            expression.accept(
388                    new JetTreeVisitor<Collection<JetReturnExpression>>() {
389                        @Override
390                        public Void visitReturnExpression(
391                                @NotNull JetReturnExpression expression, Collection<JetReturnExpression> data
392                        ) {
393                            data.add(expression);
394                            return null;
395                        }
396                    },
397                    result
398            );
399            return result;
400        }
401    }