001    /*
002     * Copyright 2010-2013 JetBrains s.r.o.
003     *
004     * Licensed under the Apache License, Version 2.0 (the "License");
005     * you may not use this file except in compliance with the License.
006     * You may obtain a copy of the License at
007     *
008     * http://www.apache.org/licenses/LICENSE-2.0
009     *
010     * Unless required by applicable law or agreed to in writing, software
011     * distributed under the License is distributed on an "AS IS" BASIS,
012     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013     * See the License for the specific language governing permissions and
014     * limitations under the License.
015     */
016    
017    package org.jetbrains.jet.lang.resolve.calls;
018    
019    import com.google.common.collect.Lists;
020    import org.jetbrains.annotations.NotNull;
021    import org.jetbrains.annotations.Nullable;
022    import org.jetbrains.jet.lang.descriptors.CallableDescriptor;
023    import org.jetbrains.jet.lang.descriptors.annotations.AnnotationDescriptor;
024    import org.jetbrains.jet.lang.diagnostics.Errors;
025    import org.jetbrains.jet.lang.psi.*;
026    import org.jetbrains.jet.lang.resolve.*;
027    import org.jetbrains.jet.lang.resolve.calls.context.CallResolutionContext;
028    import org.jetbrains.jet.lang.resolve.calls.context.CheckValueArgumentsMode;
029    import org.jetbrains.jet.lang.resolve.calls.context.ResolutionContext;
030    import org.jetbrains.jet.lang.resolve.calls.model.MutableDataFlowInfoForArguments;
031    import org.jetbrains.jet.lang.resolve.constants.CompileTimeConstant;
032    import org.jetbrains.jet.lang.resolve.constants.CompileTimeConstantResolver;
033    import org.jetbrains.jet.lang.resolve.constants.ErrorValue;
034    import org.jetbrains.jet.lang.resolve.constants.NumberValueTypeConstructor;
035    import org.jetbrains.jet.lang.resolve.scopes.JetScope;
036    import org.jetbrains.jet.lang.types.JetType;
037    import org.jetbrains.jet.lang.types.JetTypeInfo;
038    import org.jetbrains.jet.lang.types.TypeUtils;
039    import org.jetbrains.jet.lang.types.checker.JetTypeChecker;
040    import org.jetbrains.jet.lang.types.expressions.ExpressionTypingServices;
041    import org.jetbrains.jet.lang.types.lang.KotlinBuiltIns;
042    
043    import javax.inject.Inject;
044    import java.util.Collections;
045    import java.util.List;
046    import java.util.Set;
047    
048    import static org.jetbrains.jet.lang.resolve.BindingContextUtils.getRecordedTypeInfo;
049    import static org.jetbrains.jet.lang.resolve.calls.CallResolverUtil.ResolveArgumentsMode;
050    import static org.jetbrains.jet.lang.resolve.calls.CallResolverUtil.ResolveArgumentsMode.RESOLVE_FUNCTION_ARGUMENTS;
051    import static org.jetbrains.jet.lang.resolve.calls.CallResolverUtil.ResolveArgumentsMode.SHAPE_FUNCTION_ARGUMENTS;
052    import static org.jetbrains.jet.lang.resolve.calls.context.ContextDependency.DEPENDENT;
053    import static org.jetbrains.jet.lang.types.TypeUtils.*;
054    
055    public class ArgumentTypeResolver {
056    
057        @NotNull
058        private TypeResolver typeResolver;
059        @NotNull
060        private ExpressionTypingServices expressionTypingServices;
061    
062        @Inject
063        public void setTypeResolver(@NotNull TypeResolver typeResolver) {
064            this.typeResolver = typeResolver;
065        }
066    
067        @Inject
068        public void setExpressionTypingServices(@NotNull ExpressionTypingServices expressionTypingServices) {
069            this.expressionTypingServices = expressionTypingServices;
070        }
071    
072        public static boolean isSubtypeOfForArgumentType(
073                @NotNull JetType actualType,
074                @NotNull JetType expectedType
075        ) {
076            if (actualType == PLACEHOLDER_FUNCTION_TYPE) {
077                return isFunctionOrErrorType(expectedType) || KotlinBuiltIns.getInstance().isAny(expectedType); //todo function type extends
078            }
079            return JetTypeChecker.INSTANCE.isSubtypeOf(actualType, expectedType);
080        }
081    
082        private static boolean isFunctionOrErrorType(@NotNull JetType supertype) {
083            return KotlinBuiltIns.getInstance().isFunctionOrExtensionFunctionType(supertype) || supertype.isError();
084        }
085    
086        public void checkTypesWithNoCallee(@NotNull CallResolutionContext<?> context) {
087            checkTypesWithNoCallee(context, SHAPE_FUNCTION_ARGUMENTS);
088        }
089    
090        public void checkTypesWithNoCallee(@NotNull CallResolutionContext<?> context, @NotNull ResolveArgumentsMode resolveFunctionArgumentBodies) {
091            if (context.checkArguments == CheckValueArgumentsMode.DISABLED) return;
092    
093            for (ValueArgument valueArgument : context.call.getValueArguments()) {
094                JetExpression argumentExpression = valueArgument.getArgumentExpression();
095                if (argumentExpression != null && !(argumentExpression instanceof JetFunctionLiteralExpression)) {
096                    checkArgumentTypeWithNoCallee(context, argumentExpression);
097                }
098            }
099    
100            if (resolveFunctionArgumentBodies == RESOLVE_FUNCTION_ARGUMENTS) {
101                checkTypesForFunctionArgumentsWithNoCallee(context);
102            }
103    
104            for (JetTypeProjection typeProjection : context.call.getTypeArguments()) {
105                JetTypeReference typeReference = typeProjection.getTypeReference();
106                if (typeReference == null) {
107                    context.trace.report(Errors.PROJECTION_ON_NON_CLASS_TYPE_ARGUMENT.on(typeProjection));
108                }
109                else {
110                    typeResolver.resolveType(context.scope, typeReference, context.trace, true);
111                }
112            }
113        }
114    
115        public void checkTypesForFunctionArgumentsWithNoCallee(@NotNull CallResolutionContext<?> context) {
116            if (context.checkArguments == CheckValueArgumentsMode.DISABLED) return;
117    
118            for (ValueArgument valueArgument : context.call.getValueArguments()) {
119                JetExpression argumentExpression = valueArgument.getArgumentExpression();
120                if (argumentExpression != null && (argumentExpression instanceof JetFunctionLiteralExpression)) {
121                    checkArgumentTypeWithNoCallee(context, argumentExpression);
122                }
123            }
124    
125            for (JetExpression expression : context.call.getFunctionLiteralArguments()) {
126                checkArgumentTypeWithNoCallee(context, expression);
127            }
128        }
129    
130        public void checkUnmappedArgumentTypes(CallResolutionContext<?> context, Set<ValueArgument> unmappedArguments) {
131            for (ValueArgument valueArgument : unmappedArguments) {
132                JetExpression argumentExpression = valueArgument.getArgumentExpression();
133                if (argumentExpression != null) {
134                    checkArgumentTypeWithNoCallee(context, argumentExpression);
135                }
136            }
137        }
138    
139        private void checkArgumentTypeWithNoCallee(CallResolutionContext<?> context, JetExpression argumentExpression) {
140            expressionTypingServices.getTypeInfo(argumentExpression, context.replaceExpectedType(NO_EXPECTED_TYPE));
141            updateResultArgumentTypeIfNotDenotable(context, argumentExpression);
142        }
143    
144        public static boolean isFunctionLiteralArgument(@NotNull JetExpression expression) {
145            return getFunctionLiteralArgumentIfAny(expression) != null;
146        }
147    
148        @NotNull
149        public static JetFunctionLiteralExpression getFunctionLiteralArgument(@NotNull JetExpression expression) {
150            assert isFunctionLiteralArgument(expression);
151            //noinspection ConstantConditions
152            return getFunctionLiteralArgumentIfAny(expression);
153        }
154    
155        @Nullable
156        private static JetFunctionLiteralExpression getFunctionLiteralArgumentIfAny(@NotNull JetExpression expression) {
157            JetExpression deparenthesizedExpression = JetPsiUtil.deparenthesize(expression, false);
158            if (deparenthesizedExpression instanceof JetBlockExpression) {
159                // todo
160                // This case is a temporary hack for 'if' branches.
161                // The right way to implement this logic is to interpret 'if' branches as function literals with explicitly-typed signatures
162                // (no arguments and no receiver) and therefore analyze them straight away (not in the 'complete' phase).
163                JetElement lastStatementInABlock = JetPsiUtil.getLastStatementInABlock((JetBlockExpression) deparenthesizedExpression);
164                if (lastStatementInABlock instanceof JetExpression) {
165                    deparenthesizedExpression = JetPsiUtil.deparenthesize((JetExpression) lastStatementInABlock, false);
166                }
167            }
168            if (deparenthesizedExpression instanceof JetFunctionLiteralExpression) {
169                return (JetFunctionLiteralExpression) deparenthesizedExpression;
170            }
171            return null;
172        }
173    
174        @NotNull
175        public JetTypeInfo getArgumentTypeInfo(
176                @Nullable JetExpression expression,
177                @NotNull CallResolutionContext<?> context,
178                @NotNull ResolveArgumentsMode resolveArgumentsMode
179        ) {
180            if (expression == null) {
181                return JetTypeInfo.create(null, context.dataFlowInfo);
182            }
183            if (isFunctionLiteralArgument(expression)) {
184                return getFunctionLiteralTypeInfo(expression, getFunctionLiteralArgument(expression), context, resolveArgumentsMode);
185            }
186            JetTypeInfo recordedTypeInfo = getRecordedTypeInfo(expression, context.trace.getBindingContext());
187            if (recordedTypeInfo != null) {
188                return recordedTypeInfo;
189            }
190            ResolutionContext newContext = context.replaceExpectedType(NO_EXPECTED_TYPE).replaceContextDependency(DEPENDENT);
191    
192            return expressionTypingServices.getTypeInfo(expression, newContext);
193        }
194    
195        @NotNull
196        public JetTypeInfo getFunctionLiteralTypeInfo(
197                @NotNull JetExpression expression,
198                @NotNull JetFunctionLiteralExpression functionLiteralExpression,
199                @NotNull CallResolutionContext<?> context,
200                @NotNull ResolveArgumentsMode resolveArgumentsMode
201        ) {
202            if (resolveArgumentsMode == SHAPE_FUNCTION_ARGUMENTS) {
203                JetType type = getShapeTypeOfFunctionLiteral(functionLiteralExpression, context.scope, context.trace, true);
204                return JetTypeInfo.create(type, context.dataFlowInfo);
205            }
206            return expressionTypingServices.getTypeInfo(expression, context);
207        }
208    
209        @Nullable
210        public JetType getShapeTypeOfFunctionLiteral(
211                @NotNull JetFunctionLiteralExpression expression,
212                @NotNull JetScope scope,
213                @NotNull BindingTrace trace,
214                boolean expectedTypeIsUnknown
215        ) {
216            if (expression.getFunctionLiteral().getValueParameterList() == null) {
217                return expectedTypeIsUnknown ? PLACEHOLDER_FUNCTION_TYPE : KotlinBuiltIns.getInstance().getFunctionType(
218                        Collections.<AnnotationDescriptor>emptyList(), null, Collections.<JetType>emptyList(), DONT_CARE);
219            }
220            List<JetParameter> valueParameters = expression.getValueParameters();
221            TemporaryBindingTrace temporaryTrace = TemporaryBindingTrace.create(
222                    trace, "trace to resolve function literal parameter types");
223            List<JetType> parameterTypes = Lists.newArrayList();
224            for (JetParameter parameter : valueParameters) {
225                parameterTypes.add(resolveTypeRefWithDefault(parameter.getTypeReference(), scope, temporaryTrace, DONT_CARE));
226            }
227            JetFunctionLiteral functionLiteral = expression.getFunctionLiteral();
228            JetType returnType = resolveTypeRefWithDefault(functionLiteral.getReturnTypeRef(), scope, temporaryTrace, DONT_CARE);
229            assert returnType != null;
230            JetType receiverType = resolveTypeRefWithDefault(functionLiteral.getReceiverTypeRef(), scope, temporaryTrace, null);
231            return KotlinBuiltIns.getInstance().getFunctionType(Collections.<AnnotationDescriptor>emptyList(), receiverType, parameterTypes,
232                                                                returnType);
233        }
234    
235        @Nullable
236        public JetType resolveTypeRefWithDefault(
237                @Nullable JetTypeReference returnTypeRef,
238                @NotNull JetScope scope,
239                @NotNull BindingTrace trace,
240                @Nullable JetType defaultValue
241        ) {
242            if (returnTypeRef != null) {
243                return expressionTypingServices.getTypeResolver().resolveType(scope, returnTypeRef, trace, true);
244            }
245            return defaultValue;
246        }
247    
248        public <D extends CallableDescriptor> void analyzeArgumentsAndRecordTypes(
249                @NotNull CallResolutionContext<?> context
250        ) {
251            MutableDataFlowInfoForArguments infoForArguments = context.dataFlowInfoForArguments;
252            infoForArguments.setInitialDataFlowInfo(context.dataFlowInfo);
253    
254            for (ValueArgument argument : context.call.getValueArguments()) {
255                JetExpression expression = argument.getArgumentExpression();
256                if (expression == null) continue;
257    
258                CallResolutionContext<?> newContext = context.replaceDataFlowInfo(infoForArguments.getInfo(argument));
259                JetTypeInfo typeInfoForCall = getArgumentTypeInfo(expression, newContext, SHAPE_FUNCTION_ARGUMENTS);
260                infoForArguments.updateInfo(argument, typeInfoForCall.getDataFlowInfo());
261            }
262        }
263    
264        @Nullable
265        public <D extends CallableDescriptor> JetType updateResultArgumentTypeIfNotDenotable(
266                @NotNull ResolutionContext context,
267                @NotNull JetExpression expression
268        ) {
269            JetType type = context.trace.get(BindingContext.EXPRESSION_TYPE, expression);
270            if (type != null && !type.getConstructor().isDenotable()) {
271                if (type.getConstructor() instanceof NumberValueTypeConstructor) {
272                    NumberValueTypeConstructor constructor = (NumberValueTypeConstructor) type.getConstructor();
273                    JetType primitiveType = TypeUtils.getPrimitiveNumberType(constructor, context.expectedType);
274                    updateNumberType(primitiveType, expression, context);
275                    return primitiveType;
276                }
277            }
278            return type;
279        }
280    
281        private <D extends CallableDescriptor> void updateNumberType(
282                @NotNull JetType numberType,
283                @Nullable JetExpression expression,
284                @NotNull ResolutionContext context
285        ) {
286            if (expression == null) return;
287            BindingContextUtils.updateRecordedType(numberType, expression, context.trace, false);
288    
289            if (!(expression instanceof JetConstantExpression)) {
290                JetExpression deparenthesized = JetPsiUtil.deparenthesize(expression, false);
291                if (deparenthesized != expression) {
292                    updateNumberType(numberType, deparenthesized, context);
293                }
294                if (deparenthesized instanceof JetBlockExpression) {
295                    JetElement lastStatement = JetPsiUtil.getLastStatementInABlock((JetBlockExpression) deparenthesized);
296                    if (lastStatement instanceof JetExpression) {
297                        updateNumberType(numberType, (JetExpression) lastStatement, context);
298                    }
299                }
300                return;
301            }
302            CompileTimeConstant<?> constant =
303                    new CompileTimeConstantResolver().getCompileTimeConstant((JetConstantExpression) expression, numberType);
304    
305            if (!(constant instanceof ErrorValue)) {
306                context.trace.record(BindingContext.COMPILE_TIME_VALUE, expression, constant);
307            }
308        }
309    }