001    /*
002     * Copyright 2010-2013 JetBrains s.r.o.
003     *
004     * Licensed under the Apache License, Version 2.0 (the "License");
005     * you may not use this file except in compliance with the License.
006     * You may obtain a copy of the License at
007     *
008     * http://www.apache.org/licenses/LICENSE-2.0
009     *
010     * Unless required by applicable law or agreed to in writing, software
011     * distributed under the License is distributed on an "AS IS" BASIS,
012     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013     * See the License for the specific language governing permissions and
014     * limitations under the License.
015     */
016    
017    package org.jetbrains.jet.lang.resolve.calls;
018    
019    import com.google.common.collect.Lists;
020    import org.jetbrains.annotations.NotNull;
021    import org.jetbrains.annotations.Nullable;
022    import org.jetbrains.jet.lang.descriptors.CallableDescriptor;
023    import org.jetbrains.jet.lang.descriptors.ValueParameterDescriptor;
024    import org.jetbrains.jet.lang.descriptors.annotations.AnnotationDescriptor;
025    import org.jetbrains.jet.lang.diagnostics.Errors;
026    import org.jetbrains.jet.lang.psi.*;
027    import org.jetbrains.jet.lang.resolve.*;
028    import org.jetbrains.jet.lang.resolve.calls.context.*;
029    import org.jetbrains.jet.lang.resolve.calls.model.MutableDataFlowInfoForArguments;
030    import org.jetbrains.jet.lang.resolve.calls.model.ResolvedCallImpl;
031    import org.jetbrains.jet.lang.resolve.calls.model.ResolvedValueArgument;
032    import org.jetbrains.jet.lang.resolve.constants.CompileTimeConstant;
033    import org.jetbrains.jet.lang.resolve.constants.CompileTimeConstantResolver;
034    import org.jetbrains.jet.lang.resolve.constants.ErrorValue;
035    import org.jetbrains.jet.lang.resolve.constants.NumberValueTypeConstructor;
036    import org.jetbrains.jet.lang.resolve.scopes.JetScope;
037    import org.jetbrains.jet.lang.types.*;
038    import org.jetbrains.jet.lang.types.checker.JetTypeChecker;
039    import org.jetbrains.jet.lang.types.expressions.ExpressionTypingServices;
040    import org.jetbrains.jet.lang.types.lang.KotlinBuiltIns;
041    
042    import javax.inject.Inject;
043    import java.util.Collections;
044    import java.util.List;
045    import java.util.Map;
046    import java.util.Set;
047    
048    import static org.jetbrains.jet.lang.resolve.BindingContextUtils.getRecordedTypeInfo;
049    import static org.jetbrains.jet.lang.resolve.calls.CallResolverUtil.*;
050    import static org.jetbrains.jet.lang.resolve.calls.CallResolverUtil.ResolveArgumentsMode.RESOLVE_FUNCTION_ARGUMENTS;
051    import static org.jetbrains.jet.lang.resolve.calls.CallResolverUtil.ResolveArgumentsMode.SKIP_FUNCTION_ARGUMENTS;
052    import static org.jetbrains.jet.lang.types.TypeUtils.NO_EXPECTED_TYPE;
053    
054    public class ArgumentTypeResolver {
055    
056        @NotNull
057        private TypeResolver typeResolver;
058        @NotNull
059        private ExpressionTypingServices expressionTypingServices;
060    
061        @Inject
062        public void setTypeResolver(@NotNull TypeResolver typeResolver) {
063            this.typeResolver = typeResolver;
064        }
065    
066        @Inject
067        public void setExpressionTypingServices(@NotNull ExpressionTypingServices expressionTypingServices) {
068            this.expressionTypingServices = expressionTypingServices;
069        }
070    
071        public static boolean isSubtypeOfForArgumentType(
072                @NotNull JetType actualType,
073                @NotNull JetType expectedType
074        ) {
075            if (actualType == PLACEHOLDER_FUNCTION_TYPE) {
076                return isFunctionOrErrorType(expectedType) || KotlinBuiltIns.getInstance().isAny(expectedType); //todo function type extends
077            }
078            return JetTypeChecker.INSTANCE.isSubtypeOf(actualType, expectedType);
079        }
080    
081        private static boolean isFunctionOrErrorType(@NotNull JetType supertype) {
082            return KotlinBuiltIns.getInstance().isFunctionOrExtensionFunctionType(supertype) || ErrorUtils.isErrorType(supertype);
083        }
084    
085        public void checkTypesWithNoCallee(@NotNull CallResolutionContext<?> context) {
086            checkTypesWithNoCallee(context, SKIP_FUNCTION_ARGUMENTS);
087        }
088    
089        public void checkTypesWithNoCallee(@NotNull CallResolutionContext<?> context, @NotNull ResolveArgumentsMode resolveFunctionArgumentBodies) {
090            if (context.checkArguments == CheckValueArgumentsMode.DISABLED) return;
091    
092            for (ValueArgument valueArgument : context.call.getValueArguments()) {
093                JetExpression argumentExpression = valueArgument.getArgumentExpression();
094                if (argumentExpression != null && !(argumentExpression instanceof JetFunctionLiteralExpression)) {
095                    checkArgumentType(context, argumentExpression);
096                }
097            }
098    
099            if (resolveFunctionArgumentBodies == RESOLVE_FUNCTION_ARGUMENTS) {
100                checkTypesForFunctionArgumentsWithNoCallee(context);
101            }
102    
103            for (JetTypeProjection typeProjection : context.call.getTypeArguments()) {
104                JetTypeReference typeReference = typeProjection.getTypeReference();
105                if (typeReference == null) {
106                    context.trace.report(Errors.PROJECTION_ON_NON_CLASS_TYPE_ARGUMENT.on(typeProjection));
107                }
108                else {
109                    typeResolver.resolveType(context.scope, typeReference, context.trace, true);
110                }
111            }
112        }
113    
114        public void checkTypesForFunctionArgumentsWithNoCallee(@NotNull CallResolutionContext<?> context) {
115            if (context.checkArguments == CheckValueArgumentsMode.DISABLED) return;
116    
117            for (ValueArgument valueArgument : context.call.getValueArguments()) {
118                JetExpression argumentExpression = valueArgument.getArgumentExpression();
119                if (argumentExpression != null && (argumentExpression instanceof JetFunctionLiteralExpression)) {
120                    checkArgumentType(context, argumentExpression);
121                }
122            }
123    
124            for (JetExpression expression : context.call.getFunctionLiteralArguments()) {
125                checkArgumentType(context, expression);
126            }
127        }
128    
129        public void checkUnmappedArgumentTypes(CallResolutionContext<?> context, Set<ValueArgument> unmappedArguments) {
130            for (ValueArgument valueArgument : unmappedArguments) {
131                JetExpression argumentExpression = valueArgument.getArgumentExpression();
132                if (argumentExpression != null) {
133                    checkArgumentType(context, argumentExpression);
134                }
135            }
136        }
137    
138        private void checkArgumentType(CallResolutionContext<?> context, JetExpression argumentExpression) {
139            expressionTypingServices.getType(context.scope, argumentExpression, NO_EXPECTED_TYPE, context.dataFlowInfo, context.trace);
140            updateResultArgumentTypeIfNotDenotable(context, argumentExpression);
141        }
142    
143        public <D extends CallableDescriptor> void checkTypesForFunctionArguments(CallResolutionContext<?> context, ResolvedCallImpl<D> resolvedCall) {
144            Map<ValueParameterDescriptor, ResolvedValueArgument> arguments = resolvedCall.getValueArguments();
145            for (Map.Entry<ValueParameterDescriptor, ResolvedValueArgument> entry : arguments.entrySet()) {
146                ValueParameterDescriptor valueParameterDescriptor = entry.getKey();
147                JetType varargElementType = valueParameterDescriptor.getVarargElementType();
148                JetType functionType;
149                if (varargElementType != null) {
150                    functionType = varargElementType;
151                }
152                else {
153                    functionType = valueParameterDescriptor.getType();
154                }
155                ResolvedValueArgument valueArgument = entry.getValue();
156                List<ValueArgument> valueArguments = valueArgument.getArguments();
157                for (ValueArgument argument : valueArguments) {
158                    JetExpression expression = argument.getArgumentExpression();
159                    if (expression instanceof JetFunctionLiteralExpression) {
160                        expressionTypingServices.getType(context.scope, expression, functionType, context.dataFlowInfo, context.trace);
161                    }
162                }
163            }
164        }
165    
166        @NotNull
167        public JetTypeInfo getArgumentTypeInfo(
168                @Nullable JetExpression expression,
169                @NotNull CallResolutionContext<?> context,
170                @NotNull ResolveArgumentsMode resolveArgumentsMode
171        ) {
172            if (expression == null) {
173                return JetTypeInfo.create(null, context.dataFlowInfo);
174            }
175            JetExpression deparenthesizedExpression = JetPsiUtil.deparenthesize(JetPsiUtil.unwrapFromBlock(expression), false);
176            if (deparenthesizedExpression instanceof JetFunctionLiteralExpression) {
177                return getFunctionLiteralTypeInfo(expression, (JetFunctionLiteralExpression) deparenthesizedExpression, context, resolveArgumentsMode);
178            }
179            JetTypeInfo recordedTypeInfo = getRecordedTypeInfo(expression, context.trace.getBindingContext());
180            if (recordedTypeInfo != null) {
181                return recordedTypeInfo;
182            }
183            ResolutionContext newContext = context.replaceExpectedType(TypeUtils.NO_EXPECTED_TYPE)
184                    .replaceContextDependency(ContextDependency.DEPENDENT).replaceExpressionPosition(ExpressionPosition.FREE);
185    
186            return expressionTypingServices.getTypeInfo(expression, newContext);
187        }
188    
189        @NotNull
190        public JetTypeInfo getFunctionLiteralTypeInfo(
191                @NotNull JetExpression expression,
192                @NotNull JetFunctionLiteralExpression functionLiteralExpression,
193                @NotNull CallResolutionContext<?> context,
194                @NotNull ResolveArgumentsMode resolveArgumentsMode
195        ) {
196            if (resolveArgumentsMode == SKIP_FUNCTION_ARGUMENTS) {
197                JetType type = getFunctionLiteralType(functionLiteralExpression, context.scope, context.trace);
198                return JetTypeInfo.create(type, context.dataFlowInfo);
199            }
200            return expressionTypingServices.getTypeInfo(expression, context);
201        }
202    
203        @Nullable
204        public JetType getFunctionLiteralType(
205                @NotNull JetFunctionLiteralExpression expression,
206                @NotNull JetScope scope,
207                @NotNull BindingTrace trace
208        ) {
209            if (expression.getFunctionLiteral().getValueParameterList() == null) {
210                return PLACEHOLDER_FUNCTION_TYPE;
211            }
212            List<JetParameter> valueParameters = expression.getValueParameters();
213            TemporaryBindingTrace temporaryTrace = TemporaryBindingTrace.create(
214                    trace, "trace to resolve function literal parameter types");
215            List<JetType> parameterTypes = Lists.newArrayList();
216            for (JetParameter parameter : valueParameters) {
217                parameterTypes.add(resolveTypeRefWithDefault(parameter.getTypeReference(), scope, temporaryTrace, DONT_CARE));
218            }
219            JetFunctionLiteral functionLiteral = expression.getFunctionLiteral();
220            JetType returnType = resolveTypeRefWithDefault(functionLiteral.getReturnTypeRef(), scope, temporaryTrace, DONT_CARE);
221            assert returnType != null;
222            JetType receiverType = resolveTypeRefWithDefault(functionLiteral.getReceiverTypeRef(), scope, temporaryTrace, null);
223            return KotlinBuiltIns.getInstance().getFunctionType(Collections.<AnnotationDescriptor>emptyList(), receiverType, parameterTypes,
224                                                                returnType);
225        }
226    
227        @Nullable
228        public JetType resolveTypeRefWithDefault(
229                @Nullable JetTypeReference returnTypeRef,
230                @NotNull JetScope scope,
231                @NotNull BindingTrace trace,
232                @Nullable JetType defaultValue
233        ) {
234            if (returnTypeRef != null) {
235                return expressionTypingServices.getTypeResolver().resolveType(scope, returnTypeRef, trace, true);
236            }
237            return defaultValue;
238        }
239    
240        public <D extends CallableDescriptor> void analyzeArgumentsAndRecordTypes(
241                @NotNull CallResolutionContext<?> context
242        ) {
243            MutableDataFlowInfoForArguments infoForArguments = context.dataFlowInfoForArguments;
244            infoForArguments.setInitialDataFlowInfo(context.dataFlowInfo);
245    
246            for (ValueArgument argument : context.call.getValueArguments()) {
247                JetExpression expression = argument.getArgumentExpression();
248                if (expression == null) continue;
249    
250                CallResolutionContext<?> newContext = context.replaceDataFlowInfo(infoForArguments.getInfo(argument));
251                JetTypeInfo typeInfoForCall = getArgumentTypeInfo(expression, newContext, SKIP_FUNCTION_ARGUMENTS);
252                infoForArguments.updateInfo(argument, typeInfoForCall.getDataFlowInfo());
253            }
254        }
255    
256        @Nullable
257        public <D extends CallableDescriptor> JetType updateResultArgumentTypeIfNotDenotable(
258                @NotNull ResolutionContext context,
259                @NotNull JetExpression expression
260        ) {
261            JetType type = context.trace.get(BindingContext.EXPRESSION_TYPE, expression);
262            if (type != null && !type.getConstructor().isDenotable()) {
263                if (type.getConstructor() instanceof NumberValueTypeConstructor) {
264                    NumberValueTypeConstructor constructor = (NumberValueTypeConstructor) type.getConstructor();
265                    JetType primitiveType = TypeUtils.getPrimitiveNumberType(constructor, context.expectedType);
266                    updateNumberType(primitiveType, expression, context);
267                    return primitiveType;
268                }
269            }
270            return type;
271        }
272    
273        private <D extends CallableDescriptor> void updateNumberType(
274                @NotNull JetType numberType,
275                @Nullable JetExpression expression,
276                @NotNull ResolutionContext context
277        ) {
278            if (expression == null) return;
279            BindingContextUtils.updateRecordedType(numberType, expression, context.trace, false);
280    
281            if (!(expression instanceof JetConstantExpression)) {
282                JetExpression deparenthesized = JetPsiUtil.deparenthesize(expression, false);
283                if (deparenthesized != expression) {
284                    updateNumberType(numberType, deparenthesized, context);
285                }
286                if (deparenthesized instanceof JetBlockExpression) {
287                    JetElement lastStatement = JetPsiUtil.getLastStatementInABlock((JetBlockExpression) deparenthesized);
288                    if (lastStatement instanceof JetExpression) {
289                        updateNumberType(numberType, (JetExpression) lastStatement, context);
290                    }
291                }
292                return;
293            }
294            CompileTimeConstant<?> constant =
295                    new CompileTimeConstantResolver().getCompileTimeConstant((JetConstantExpression) expression, numberType);
296    
297            if (!(constant instanceof ErrorValue)) {
298                context.trace.record(BindingContext.COMPILE_TIME_VALUE, expression, constant);
299            }
300        }
301    }