001    /*
002     * Copyright 2010-2013 JetBrains s.r.o.
003     *
004     * Licensed under the Apache License, Version 2.0 (the "License");
005     * you may not use this file except in compliance with the License.
006     * You may obtain a copy of the License at
007     *
008     * http://www.apache.org/licenses/LICENSE-2.0
009     *
010     * Unless required by applicable law or agreed to in writing, software
011     * distributed under the License is distributed on an "AS IS" BASIS,
012     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013     * See the License for the specific language governing permissions and
014     * limitations under the License.
015     */
016    
017    package org.jetbrains.jet.lang.resolve.calls;
018    
019    import com.google.common.collect.Lists;
020    import org.jetbrains.annotations.NotNull;
021    import org.jetbrains.annotations.Nullable;
022    import org.jetbrains.jet.lang.descriptors.CallableDescriptor;
023    import org.jetbrains.jet.lang.descriptors.ValueParameterDescriptor;
024    import org.jetbrains.jet.lang.descriptors.annotations.AnnotationDescriptor;
025    import org.jetbrains.jet.lang.diagnostics.Errors;
026    import org.jetbrains.jet.lang.psi.*;
027    import org.jetbrains.jet.lang.resolve.*;
028    import org.jetbrains.jet.lang.resolve.calls.context.CallResolutionContext;
029    import org.jetbrains.jet.lang.resolve.calls.context.CheckValueArgumentsMode;
030    import org.jetbrains.jet.lang.resolve.calls.context.ResolutionContext;
031    import org.jetbrains.jet.lang.resolve.calls.model.MutableDataFlowInfoForArguments;
032    import org.jetbrains.jet.lang.resolve.calls.model.ResolvedCallImpl;
033    import org.jetbrains.jet.lang.resolve.calls.model.ResolvedValueArgument;
034    import org.jetbrains.jet.lang.resolve.constants.CompileTimeConstant;
035    import org.jetbrains.jet.lang.resolve.constants.CompileTimeConstantResolver;
036    import org.jetbrains.jet.lang.resolve.constants.ErrorValue;
037    import org.jetbrains.jet.lang.resolve.constants.NumberValueTypeConstructor;
038    import org.jetbrains.jet.lang.resolve.scopes.JetScope;
039    import org.jetbrains.jet.lang.types.JetType;
040    import org.jetbrains.jet.lang.types.JetTypeInfo;
041    import org.jetbrains.jet.lang.types.TypeUtils;
042    import org.jetbrains.jet.lang.types.checker.JetTypeChecker;
043    import org.jetbrains.jet.lang.types.expressions.ExpressionTypingServices;
044    import org.jetbrains.jet.lang.types.lang.KotlinBuiltIns;
045    
046    import javax.inject.Inject;
047    import java.util.Collections;
048    import java.util.List;
049    import java.util.Map;
050    import java.util.Set;
051    
052    import static org.jetbrains.jet.lang.resolve.BindingContextUtils.getRecordedTypeInfo;
053    import static org.jetbrains.jet.lang.resolve.calls.CallResolverUtil.ResolveArgumentsMode;
054    import static org.jetbrains.jet.lang.resolve.calls.CallResolverUtil.ResolveArgumentsMode.RESOLVE_FUNCTION_ARGUMENTS;
055    import static org.jetbrains.jet.lang.resolve.calls.CallResolverUtil.ResolveArgumentsMode.SHAPE_FUNCTION_ARGUMENTS;
056    import static org.jetbrains.jet.lang.resolve.calls.context.ContextDependency.DEPENDENT;
057    import static org.jetbrains.jet.lang.types.TypeUtils.*;
058    
059    public class ArgumentTypeResolver {
060    
061        @NotNull
062        private TypeResolver typeResolver;
063        @NotNull
064        private ExpressionTypingServices expressionTypingServices;
065    
066        @Inject
067        public void setTypeResolver(@NotNull TypeResolver typeResolver) {
068            this.typeResolver = typeResolver;
069        }
070    
071        @Inject
072        public void setExpressionTypingServices(@NotNull ExpressionTypingServices expressionTypingServices) {
073            this.expressionTypingServices = expressionTypingServices;
074        }
075    
076        public static boolean isSubtypeOfForArgumentType(
077                @NotNull JetType actualType,
078                @NotNull JetType expectedType
079        ) {
080            if (actualType == PLACEHOLDER_FUNCTION_TYPE) {
081                return isFunctionOrErrorType(expectedType) || KotlinBuiltIns.getInstance().isAny(expectedType); //todo function type extends
082            }
083            return JetTypeChecker.INSTANCE.isSubtypeOf(actualType, expectedType);
084        }
085    
086        private static boolean isFunctionOrErrorType(@NotNull JetType supertype) {
087            return KotlinBuiltIns.getInstance().isFunctionOrExtensionFunctionType(supertype) || supertype.isError();
088        }
089    
090        public void checkTypesWithNoCallee(@NotNull CallResolutionContext<?> context) {
091            checkTypesWithNoCallee(context, SHAPE_FUNCTION_ARGUMENTS);
092        }
093    
094        public void checkTypesWithNoCallee(@NotNull CallResolutionContext<?> context, @NotNull ResolveArgumentsMode resolveFunctionArgumentBodies) {
095            if (context.checkArguments == CheckValueArgumentsMode.DISABLED) return;
096    
097            for (ValueArgument valueArgument : context.call.getValueArguments()) {
098                JetExpression argumentExpression = valueArgument.getArgumentExpression();
099                if (argumentExpression != null && !(argumentExpression instanceof JetFunctionLiteralExpression)) {
100                    checkArgumentType(context, argumentExpression);
101                }
102            }
103    
104            if (resolveFunctionArgumentBodies == RESOLVE_FUNCTION_ARGUMENTS) {
105                checkTypesForFunctionArgumentsWithNoCallee(context);
106            }
107    
108            for (JetTypeProjection typeProjection : context.call.getTypeArguments()) {
109                JetTypeReference typeReference = typeProjection.getTypeReference();
110                if (typeReference == null) {
111                    context.trace.report(Errors.PROJECTION_ON_NON_CLASS_TYPE_ARGUMENT.on(typeProjection));
112                }
113                else {
114                    typeResolver.resolveType(context.scope, typeReference, context.trace, true);
115                }
116            }
117        }
118    
119        public void checkTypesForFunctionArgumentsWithNoCallee(@NotNull CallResolutionContext<?> context) {
120            if (context.checkArguments == CheckValueArgumentsMode.DISABLED) return;
121    
122            for (ValueArgument valueArgument : context.call.getValueArguments()) {
123                JetExpression argumentExpression = valueArgument.getArgumentExpression();
124                if (argumentExpression != null && (argumentExpression instanceof JetFunctionLiteralExpression)) {
125                    checkArgumentType(context, argumentExpression);
126                }
127            }
128    
129            for (JetExpression expression : context.call.getFunctionLiteralArguments()) {
130                checkArgumentType(context, expression);
131            }
132        }
133    
134        public void checkUnmappedArgumentTypes(CallResolutionContext<?> context, Set<ValueArgument> unmappedArguments) {
135            for (ValueArgument valueArgument : unmappedArguments) {
136                JetExpression argumentExpression = valueArgument.getArgumentExpression();
137                if (argumentExpression != null) {
138                    checkArgumentType(context, argumentExpression);
139                }
140            }
141        }
142    
143        private void checkArgumentType(CallResolutionContext<?> context, JetExpression argumentExpression) {
144            expressionTypingServices.getType(context.scope, argumentExpression, NO_EXPECTED_TYPE, context.dataFlowInfo, context.trace);
145            updateResultArgumentTypeIfNotDenotable(context, argumentExpression);
146        }
147    
148        public <D extends CallableDescriptor> void checkTypesForFunctionArguments(CallResolutionContext<?> context, ResolvedCallImpl<D> resolvedCall) {
149            Map<ValueParameterDescriptor, ResolvedValueArgument> arguments = resolvedCall.getValueArguments();
150            for (Map.Entry<ValueParameterDescriptor, ResolvedValueArgument> entry : arguments.entrySet()) {
151                ValueParameterDescriptor valueParameterDescriptor = entry.getKey();
152                JetType varargElementType = valueParameterDescriptor.getVarargElementType();
153                JetType functionType;
154                if (varargElementType != null) {
155                    functionType = varargElementType;
156                }
157                else {
158                    functionType = valueParameterDescriptor.getType();
159                }
160                ResolvedValueArgument valueArgument = entry.getValue();
161                List<ValueArgument> valueArguments = valueArgument.getArguments();
162                for (ValueArgument argument : valueArguments) {
163                    JetExpression expression = argument.getArgumentExpression();
164                    if (expression instanceof JetFunctionLiteralExpression) {
165                        expressionTypingServices.getType(context.scope, expression, functionType, context.dataFlowInfo, context.trace);
166                    }
167                }
168            }
169        }
170    
171        public static boolean isFunctionLiteralArgument(@NotNull JetExpression expression) {
172            return getFunctionLiteralArgumentIfAny(expression) != null;
173        }
174    
175        @NotNull
176        public static JetFunctionLiteralExpression getFunctionLiteralArgument(@NotNull JetExpression expression) {
177            assert isFunctionLiteralArgument(expression);
178            //noinspection ConstantConditions
179            return getFunctionLiteralArgumentIfAny(expression);
180        }
181    
182        @Nullable
183        private static JetFunctionLiteralExpression getFunctionLiteralArgumentIfAny(@NotNull JetExpression expression) {
184            JetExpression deparenthesizedExpression = JetPsiUtil.deparenthesize(expression, false);
185            if (deparenthesizedExpression instanceof JetBlockExpression) {
186                // todo
187                // This case is a temporary hack for 'if' branches.
188                // The right way to implement this logic is to interpret 'if' branches as function literals with explicitly-typed signatures
189                // (no arguments and no receiver) and therefore analyze them straight away (not in the 'complete' phase).
190                JetElement lastStatementInABlock = JetPsiUtil.getLastStatementInABlock((JetBlockExpression) deparenthesizedExpression);
191                if (lastStatementInABlock instanceof JetExpression) {
192                    deparenthesizedExpression = JetPsiUtil.deparenthesize((JetExpression) lastStatementInABlock, false);
193                }
194            }
195            if (deparenthesizedExpression instanceof JetFunctionLiteralExpression) {
196                return (JetFunctionLiteralExpression) deparenthesizedExpression;
197            }
198            return null;
199        }
200    
201        @NotNull
202        public JetTypeInfo getArgumentTypeInfo(
203                @Nullable JetExpression expression,
204                @NotNull CallResolutionContext<?> context,
205                @NotNull ResolveArgumentsMode resolveArgumentsMode
206        ) {
207            if (expression == null) {
208                return JetTypeInfo.create(null, context.dataFlowInfo);
209            }
210            if (isFunctionLiteralArgument(expression)) {
211                return getFunctionLiteralTypeInfo(expression, getFunctionLiteralArgument(expression), context, resolveArgumentsMode);
212            }
213            JetTypeInfo recordedTypeInfo = getRecordedTypeInfo(expression, context.trace.getBindingContext());
214            if (recordedTypeInfo != null) {
215                return recordedTypeInfo;
216            }
217            ResolutionContext newContext = context.replaceExpectedType(NO_EXPECTED_TYPE).replaceContextDependency(DEPENDENT);
218    
219            return expressionTypingServices.getTypeInfo(expression, newContext);
220        }
221    
222        @NotNull
223        public JetTypeInfo getFunctionLiteralTypeInfo(
224                @NotNull JetExpression expression,
225                @NotNull JetFunctionLiteralExpression functionLiteralExpression,
226                @NotNull CallResolutionContext<?> context,
227                @NotNull ResolveArgumentsMode resolveArgumentsMode
228        ) {
229            if (resolveArgumentsMode == SHAPE_FUNCTION_ARGUMENTS) {
230                JetType type = getShapeTypeOfFunctionLiteral(functionLiteralExpression, context.scope, context.trace, true);
231                return JetTypeInfo.create(type, context.dataFlowInfo);
232            }
233            return expressionTypingServices.getTypeInfo(expression, context);
234        }
235    
236        @Nullable
237        public JetType getShapeTypeOfFunctionLiteral(
238                @NotNull JetFunctionLiteralExpression expression,
239                @NotNull JetScope scope,
240                @NotNull BindingTrace trace,
241                boolean expectedTypeIsUnknown
242        ) {
243            if (expression.getFunctionLiteral().getValueParameterList() == null) {
244                return expectedTypeIsUnknown ? PLACEHOLDER_FUNCTION_TYPE : KotlinBuiltIns.getInstance().getFunctionType(
245                        Collections.<AnnotationDescriptor>emptyList(), null, Collections.<JetType>emptyList(), DONT_CARE);
246            }
247            List<JetParameter> valueParameters = expression.getValueParameters();
248            TemporaryBindingTrace temporaryTrace = TemporaryBindingTrace.create(
249                    trace, "trace to resolve function literal parameter types");
250            List<JetType> parameterTypes = Lists.newArrayList();
251            for (JetParameter parameter : valueParameters) {
252                parameterTypes.add(resolveTypeRefWithDefault(parameter.getTypeReference(), scope, temporaryTrace, DONT_CARE));
253            }
254            JetFunctionLiteral functionLiteral = expression.getFunctionLiteral();
255            JetType returnType = resolveTypeRefWithDefault(functionLiteral.getReturnTypeRef(), scope, temporaryTrace, DONT_CARE);
256            assert returnType != null;
257            JetType receiverType = resolveTypeRefWithDefault(functionLiteral.getReceiverTypeRef(), scope, temporaryTrace, null);
258            return KotlinBuiltIns.getInstance().getFunctionType(Collections.<AnnotationDescriptor>emptyList(), receiverType, parameterTypes,
259                                                                returnType);
260        }
261    
262        @Nullable
263        public JetType resolveTypeRefWithDefault(
264                @Nullable JetTypeReference returnTypeRef,
265                @NotNull JetScope scope,
266                @NotNull BindingTrace trace,
267                @Nullable JetType defaultValue
268        ) {
269            if (returnTypeRef != null) {
270                return expressionTypingServices.getTypeResolver().resolveType(scope, returnTypeRef, trace, true);
271            }
272            return defaultValue;
273        }
274    
275        public <D extends CallableDescriptor> void analyzeArgumentsAndRecordTypes(
276                @NotNull CallResolutionContext<?> context
277        ) {
278            MutableDataFlowInfoForArguments infoForArguments = context.dataFlowInfoForArguments;
279            infoForArguments.setInitialDataFlowInfo(context.dataFlowInfo);
280    
281            for (ValueArgument argument : context.call.getValueArguments()) {
282                JetExpression expression = argument.getArgumentExpression();
283                if (expression == null) continue;
284    
285                CallResolutionContext<?> newContext = context.replaceDataFlowInfo(infoForArguments.getInfo(argument));
286                JetTypeInfo typeInfoForCall = getArgumentTypeInfo(expression, newContext, SHAPE_FUNCTION_ARGUMENTS);
287                infoForArguments.updateInfo(argument, typeInfoForCall.getDataFlowInfo());
288            }
289        }
290    
291        @Nullable
292        public <D extends CallableDescriptor> JetType updateResultArgumentTypeIfNotDenotable(
293                @NotNull ResolutionContext context,
294                @NotNull JetExpression expression
295        ) {
296            JetType type = context.trace.get(BindingContext.EXPRESSION_TYPE, expression);
297            if (type != null && !type.getConstructor().isDenotable()) {
298                if (type.getConstructor() instanceof NumberValueTypeConstructor) {
299                    NumberValueTypeConstructor constructor = (NumberValueTypeConstructor) type.getConstructor();
300                    JetType primitiveType = TypeUtils.getPrimitiveNumberType(constructor, context.expectedType);
301                    updateNumberType(primitiveType, expression, context);
302                    return primitiveType;
303                }
304            }
305            return type;
306        }
307    
308        private <D extends CallableDescriptor> void updateNumberType(
309                @NotNull JetType numberType,
310                @Nullable JetExpression expression,
311                @NotNull ResolutionContext context
312        ) {
313            if (expression == null) return;
314            BindingContextUtils.updateRecordedType(numberType, expression, context.trace, false);
315    
316            if (!(expression instanceof JetConstantExpression)) {
317                JetExpression deparenthesized = JetPsiUtil.deparenthesize(expression, false);
318                if (deparenthesized != expression) {
319                    updateNumberType(numberType, deparenthesized, context);
320                }
321                if (deparenthesized instanceof JetBlockExpression) {
322                    JetElement lastStatement = JetPsiUtil.getLastStatementInABlock((JetBlockExpression) deparenthesized);
323                    if (lastStatement instanceof JetExpression) {
324                        updateNumberType(numberType, (JetExpression) lastStatement, context);
325                    }
326                }
327                return;
328            }
329            CompileTimeConstant<?> constant =
330                    new CompileTimeConstantResolver().getCompileTimeConstant((JetConstantExpression) expression, numberType);
331    
332            if (!(constant instanceof ErrorValue)) {
333                context.trace.record(BindingContext.COMPILE_TIME_VALUE, expression, constant);
334            }
335        }
336    }