001    /*
002     * Copyright 2010-2013 JetBrains s.r.o.
003     *
004     * Licensed under the Apache License, Version 2.0 (the "License");
005     * you may not use this file except in compliance with the License.
006     * You may obtain a copy of the License at
007     *
008     * http://www.apache.org/licenses/LICENSE-2.0
009     *
010     * Unless required by applicable law or agreed to in writing, software
011     * distributed under the License is distributed on an "AS IS" BASIS,
012     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013     * See the License for the specific language governing permissions and
014     * limitations under the License.
015     */
016    
017    package org.jetbrains.jet.lang.types.expressions;
018    
019    import com.google.common.collect.Lists;
020    import com.intellij.psi.PsiElement;
021    import com.intellij.util.Function;
022    import com.intellij.util.containers.ContainerUtil;
023    import jet.Function0;
024    import org.jetbrains.annotations.NotNull;
025    import org.jetbrains.annotations.Nullable;
026    import org.jetbrains.jet.lang.descriptors.*;
027    import org.jetbrains.jet.lang.descriptors.annotations.AnnotationDescriptor;
028    import org.jetbrains.jet.lang.descriptors.impl.*;
029    import org.jetbrains.jet.lang.psi.*;
030    import org.jetbrains.jet.lang.resolve.*;
031    import org.jetbrains.jet.lang.resolve.name.LabelName;
032    import org.jetbrains.jet.lang.resolve.name.Name;
033    import org.jetbrains.jet.lang.resolve.scopes.JetScope;
034    import org.jetbrains.jet.lang.types.*;
035    import org.jetbrains.jet.lang.types.checker.JetTypeChecker;
036    import org.jetbrains.jet.lang.types.lang.KotlinBuiltIns;
037    import org.jetbrains.jet.util.slicedmap.WritableSlice;
038    
039    import java.util.Collection;
040    import java.util.Collections;
041    import java.util.List;
042    
043    import static org.jetbrains.jet.lang.diagnostics.Errors.*;
044    import static org.jetbrains.jet.lang.resolve.BindingContext.*;
045    import static org.jetbrains.jet.lang.resolve.calls.context.ContextDependency.INDEPENDENT;
046    import static org.jetbrains.jet.lang.types.TypeUtils.*;
047    import static org.jetbrains.jet.lang.types.expressions.CoercionStrategy.COERCION_TO_UNIT;
048    import static org.jetbrains.jet.util.StorageUtil.createRecursionIntolerantLazyValueWithDefault;
049    
050    public class ClosureExpressionsTypingVisitor extends ExpressionTypingVisitor {
051        protected ClosureExpressionsTypingVisitor(@NotNull ExpressionTypingInternals facade) {
052            super(facade);
053        }
054    
055        @Override
056        public JetTypeInfo visitObjectLiteralExpression(@NotNull final JetObjectLiteralExpression expression, final ExpressionTypingContext context) {
057            DelegatingBindingTrace delegatingBindingTrace = context.trace.get(TRACE_DELTAS_CACHE, expression.getObjectDeclaration());
058            if (delegatingBindingTrace != null) {
059                delegatingBindingTrace.addAllMyDataTo(context.trace);
060                JetType type = context.trace.get(EXPRESSION_TYPE, expression);
061                return DataFlowUtils.checkType(type, expression, context, context.dataFlowInfo);
062            }
063            final JetType[] result = new JetType[1];
064            final TemporaryBindingTrace temporaryTrace = TemporaryBindingTrace.create(context.trace, "trace to resolve object literal expression", expression);
065            ObservableBindingTrace.RecordHandler<PsiElement, ClassDescriptor> handler = new ObservableBindingTrace.RecordHandler<PsiElement, ClassDescriptor>() {
066    
067                @Override
068                public void handleRecord(WritableSlice<PsiElement, ClassDescriptor> slice, PsiElement declaration, final ClassDescriptor descriptor) {
069                    if (slice == CLASS && declaration == expression.getObjectDeclaration()) {
070                        JetType defaultType = DeferredType.create(context.trace, createRecursionIntolerantLazyValueWithDefault(
071                                ErrorUtils.createErrorType("Recursive dependency"),
072                                new Function0<JetType>() {
073                                    @Override
074                                    public JetType invoke() {
075                                        return descriptor.getDefaultType();
076                                    }
077                                }));
078                        result[0] = defaultType;
079                        if (!context.trace.get(PROCESSED, expression)) {
080                            temporaryTrace.record(EXPRESSION_TYPE, expression, defaultType);
081                            temporaryTrace.record(PROCESSED, expression);
082                        }
083                    }
084                }
085            };
086            ObservableBindingTrace traceAdapter = new ObservableBindingTrace(temporaryTrace);
087            traceAdapter.addHandler(CLASS, handler);
088            TopDownAnalyzer.processClassOrObject(context.replaceBindingTrace(traceAdapter).replaceContextDependency(INDEPENDENT),
089                                                 context.scope.getContainingDeclaration(),
090                                                 expression.getObjectDeclaration());
091    
092            DelegatingBindingTrace cloneDelta = new DelegatingBindingTrace(
093                    new BindingTraceContext().getBindingContext(), "cached delta trace for object literal expression resolve", expression);
094            temporaryTrace.addAllMyDataTo(cloneDelta);
095            context.trace.record(TRACE_DELTAS_CACHE, expression.getObjectDeclaration(), cloneDelta);
096            temporaryTrace.commit();
097            return DataFlowUtils.checkType(result[0], expression, context, context.dataFlowInfo);
098        }
099    
100        @Override
101        public JetTypeInfo visitFunctionLiteralExpression(@NotNull JetFunctionLiteralExpression expression, ExpressionTypingContext context) {
102            JetBlockExpression bodyExpression = expression.getFunctionLiteral().getBodyExpression();
103            if (bodyExpression == null) return null;
104    
105            Name callerName = getCallerName(expression);
106            if (callerName != null) {
107                context.labelResolver.enterLabeledElement(new LabelName(callerName.asString()), expression);
108            }
109    
110            JetType expectedType = context.expectedType;
111            boolean functionTypeExpected = !noExpectedType(expectedType) && KotlinBuiltIns.getInstance().isFunctionOrExtensionFunctionType(
112                    expectedType);
113    
114            AnonymousFunctionDescriptor functionDescriptor = createFunctionDescriptor(expression, context, functionTypeExpected);
115            JetType safeReturnType = computeReturnType(expression, context, functionDescriptor, functionTypeExpected);
116            functionDescriptor.setReturnType(safeReturnType);
117    
118            JetType receiver = DescriptorUtils.getReceiverParameterType(functionDescriptor.getReceiverParameter());
119            List<JetType> valueParametersTypes = DescriptorUtils.getValueParametersTypes(functionDescriptor.getValueParameters());
120            JetType resultType = KotlinBuiltIns.getInstance().getFunctionType(
121                    Collections.<AnnotationDescriptor>emptyList(), receiver, valueParametersTypes, safeReturnType);
122            if (!noExpectedType(expectedType) && KotlinBuiltIns.getInstance().isFunctionOrExtensionFunctionType(expectedType)) {
123                // all checks were done before
124                return JetTypeInfo.create(resultType, context.dataFlowInfo);
125            }
126    
127            if (callerName != null) {
128                context.labelResolver.exitLabeledElement(expression);
129            }
130    
131            return DataFlowUtils.checkType(resultType, expression, context, context.dataFlowInfo);
132        }
133    
134        @Nullable
135        private static Name getCallerName(@NotNull JetFunctionLiteralExpression expression) {
136            JetCallExpression callExpression = getContainingCallExpression(expression);
137            if (callExpression == null) return null;
138    
139            JetExpression calleeExpression = callExpression.getCalleeExpression();
140            if (calleeExpression instanceof JetSimpleNameExpression) {
141                JetSimpleNameExpression nameExpression = (JetSimpleNameExpression) calleeExpression;
142                return nameExpression.getReferencedNameAsName();
143            }
144    
145            return null;
146        }
147    
148        @Nullable
149        private static JetCallExpression getContainingCallExpression(JetFunctionLiteralExpression expression) {
150            PsiElement parent = expression.getParent();
151            if (parent instanceof JetCallExpression) {
152                // f {}
153                return (JetCallExpression) parent;
154            }
155    
156            if (parent instanceof JetValueArgument) {
157                // f ({}) or f(p = {})
158                JetValueArgument argument = (JetValueArgument) parent;
159                PsiElement argList = argument.getParent();
160                if (argList == null) return null;
161                PsiElement call = argList.getParent();
162                if (call instanceof JetCallExpression) {
163                    return (JetCallExpression) call;
164                }
165            }
166            return null;
167        }
168    
169        @NotNull
170        private static AnonymousFunctionDescriptor createFunctionDescriptor(
171                @NotNull JetFunctionLiteralExpression expression,
172                @NotNull ExpressionTypingContext context,
173                boolean functionTypeExpected
174        ) {
175            JetFunctionLiteral functionLiteral = expression.getFunctionLiteral();
176            JetTypeReference receiverTypeRef = functionLiteral.getReceiverTypeRef();
177            AnonymousFunctionDescriptor functionDescriptor = new AnonymousFunctionDescriptor(
178                    context.scope.getContainingDeclaration(), Collections.<AnnotationDescriptor>emptyList(), CallableMemberDescriptor.Kind.DECLARATION);
179    
180            List<ValueParameterDescriptor> valueParameterDescriptors = createValueParameterDescriptors(context, functionLiteral,
181                                                                                                       functionDescriptor, functionTypeExpected);
182    
183            JetType effectiveReceiverType;
184            if (receiverTypeRef == null) {
185                if (functionTypeExpected) {
186                    effectiveReceiverType = KotlinBuiltIns.getInstance().getReceiverType(context.expectedType);
187                }
188                else {
189                    effectiveReceiverType = null;
190                }
191            }
192            else {
193                effectiveReceiverType = context.expressionTypingServices.getTypeResolver().resolveType(context.scope, receiverTypeRef, context.trace, true);
194            }
195            functionDescriptor.initialize(effectiveReceiverType,
196                                          ReceiverParameterDescriptor.NO_RECEIVER_PARAMETER,
197                                          Collections.<TypeParameterDescriptorImpl>emptyList(),
198                                          valueParameterDescriptors,
199                                          /*unsubstitutedReturnType = */ null,
200                                          Modality.FINAL,
201                                          Visibilities.LOCAL
202            );
203            BindingContextUtils.recordFunctionDeclarationToDescriptor(context.trace, functionLiteral, functionDescriptor);
204            return functionDescriptor;
205        }
206    
207        @NotNull
208        private static List<ValueParameterDescriptor> createValueParameterDescriptors(
209                @NotNull ExpressionTypingContext context,
210                @NotNull JetFunctionLiteral functionLiteral,
211                @NotNull FunctionDescriptorImpl functionDescriptor,
212                boolean functionTypeExpected
213        ) {
214            List<ValueParameterDescriptor> valueParameterDescriptors = Lists.newArrayList();
215            List<JetParameter> declaredValueParameters = functionLiteral.getValueParameters();
216    
217            List<ValueParameterDescriptor> expectedValueParameters =  (functionTypeExpected)
218                                                              ? KotlinBuiltIns.getInstance().getValueParameters(functionDescriptor, context.expectedType)
219                                                              : null;
220    
221            JetParameterList valueParameterList = functionLiteral.getValueParameterList();
222            boolean hasDeclaredValueParameters = valueParameterList != null;
223            if (functionTypeExpected && !hasDeclaredValueParameters && expectedValueParameters.size() == 1) {
224                ValueParameterDescriptor valueParameterDescriptor = expectedValueParameters.get(0);
225                ValueParameterDescriptor it = new ValueParameterDescriptorImpl(
226                        functionDescriptor, 0, Collections.<AnnotationDescriptor>emptyList(), Name.identifier("it"),
227                        valueParameterDescriptor.getType(), valueParameterDescriptor.hasDefaultValue(), valueParameterDescriptor.getVarargElementType()
228                );
229                valueParameterDescriptors.add(it);
230                context.trace.record(AUTO_CREATED_IT, it);
231            }
232            else {
233                if (expectedValueParameters != null && declaredValueParameters.size() != expectedValueParameters.size()) {
234                    List<JetType> expectedParameterTypes = DescriptorUtils.getValueParametersTypes(expectedValueParameters);
235                    context.trace.report(EXPECTED_PARAMETERS_NUMBER_MISMATCH.on(functionLiteral, expectedParameterTypes.size(), expectedParameterTypes));
236                }
237                for (int i = 0; i < declaredValueParameters.size(); i++) {
238                    ValueParameterDescriptor valueParameterDescriptor = createValueParameterDescriptor(
239                            context, functionDescriptor, declaredValueParameters, expectedValueParameters, i);
240                    valueParameterDescriptors.add(valueParameterDescriptor);
241                }
242            }
243            return valueParameterDescriptors;
244        }
245    
246        @NotNull
247        private static ValueParameterDescriptor createValueParameterDescriptor(
248                @NotNull ExpressionTypingContext context,
249                @NotNull FunctionDescriptorImpl functionDescriptor,
250                @NotNull List<JetParameter> declaredValueParameters,
251                @Nullable List<ValueParameterDescriptor> expectedValueParameters,
252                int index
253        ) {
254            JetParameter declaredParameter = declaredValueParameters.get(index);
255            JetTypeReference typeReference = declaredParameter.getTypeReference();
256    
257            JetType expectedType;
258            if (expectedValueParameters != null && index < expectedValueParameters.size()) {
259                expectedType = expectedValueParameters.get(index).getType();
260            }
261            else {
262                expectedType = null;
263            }
264            JetType type;
265            if (typeReference != null) {
266                type = context.expressionTypingServices.getTypeResolver().resolveType(context.scope, typeReference, context.trace, true);
267                if (expectedType != null) {
268                    if (!JetTypeChecker.INSTANCE.isSubtypeOf(expectedType, type)) {
269                        context.trace.report(EXPECTED_PARAMETER_TYPE_MISMATCH.on(declaredParameter, expectedType));
270                    }
271                }
272            }
273            else {
274                if (expectedType == null || expectedType == DONT_CARE || expectedType == CANT_INFER_TYPE_PARAMETER) {
275                    context.trace.report(CANNOT_INFER_PARAMETER_TYPE.on(declaredParameter));
276                }
277                if (expectedType != null) {
278                    type = expectedType;
279                }
280                else {
281                    type = CANT_INFER_LAMBDA_PARAM_TYPE;
282                }
283            }
284            return context.expressionTypingServices.getDescriptorResolver().resolveValueParameterDescriptorWithAnnotationArguments(
285                    context.scope, functionDescriptor, declaredParameter, index, type, context.trace);
286        }
287    
288        @NotNull
289        private static JetType computeReturnType(
290                @NotNull JetFunctionLiteralExpression expression,
291                @NotNull ExpressionTypingContext context,
292                @NotNull SimpleFunctionDescriptorImpl functionDescriptor,
293                boolean functionTypeExpected
294        ) {
295            JetType expectedReturnType = functionTypeExpected ? KotlinBuiltIns.getInstance().getReturnTypeFromFunctionType(context.expectedType) : null;
296            JetType returnType = computeUnsafeReturnType(expression, context, functionDescriptor, expectedReturnType);
297    
298            if (!expression.getFunctionLiteral().hasDeclaredReturnType() && functionTypeExpected) {
299                if (KotlinBuiltIns.getInstance().isUnit(expectedReturnType)) {
300                    return KotlinBuiltIns.getInstance().getUnitType();
301                }
302            }
303            return returnType == null ? CANT_INFER_LAMBDA_PARAM_TYPE : returnType;
304        }
305    
306        @Nullable
307        private static JetType computeUnsafeReturnType(
308                @NotNull JetFunctionLiteralExpression expression,
309                @NotNull ExpressionTypingContext context,
310                @NotNull SimpleFunctionDescriptorImpl functionDescriptor,
311                @Nullable JetType expectedReturnType
312        ) {
313            JetFunctionLiteral functionLiteral = expression.getFunctionLiteral();
314            JetBlockExpression bodyExpression = functionLiteral.getBodyExpression();
315            assert bodyExpression != null;
316    
317            JetScope functionInnerScope = FunctionDescriptorUtil.getFunctionInnerScope(context.scope, functionDescriptor, context.trace);
318            JetTypeReference returnTypeRef = functionLiteral.getReturnTypeRef();
319            JetType declaredReturnType = null;
320            if (returnTypeRef != null) {
321                declaredReturnType = context.expressionTypingServices.getTypeResolver().resolveType(context.scope, returnTypeRef, context.trace, true);
322                // This is needed for ControlStructureTypingVisitor#visitReturnExpression() to properly type-check returned expressions
323                functionDescriptor.setReturnType(declaredReturnType);
324                if (expectedReturnType != null) {
325                    if (!JetTypeChecker.INSTANCE.isSubtypeOf(declaredReturnType, expectedReturnType)) {
326                        context.trace.report(EXPECTED_RETURN_TYPE_MISMATCH.on(returnTypeRef, expectedReturnType));
327                    }
328                }
329            }
330    
331            // Type-check the body
332            ExpressionTypingContext newContext = context.replaceScope(functionInnerScope)
333                    .replaceExpectedType(declaredReturnType != null
334                                         ? declaredReturnType
335                                         : (expectedReturnType != null ? expectedReturnType : NO_EXPECTED_TYPE));
336    
337            JetType typeOfBodyExpression = context.expressionTypingServices.getBlockReturnedType(bodyExpression, COERCION_TO_UNIT, newContext).getType();
338    
339            List<JetType> returnedExpressionTypes = Lists.newArrayList(getTypesOfLocallyReturnedExpressions(
340                    functionLiteral, context.trace, collectReturns(bodyExpression)));
341            ContainerUtil.addIfNotNull(returnedExpressionTypes, typeOfBodyExpression);
342    
343            if (declaredReturnType != null) return declaredReturnType;
344            if (returnedExpressionTypes.isEmpty()) return null;
345            return CommonSupertypes.commonSupertype(returnedExpressionTypes);
346        }
347    
348        private static List<JetType> getTypesOfLocallyReturnedExpressions(
349                final JetFunctionLiteral functionLiteral,
350                final BindingTrace trace,
351                Collection<JetReturnExpression> returnExpressions
352        ) {
353            return ContainerUtil.mapNotNull(returnExpressions, new Function<JetReturnExpression, JetType>() {
354                @Override
355                public JetType fun(JetReturnExpression returnExpression) {
356                    JetSimpleNameExpression label = returnExpression.getTargetLabel();
357                    if (label == null) {
358                        // No label => non-local return
359                        return null;
360                    }
361    
362                    PsiElement labelTarget = trace.get(BindingContext.LABEL_TARGET, label);
363                    if (labelTarget != functionLiteral) {
364                        // Either a local return of inner lambda/function or a non-local return
365                        return null;
366                    }
367    
368                    JetExpression returnedExpression = returnExpression.getReturnedExpression();
369                    if (returnedExpression == null) {
370                        return KotlinBuiltIns.getInstance().getUnitType();
371                    }
372                    JetType returnedType = trace.get(EXPRESSION_TYPE, returnedExpression);
373                    assert returnedType != null : "No type for returned expression: " + returnedExpression + ",\n" +
374                                                  "the type should have been computed by getBlockReturnedType() above";
375                    return returnedType;
376                }
377            });
378        }
379    
380        public static Collection<JetReturnExpression> collectReturns(@NotNull JetExpression expression) {
381            Collection<JetReturnExpression> result = Lists.newArrayList();
382            expression.accept(
383                    new JetTreeVisitor<Collection<JetReturnExpression>>() {
384                        @Override
385                        public Void visitReturnExpression(
386                                @NotNull JetReturnExpression expression, Collection<JetReturnExpression> data
387                        ) {
388                            data.add(expression);
389                            return null;
390                        }
391                    },
392                    result
393            );
394            return result;
395        }
396    }