001    /*
002     * Copyright 2010-2013 JetBrains s.r.o.
003     *
004     * Licensed under the Apache License, Version 2.0 (the "License");
005     * you may not use this file except in compliance with the License.
006     * You may obtain a copy of the License at
007     *
008     * http://www.apache.org/licenses/LICENSE-2.0
009     *
010     * Unless required by applicable law or agreed to in writing, software
011     * distributed under the License is distributed on an "AS IS" BASIS,
012     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013     * See the License for the specific language governing permissions and
014     * limitations under the License.
015     */
016    
017    package org.jetbrains.jet.lang.types.expressions;
018    
019    import com.google.common.collect.Lists;
020    import com.intellij.psi.PsiElement;
021    import com.intellij.util.Function;
022    import com.intellij.util.containers.ContainerUtil;
023    import kotlin.Function0;
024    import org.jetbrains.annotations.NotNull;
025    import org.jetbrains.annotations.Nullable;
026    import org.jetbrains.jet.lang.descriptors.*;
027    import org.jetbrains.jet.lang.descriptors.annotations.Annotations;
028    import org.jetbrains.jet.lang.descriptors.impl.*;
029    import org.jetbrains.jet.lang.psi.*;
030    import org.jetbrains.jet.lang.resolve.*;
031    import org.jetbrains.jet.lang.resolve.name.Name;
032    import org.jetbrains.jet.lang.resolve.scopes.JetScope;
033    import org.jetbrains.jet.lang.types.*;
034    import org.jetbrains.jet.lang.types.checker.JetTypeChecker;
035    import org.jetbrains.jet.lang.types.lang.KotlinBuiltIns;
036    import org.jetbrains.jet.util.slicedmap.WritableSlice;
037    
038    import java.util.Collection;
039    import java.util.Collections;
040    import java.util.List;
041    
042    import static org.jetbrains.jet.lang.diagnostics.Errors.*;
043    import static org.jetbrains.jet.lang.resolve.BindingContext.*;
044    import static org.jetbrains.jet.lang.resolve.calls.context.ContextDependency.INDEPENDENT;
045    import static org.jetbrains.jet.lang.resolve.source.SourcePackage.toSourceElement;
046    import static org.jetbrains.jet.lang.types.TypeUtils.*;
047    import static org.jetbrains.jet.lang.types.expressions.CoercionStrategy.COERCION_TO_UNIT;
048    
049    public class ClosureExpressionsTypingVisitor extends ExpressionTypingVisitor {
050    
051        protected ClosureExpressionsTypingVisitor(@NotNull ExpressionTypingInternals facade) {
052            super(facade);
053        }
054    
055        @Override
056        public JetTypeInfo visitObjectLiteralExpression(@NotNull final JetObjectLiteralExpression expression, final ExpressionTypingContext context) {
057            DelegatingBindingTrace delegatingBindingTrace = context.trace.get(TRACE_DELTAS_CACHE, expression.getObjectDeclaration());
058            if (delegatingBindingTrace != null) {
059                delegatingBindingTrace.addAllMyDataTo(context.trace);
060                JetType type = context.trace.get(EXPRESSION_TYPE, expression);
061                return DataFlowUtils.checkType(type, expression, context, context.dataFlowInfo);
062            }
063            final JetType[] result = new JetType[1];
064            final TemporaryBindingTrace temporaryTrace = TemporaryBindingTrace.create(context.trace, "trace to resolve object literal expression", expression);
065            ObservableBindingTrace.RecordHandler<PsiElement, ClassDescriptor> handler = new ObservableBindingTrace.RecordHandler<PsiElement, ClassDescriptor>() {
066    
067                @Override
068                public void handleRecord(WritableSlice<PsiElement, ClassDescriptor> slice, PsiElement declaration, final ClassDescriptor descriptor) {
069                    if (slice == CLASS && declaration == expression.getObjectDeclaration()) {
070                        JetType defaultType = DeferredType.createRecursionIntolerant(components.globalContext.getStorageManager(),
071                                                                                     context.trace,
072                                                                                     new Function0<JetType>() {
073                                                                                         @Override
074                                                                                         public JetType invoke() {
075                                                                                             return descriptor.getDefaultType();
076                                                                                         }
077                                                                                     });
078                        result[0] = defaultType;
079                        if (!context.trace.get(PROCESSED, expression)) {
080                            temporaryTrace.record(EXPRESSION_TYPE, expression, defaultType);
081                            temporaryTrace.record(PROCESSED, expression);
082                        }
083                    }
084                }
085            };
086            ObservableBindingTrace traceAdapter = new ObservableBindingTrace(temporaryTrace);
087            traceAdapter.addHandler(CLASS, handler);
088            TopDownAnalyzer.processClassOrObject(components.globalContext,
089                                                 null, // don't need to add classifier of object literal to any scope
090                                                 context.replaceBindingTrace(traceAdapter).replaceContextDependency(INDEPENDENT),
091                                                 context.scope.getContainingDeclaration(),
092                                                 expression.getObjectDeclaration());
093    
094            DelegatingBindingTrace cloneDelta = new DelegatingBindingTrace(
095                    new BindingTraceContext().getBindingContext(), "cached delta trace for object literal expression resolve", expression);
096            temporaryTrace.addAllMyDataTo(cloneDelta);
097            context.trace.record(TRACE_DELTAS_CACHE, expression.getObjectDeclaration(), cloneDelta);
098            temporaryTrace.commit();
099            return DataFlowUtils.checkType(result[0], expression, context, context.dataFlowInfo);
100        }
101    
102        @Override
103        public JetTypeInfo visitFunctionLiteralExpression(@NotNull JetFunctionLiteralExpression expression, ExpressionTypingContext context) {
104            if (!expression.getFunctionLiteral().hasBody()) return null;
105    
106            JetType expectedType = context.expectedType;
107            boolean functionTypeExpected = !noExpectedType(expectedType) && KotlinBuiltIns.getInstance().isFunctionOrExtensionFunctionType(
108                    expectedType);
109    
110            AnonymousFunctionDescriptor functionDescriptor = createFunctionDescriptor(expression, context, functionTypeExpected);
111            JetType safeReturnType = computeReturnType(expression, context, functionDescriptor, functionTypeExpected);
112            functionDescriptor.setReturnType(safeReturnType);
113    
114            JetType receiver = DescriptorUtils.getReceiverParameterType(functionDescriptor.getReceiverParameter());
115            List<JetType> valueParametersTypes = ExpressionTypingUtils.getValueParametersTypes(functionDescriptor.getValueParameters());
116            JetType resultType = KotlinBuiltIns.getInstance().getFunctionType(
117                    Annotations.EMPTY, receiver, valueParametersTypes, safeReturnType);
118            if (!noExpectedType(expectedType) && KotlinBuiltIns.getInstance().isFunctionOrExtensionFunctionType(expectedType)) {
119                // all checks were done before
120                return JetTypeInfo.create(resultType, context.dataFlowInfo);
121            }
122    
123            return DataFlowUtils.checkType(resultType, expression, context, context.dataFlowInfo);
124        }
125    
126        @NotNull
127        private AnonymousFunctionDescriptor createFunctionDescriptor(
128                @NotNull JetFunctionLiteralExpression expression,
129                @NotNull ExpressionTypingContext context,
130                boolean functionTypeExpected
131        ) {
132            JetFunctionLiteral functionLiteral = expression.getFunctionLiteral();
133            JetTypeReference receiverTypeRef = functionLiteral.getReceiverTypeRef();
134            AnonymousFunctionDescriptor functionDescriptor = new AnonymousFunctionDescriptor(
135                    context.scope.getContainingDeclaration(), Annotations.EMPTY, CallableMemberDescriptor.Kind.DECLARATION,
136                    toSourceElement(functionLiteral)
137            );
138    
139            List<ValueParameterDescriptor> valueParameterDescriptors = createValueParameterDescriptors(context, functionLiteral,
140                                                                                                       functionDescriptor, functionTypeExpected);
141    
142            JetType effectiveReceiverType;
143            if (receiverTypeRef == null) {
144                if (functionTypeExpected) {
145                    effectiveReceiverType = KotlinBuiltIns.getInstance().getReceiverType(context.expectedType);
146                }
147                else {
148                    effectiveReceiverType = null;
149                }
150            }
151            else {
152                effectiveReceiverType = components.expressionTypingServices.getTypeResolver().resolveType(context.scope, receiverTypeRef, context.trace, true);
153            }
154            functionDescriptor.initialize(effectiveReceiverType,
155                                          ReceiverParameterDescriptor.NO_RECEIVER_PARAMETER,
156                                          Collections.<TypeParameterDescriptorImpl>emptyList(),
157                                          valueParameterDescriptors,
158                                          /*unsubstitutedReturnType = */ null,
159                                          Modality.FINAL,
160                                          Visibilities.LOCAL
161            );
162            BindingContextUtils.recordFunctionDeclarationToDescriptor(context.trace, functionLiteral, functionDescriptor);
163            return functionDescriptor;
164        }
165    
166        @NotNull
167        private List<ValueParameterDescriptor> createValueParameterDescriptors(
168                @NotNull ExpressionTypingContext context,
169                @NotNull JetFunctionLiteral functionLiteral,
170                @NotNull FunctionDescriptorImpl functionDescriptor,
171                boolean functionTypeExpected
172        ) {
173            List<ValueParameterDescriptor> valueParameterDescriptors = Lists.newArrayList();
174            List<JetParameter> declaredValueParameters = functionLiteral.getValueParameters();
175    
176            List<ValueParameterDescriptor> expectedValueParameters =  (functionTypeExpected)
177                                                              ? KotlinBuiltIns.getInstance().getValueParameters(functionDescriptor, context.expectedType)
178                                                              : null;
179    
180            JetParameterList valueParameterList = functionLiteral.getValueParameterList();
181            boolean hasDeclaredValueParameters = valueParameterList != null;
182            if (functionTypeExpected && !hasDeclaredValueParameters && expectedValueParameters.size() == 1) {
183                ValueParameterDescriptor valueParameterDescriptor = expectedValueParameters.get(0);
184                ValueParameterDescriptor it = new ValueParameterDescriptorImpl(
185                        functionDescriptor, null, 0, Annotations.EMPTY, Name.identifier("it"),
186                        valueParameterDescriptor.getType(), valueParameterDescriptor.hasDefaultValue(), valueParameterDescriptor.getVarargElementType(),
187                        SourceElement.NO_SOURCE
188                );
189                valueParameterDescriptors.add(it);
190                context.trace.record(AUTO_CREATED_IT, it);
191            }
192            else {
193                if (expectedValueParameters != null && declaredValueParameters.size() != expectedValueParameters.size()) {
194                    List<JetType> expectedParameterTypes = ExpressionTypingUtils.getValueParametersTypes(expectedValueParameters);
195                    context.trace.report(EXPECTED_PARAMETERS_NUMBER_MISMATCH.on(functionLiteral, expectedParameterTypes.size(), expectedParameterTypes));
196                }
197                for (int i = 0; i < declaredValueParameters.size(); i++) {
198                    ValueParameterDescriptor valueParameterDescriptor = createValueParameterDescriptor(
199                            context, functionDescriptor, declaredValueParameters, expectedValueParameters, i);
200                    valueParameterDescriptors.add(valueParameterDescriptor);
201                }
202            }
203            return valueParameterDescriptors;
204        }
205    
206        @NotNull
207        private ValueParameterDescriptor createValueParameterDescriptor(
208                @NotNull ExpressionTypingContext context,
209                @NotNull FunctionDescriptorImpl functionDescriptor,
210                @NotNull List<JetParameter> declaredValueParameters,
211                @Nullable List<ValueParameterDescriptor> expectedValueParameters,
212                int index
213        ) {
214            JetParameter declaredParameter = declaredValueParameters.get(index);
215            JetTypeReference typeReference = declaredParameter.getTypeReference();
216    
217            JetType expectedType;
218            if (expectedValueParameters != null && index < expectedValueParameters.size()) {
219                expectedType = expectedValueParameters.get(index).getType();
220            }
221            else {
222                expectedType = null;
223            }
224            JetType type;
225            if (typeReference != null) {
226                type = components.expressionTypingServices.getTypeResolver().resolveType(context.scope, typeReference, context.trace, true);
227                if (expectedType != null) {
228                    if (!JetTypeChecker.DEFAULT.isSubtypeOf(expectedType, type)) {
229                        context.trace.report(EXPECTED_PARAMETER_TYPE_MISMATCH.on(declaredParameter, expectedType));
230                    }
231                }
232            }
233            else {
234                if (expectedType == null || expectedType == DONT_CARE || ErrorUtils.isUninferredParameter(expectedType)) {
235                    context.trace.report(CANNOT_INFER_PARAMETER_TYPE.on(declaredParameter));
236                }
237                if (expectedType != null) {
238                    type = expectedType;
239                }
240                else {
241                    type = CANT_INFER_LAMBDA_PARAM_TYPE;
242                }
243            }
244            return components.expressionTypingServices.getDescriptorResolver().resolveValueParameterDescriptorWithAnnotationArguments(
245                    context.scope, functionDescriptor, declaredParameter, index, type, context.trace);
246        }
247    
248        @NotNull
249        private JetType computeReturnType(
250                @NotNull JetFunctionLiteralExpression expression,
251                @NotNull ExpressionTypingContext context,
252                @NotNull SimpleFunctionDescriptorImpl functionDescriptor,
253                boolean functionTypeExpected
254        ) {
255            JetType expectedReturnType = functionTypeExpected ? KotlinBuiltIns.getInstance().getReturnTypeFromFunctionType(context.expectedType) : null;
256            JetType returnType = computeUnsafeReturnType(expression, context, functionDescriptor, expectedReturnType);
257    
258            if (!expression.getFunctionLiteral().hasDeclaredReturnType() && functionTypeExpected) {
259                if (KotlinBuiltIns.getInstance().isUnit(expectedReturnType)) {
260                    return KotlinBuiltIns.getInstance().getUnitType();
261                }
262            }
263            return returnType == null ? CANT_INFER_LAMBDA_PARAM_TYPE : returnType;
264        }
265    
266        @Nullable
267        private JetType computeUnsafeReturnType(
268                @NotNull JetFunctionLiteralExpression expression,
269                @NotNull ExpressionTypingContext context,
270                @NotNull SimpleFunctionDescriptorImpl functionDescriptor,
271                @Nullable JetType expectedReturnType
272        ) {
273            JetFunctionLiteral functionLiteral = expression.getFunctionLiteral();
274            JetBlockExpression bodyExpression = functionLiteral.getBodyExpression();
275            assert bodyExpression != null;
276    
277            JetScope functionInnerScope = FunctionDescriptorUtil.getFunctionInnerScope(context.scope, functionDescriptor, context.trace);
278            JetTypeReference returnTypeRef = functionLiteral.getReturnTypeRef();
279            JetType declaredReturnType = null;
280            if (returnTypeRef != null) {
281                declaredReturnType = components.expressionTypingServices.getTypeResolver().resolveType(context.scope, returnTypeRef, context.trace, true);
282                // This is needed for ControlStructureTypingVisitor#visitReturnExpression() to properly type-check returned expressions
283                functionDescriptor.setReturnType(declaredReturnType);
284                if (expectedReturnType != null) {
285                    if (!JetTypeChecker.DEFAULT.isSubtypeOf(declaredReturnType, expectedReturnType)) {
286                        context.trace.report(EXPECTED_RETURN_TYPE_MISMATCH.on(returnTypeRef, expectedReturnType));
287                    }
288                }
289            }
290    
291            // Type-check the body
292            ExpressionTypingContext newContext = context.replaceScope(functionInnerScope)
293                    .replaceExpectedType(declaredReturnType != null
294                                         ? declaredReturnType
295                                         : (expectedReturnType != null ? expectedReturnType : NO_EXPECTED_TYPE));
296    
297            JetType typeOfBodyExpression = components.expressionTypingServices.getBlockReturnedType(bodyExpression, COERCION_TO_UNIT, newContext).getType();
298    
299            List<JetType> returnedExpressionTypes = Lists.newArrayList(getTypesOfLocallyReturnedExpressions(
300                    functionLiteral, context.trace, collectReturns(bodyExpression)));
301            ContainerUtil.addIfNotNull(returnedExpressionTypes, typeOfBodyExpression);
302    
303            if (declaredReturnType != null) return declaredReturnType;
304            if (returnedExpressionTypes.isEmpty()) return null;
305            return CommonSupertypes.commonSupertype(returnedExpressionTypes);
306        }
307    
308        private static List<JetType> getTypesOfLocallyReturnedExpressions(
309                final JetFunctionLiteral functionLiteral,
310                final BindingTrace trace,
311                Collection<JetReturnExpression> returnExpressions
312        ) {
313            return ContainerUtil.mapNotNull(returnExpressions, new Function<JetReturnExpression, JetType>() {
314                @Override
315                public JetType fun(JetReturnExpression returnExpression) {
316                    JetSimpleNameExpression label = returnExpression.getTargetLabel();
317                    if (label == null) {
318                        // No label => non-local return
319                        return null;
320                    }
321    
322                    PsiElement labelTarget = trace.get(BindingContext.LABEL_TARGET, label);
323                    if (labelTarget != functionLiteral) {
324                        // Either a local return of inner lambda/function or a non-local return
325                        return null;
326                    }
327    
328                    JetExpression returnedExpression = returnExpression.getReturnedExpression();
329                    if (returnedExpression == null) {
330                        return KotlinBuiltIns.getInstance().getUnitType();
331                    }
332                    JetType returnedType = trace.get(EXPRESSION_TYPE, returnedExpression);
333                    assert returnedType != null : "No type for returned expression: " + returnedExpression + ",\n" +
334                                                  "the type should have been computed by getBlockReturnedType() above";
335                    return returnedType;
336                }
337            });
338        }
339    
340        public static Collection<JetReturnExpression> collectReturns(@NotNull JetExpression expression) {
341            Collection<JetReturnExpression> result = Lists.newArrayList();
342            expression.accept(
343                    new JetTreeVisitor<Collection<JetReturnExpression>>() {
344                        @Override
345                        public Void visitReturnExpression(
346                                @NotNull JetReturnExpression expression, Collection<JetReturnExpression> data
347                        ) {
348                            data.add(expression);
349                            return null;
350                        }
351                    },
352                    result
353            );
354            return result;
355        }
356    }