001    /*
002     * Copyright 2010-2013 JetBrains s.r.o.
003     *
004     * Licensed under the Apache License, Version 2.0 (the "License");
005     * you may not use this file except in compliance with the License.
006     * You may obtain a copy of the License at
007     *
008     * http://www.apache.org/licenses/LICENSE-2.0
009     *
010     * Unless required by applicable law or agreed to in writing, software
011     * distributed under the License is distributed on an "AS IS" BASIS,
012     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013     * See the License for the specific language governing permissions and
014     * limitations under the License.
015     */
016    
017    package org.jetbrains.jet.lang.resolve.calls;
018    
019    import com.google.common.collect.Lists;
020    import org.jetbrains.annotations.NotNull;
021    import org.jetbrains.annotations.Nullable;
022    import org.jetbrains.jet.lang.descriptors.CallableDescriptor;
023    import org.jetbrains.jet.lang.descriptors.ValueParameterDescriptor;
024    import org.jetbrains.jet.lang.descriptors.annotations.AnnotationDescriptor;
025    import org.jetbrains.jet.lang.diagnostics.Errors;
026    import org.jetbrains.jet.lang.psi.*;
027    import org.jetbrains.jet.lang.resolve.*;
028    import org.jetbrains.jet.lang.resolve.calls.context.*;
029    import org.jetbrains.jet.lang.resolve.calls.model.MutableDataFlowInfoForArguments;
030    import org.jetbrains.jet.lang.resolve.calls.model.ResolvedCallImpl;
031    import org.jetbrains.jet.lang.resolve.calls.model.ResolvedValueArgument;
032    import org.jetbrains.jet.lang.resolve.constants.CompileTimeConstant;
033    import org.jetbrains.jet.lang.resolve.constants.CompileTimeConstantResolver;
034    import org.jetbrains.jet.lang.resolve.constants.ErrorValue;
035    import org.jetbrains.jet.lang.resolve.constants.NumberValueTypeConstructor;
036    import org.jetbrains.jet.lang.resolve.scopes.JetScope;
037    import org.jetbrains.jet.lang.types.JetType;
038    import org.jetbrains.jet.lang.types.JetTypeInfo;
039    import org.jetbrains.jet.lang.types.TypeUtils;
040    import org.jetbrains.jet.lang.types.checker.JetTypeChecker;
041    import org.jetbrains.jet.lang.types.expressions.ExpressionTypingServices;
042    import org.jetbrains.jet.lang.types.lang.KotlinBuiltIns;
043    
044    import javax.inject.Inject;
045    import java.util.Collections;
046    import java.util.List;
047    import java.util.Map;
048    import java.util.Set;
049    
050    import static org.jetbrains.jet.lang.resolve.BindingContextUtils.getRecordedTypeInfo;
051    import static org.jetbrains.jet.lang.resolve.calls.CallResolverUtil.ResolveArgumentsMode;
052    import static org.jetbrains.jet.lang.resolve.calls.CallResolverUtil.ResolveArgumentsMode.RESOLVE_FUNCTION_ARGUMENTS;
053    import static org.jetbrains.jet.lang.resolve.calls.CallResolverUtil.ResolveArgumentsMode.SHAPE_FUNCTION_ARGUMENTS;
054    import static org.jetbrains.jet.lang.types.TypeUtils.*;
055    
056    public class ArgumentTypeResolver {
057    
058        @NotNull
059        private TypeResolver typeResolver;
060        @NotNull
061        private ExpressionTypingServices expressionTypingServices;
062    
063        @Inject
064        public void setTypeResolver(@NotNull TypeResolver typeResolver) {
065            this.typeResolver = typeResolver;
066        }
067    
068        @Inject
069        public void setExpressionTypingServices(@NotNull ExpressionTypingServices expressionTypingServices) {
070            this.expressionTypingServices = expressionTypingServices;
071        }
072    
073        public static boolean isSubtypeOfForArgumentType(
074                @NotNull JetType actualType,
075                @NotNull JetType expectedType
076        ) {
077            if (actualType == PLACEHOLDER_FUNCTION_TYPE) {
078                return isFunctionOrErrorType(expectedType) || KotlinBuiltIns.getInstance().isAny(expectedType); //todo function type extends
079            }
080            return JetTypeChecker.INSTANCE.isSubtypeOf(actualType, expectedType);
081        }
082    
083        private static boolean isFunctionOrErrorType(@NotNull JetType supertype) {
084            return KotlinBuiltIns.getInstance().isFunctionOrExtensionFunctionType(supertype) || supertype.isError();
085        }
086    
087        public void checkTypesWithNoCallee(@NotNull CallResolutionContext<?> context) {
088            checkTypesWithNoCallee(context, SHAPE_FUNCTION_ARGUMENTS);
089        }
090    
091        public void checkTypesWithNoCallee(@NotNull CallResolutionContext<?> context, @NotNull ResolveArgumentsMode resolveFunctionArgumentBodies) {
092            if (context.checkArguments == CheckValueArgumentsMode.DISABLED) return;
093    
094            for (ValueArgument valueArgument : context.call.getValueArguments()) {
095                JetExpression argumentExpression = valueArgument.getArgumentExpression();
096                if (argumentExpression != null && !(argumentExpression instanceof JetFunctionLiteralExpression)) {
097                    checkArgumentType(context, argumentExpression);
098                }
099            }
100    
101            if (resolveFunctionArgumentBodies == RESOLVE_FUNCTION_ARGUMENTS) {
102                checkTypesForFunctionArgumentsWithNoCallee(context);
103            }
104    
105            for (JetTypeProjection typeProjection : context.call.getTypeArguments()) {
106                JetTypeReference typeReference = typeProjection.getTypeReference();
107                if (typeReference == null) {
108                    context.trace.report(Errors.PROJECTION_ON_NON_CLASS_TYPE_ARGUMENT.on(typeProjection));
109                }
110                else {
111                    typeResolver.resolveType(context.scope, typeReference, context.trace, true);
112                }
113            }
114        }
115    
116        public void checkTypesForFunctionArgumentsWithNoCallee(@NotNull CallResolutionContext<?> context) {
117            if (context.checkArguments == CheckValueArgumentsMode.DISABLED) return;
118    
119            for (ValueArgument valueArgument : context.call.getValueArguments()) {
120                JetExpression argumentExpression = valueArgument.getArgumentExpression();
121                if (argumentExpression != null && (argumentExpression instanceof JetFunctionLiteralExpression)) {
122                    checkArgumentType(context, argumentExpression);
123                }
124            }
125    
126            for (JetExpression expression : context.call.getFunctionLiteralArguments()) {
127                checkArgumentType(context, expression);
128            }
129        }
130    
131        public void checkUnmappedArgumentTypes(CallResolutionContext<?> context, Set<ValueArgument> unmappedArguments) {
132            for (ValueArgument valueArgument : unmappedArguments) {
133                JetExpression argumentExpression = valueArgument.getArgumentExpression();
134                if (argumentExpression != null) {
135                    checkArgumentType(context, argumentExpression);
136                }
137            }
138        }
139    
140        private void checkArgumentType(CallResolutionContext<?> context, JetExpression argumentExpression) {
141            expressionTypingServices.getType(context.scope, argumentExpression, NO_EXPECTED_TYPE, context.dataFlowInfo, context.trace);
142            updateResultArgumentTypeIfNotDenotable(context, argumentExpression);
143        }
144    
145        public <D extends CallableDescriptor> void checkTypesForFunctionArguments(CallResolutionContext<?> context, ResolvedCallImpl<D> resolvedCall) {
146            Map<ValueParameterDescriptor, ResolvedValueArgument> arguments = resolvedCall.getValueArguments();
147            for (Map.Entry<ValueParameterDescriptor, ResolvedValueArgument> entry : arguments.entrySet()) {
148                ValueParameterDescriptor valueParameterDescriptor = entry.getKey();
149                JetType varargElementType = valueParameterDescriptor.getVarargElementType();
150                JetType functionType;
151                if (varargElementType != null) {
152                    functionType = varargElementType;
153                }
154                else {
155                    functionType = valueParameterDescriptor.getType();
156                }
157                ResolvedValueArgument valueArgument = entry.getValue();
158                List<ValueArgument> valueArguments = valueArgument.getArguments();
159                for (ValueArgument argument : valueArguments) {
160                    JetExpression expression = argument.getArgumentExpression();
161                    if (expression instanceof JetFunctionLiteralExpression) {
162                        expressionTypingServices.getType(context.scope, expression, functionType, context.dataFlowInfo, context.trace);
163                    }
164                }
165            }
166        }
167    
168        public static boolean isFunctionLiteralArgument(@NotNull JetExpression expression) {
169            return getFunctionLiteralArgumentIfAny(expression) != null;
170        }
171    
172        @NotNull
173        public static JetFunctionLiteralExpression getFunctionLiteralArgument(@NotNull JetExpression expression) {
174            assert isFunctionLiteralArgument(expression);
175            //noinspection ConstantConditions
176            return getFunctionLiteralArgumentIfAny(expression);
177        }
178    
179        @Nullable
180        private static JetFunctionLiteralExpression getFunctionLiteralArgumentIfAny(@NotNull JetExpression expression) {
181            JetExpression deparenthesizedExpression = JetPsiUtil.deparenthesize(expression, false);
182            if (deparenthesizedExpression instanceof JetBlockExpression) {
183                // todo
184                // This case is a temporary hack for 'if' branches.
185                // The right way to implement this logic is to interpret 'if' branches as function literals with explicitly-typed signatures
186                // (no arguments and no receiver) and therefore analyze them straight away (not in the 'complete' phase).
187                JetElement lastStatementInABlock = JetPsiUtil.getLastStatementInABlock((JetBlockExpression) deparenthesizedExpression);
188                if (lastStatementInABlock instanceof JetExpression) {
189                    deparenthesizedExpression = JetPsiUtil.deparenthesize((JetExpression) lastStatementInABlock, false);
190                }
191            }
192            if (deparenthesizedExpression instanceof JetFunctionLiteralExpression) {
193                return (JetFunctionLiteralExpression) deparenthesizedExpression;
194            }
195            return null;
196        }
197    
198        @NotNull
199        public JetTypeInfo getArgumentTypeInfo(
200                @Nullable JetExpression expression,
201                @NotNull CallResolutionContext<?> context,
202                @NotNull ResolveArgumentsMode resolveArgumentsMode
203        ) {
204            if (expression == null) {
205                return JetTypeInfo.create(null, context.dataFlowInfo);
206            }
207            if (isFunctionLiteralArgument(expression)) {
208                return getFunctionLiteralTypeInfo(expression, getFunctionLiteralArgument(expression), context, resolveArgumentsMode);
209            }
210            JetTypeInfo recordedTypeInfo = getRecordedTypeInfo(expression, context.trace.getBindingContext());
211            if (recordedTypeInfo != null) {
212                return recordedTypeInfo;
213            }
214            ResolutionContext newContext = context.replaceExpectedType(TypeUtils.NO_EXPECTED_TYPE)
215                    .replaceContextDependency(ContextDependency.DEPENDENT).replaceExpressionPosition(ExpressionPosition.FREE);
216    
217            return expressionTypingServices.getTypeInfo(expression, newContext);
218        }
219    
220        @NotNull
221        public JetTypeInfo getFunctionLiteralTypeInfo(
222                @NotNull JetExpression expression,
223                @NotNull JetFunctionLiteralExpression functionLiteralExpression,
224                @NotNull CallResolutionContext<?> context,
225                @NotNull ResolveArgumentsMode resolveArgumentsMode
226        ) {
227            if (resolveArgumentsMode == SHAPE_FUNCTION_ARGUMENTS) {
228                JetType type = getShapeTypeOfFunctionLiteral(functionLiteralExpression, context.scope, context.trace, true);
229                return JetTypeInfo.create(type, context.dataFlowInfo);
230            }
231            return expressionTypingServices.getTypeInfo(expression, context);
232        }
233    
234        @Nullable
235        public JetType getShapeTypeOfFunctionLiteral(
236                @NotNull JetFunctionLiteralExpression expression,
237                @NotNull JetScope scope,
238                @NotNull BindingTrace trace,
239                boolean expectedTypeIsUnknown
240        ) {
241            if (expression.getFunctionLiteral().getValueParameterList() == null) {
242                return expectedTypeIsUnknown ? PLACEHOLDER_FUNCTION_TYPE : KotlinBuiltIns.getInstance().getFunctionType(
243                        Collections.<AnnotationDescriptor>emptyList(), null, Collections.<JetType>emptyList(), DONT_CARE);
244            }
245            List<JetParameter> valueParameters = expression.getValueParameters();
246            TemporaryBindingTrace temporaryTrace = TemporaryBindingTrace.create(
247                    trace, "trace to resolve function literal parameter types");
248            List<JetType> parameterTypes = Lists.newArrayList();
249            for (JetParameter parameter : valueParameters) {
250                parameterTypes.add(resolveTypeRefWithDefault(parameter.getTypeReference(), scope, temporaryTrace, DONT_CARE));
251            }
252            JetFunctionLiteral functionLiteral = expression.getFunctionLiteral();
253            JetType returnType = resolveTypeRefWithDefault(functionLiteral.getReturnTypeRef(), scope, temporaryTrace, DONT_CARE);
254            assert returnType != null;
255            JetType receiverType = resolveTypeRefWithDefault(functionLiteral.getReceiverTypeRef(), scope, temporaryTrace, null);
256            return KotlinBuiltIns.getInstance().getFunctionType(Collections.<AnnotationDescriptor>emptyList(), receiverType, parameterTypes,
257                                                                returnType);
258        }
259    
260        @Nullable
261        public JetType resolveTypeRefWithDefault(
262                @Nullable JetTypeReference returnTypeRef,
263                @NotNull JetScope scope,
264                @NotNull BindingTrace trace,
265                @Nullable JetType defaultValue
266        ) {
267            if (returnTypeRef != null) {
268                return expressionTypingServices.getTypeResolver().resolveType(scope, returnTypeRef, trace, true);
269            }
270            return defaultValue;
271        }
272    
273        public <D extends CallableDescriptor> void analyzeArgumentsAndRecordTypes(
274                @NotNull CallResolutionContext<?> context
275        ) {
276            MutableDataFlowInfoForArguments infoForArguments = context.dataFlowInfoForArguments;
277            infoForArguments.setInitialDataFlowInfo(context.dataFlowInfo);
278    
279            for (ValueArgument argument : context.call.getValueArguments()) {
280                JetExpression expression = argument.getArgumentExpression();
281                if (expression == null) continue;
282    
283                CallResolutionContext<?> newContext = context.replaceDataFlowInfo(infoForArguments.getInfo(argument));
284                JetTypeInfo typeInfoForCall = getArgumentTypeInfo(expression, newContext, SHAPE_FUNCTION_ARGUMENTS);
285                infoForArguments.updateInfo(argument, typeInfoForCall.getDataFlowInfo());
286            }
287        }
288    
289        @Nullable
290        public <D extends CallableDescriptor> JetType updateResultArgumentTypeIfNotDenotable(
291                @NotNull ResolutionContext context,
292                @NotNull JetExpression expression
293        ) {
294            JetType type = context.trace.get(BindingContext.EXPRESSION_TYPE, expression);
295            if (type != null && !type.getConstructor().isDenotable()) {
296                if (type.getConstructor() instanceof NumberValueTypeConstructor) {
297                    NumberValueTypeConstructor constructor = (NumberValueTypeConstructor) type.getConstructor();
298                    JetType primitiveType = TypeUtils.getPrimitiveNumberType(constructor, context.expectedType);
299                    updateNumberType(primitiveType, expression, context);
300                    return primitiveType;
301                }
302            }
303            return type;
304        }
305    
306        private <D extends CallableDescriptor> void updateNumberType(
307                @NotNull JetType numberType,
308                @Nullable JetExpression expression,
309                @NotNull ResolutionContext context
310        ) {
311            if (expression == null) return;
312            BindingContextUtils.updateRecordedType(numberType, expression, context.trace, false);
313    
314            if (!(expression instanceof JetConstantExpression)) {
315                JetExpression deparenthesized = JetPsiUtil.deparenthesize(expression, false);
316                if (deparenthesized != expression) {
317                    updateNumberType(numberType, deparenthesized, context);
318                }
319                if (deparenthesized instanceof JetBlockExpression) {
320                    JetElement lastStatement = JetPsiUtil.getLastStatementInABlock((JetBlockExpression) deparenthesized);
321                    if (lastStatement instanceof JetExpression) {
322                        updateNumberType(numberType, (JetExpression) lastStatement, context);
323                    }
324                }
325                return;
326            }
327            CompileTimeConstant<?> constant =
328                    new CompileTimeConstantResolver().getCompileTimeConstant((JetConstantExpression) expression, numberType);
329    
330            if (!(constant instanceof ErrorValue)) {
331                context.trace.record(BindingContext.COMPILE_TIME_VALUE, expression, constant);
332            }
333        }
334    }