001    /*
002     * Copyright 2010-2015 JetBrains s.r.o.
003     *
004     * Licensed under the Apache License, Version 2.0 (the "License");
005     * you may not use this file except in compliance with the License.
006     * You may obtain a copy of the License at
007     *
008     * http://www.apache.org/licenses/LICENSE-2.0
009     *
010     * Unless required by applicable law or agreed to in writing, software
011     * distributed under the License is distributed on an "AS IS" BASIS,
012     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013     * See the License for the specific language governing permissions and
014     * limitations under the License.
015     */
016    
017    package org.jetbrains.kotlin.types.expressions;
018    
019    import kotlin.Pair;
020    import org.jetbrains.annotations.NotNull;
021    import org.jetbrains.annotations.Nullable;
022    import org.jetbrains.kotlin.builtins.KotlinBuiltIns;
023    import org.jetbrains.kotlin.descriptors.FunctionDescriptor;
024    import org.jetbrains.kotlin.diagnostics.DiagnosticFactory1;
025    import org.jetbrains.kotlin.name.Name;
026    import org.jetbrains.kotlin.psi.Call;
027    import org.jetbrains.kotlin.psi.JetExpression;
028    import org.jetbrains.kotlin.resolve.calls.model.ResolvedCall;
029    import org.jetbrains.kotlin.resolve.calls.results.OverloadResolutionResults;
030    import org.jetbrains.kotlin.resolve.scopes.receivers.ExpressionReceiver;
031    import org.jetbrains.kotlin.resolve.scopes.receivers.TransientReceiver;
032    import org.jetbrains.kotlin.resolve.validation.SymbolUsageValidator;
033    import org.jetbrains.kotlin.types.JetType;
034    import org.jetbrains.kotlin.util.slicedMap.WritableSlice;
035    
036    import javax.inject.Inject;
037    import java.util.Collections;
038    
039    import static org.jetbrains.kotlin.diagnostics.Errors.*;
040    import static org.jetbrains.kotlin.resolve.BindingContext.*;
041    
042    public class ForLoopConventionsChecker {
043    
044        private KotlinBuiltIns builtIns;
045        private SymbolUsageValidator symbolUsageValidator;
046        private FakeCallResolver fakeCallResolver;
047    
048        @Inject
049        public void setBuiltIns(@NotNull KotlinBuiltIns builtIns) {
050            this.builtIns = builtIns;
051        }
052    
053        @Inject
054        public void setFakeCallResolver(@NotNull FakeCallResolver fakeCallResolver) {
055            this.fakeCallResolver = fakeCallResolver;
056        }
057    
058        @Inject
059        public void setSymbolUsageValidator(SymbolUsageValidator symbolUsageValidator) {
060            this.symbolUsageValidator = symbolUsageValidator;
061        }
062    
063        @Nullable
064        public JetType checkIterableConvention(@NotNull ExpressionReceiver loopRange, ExpressionTypingContext context) {
065            JetExpression loopRangeExpression = loopRange.getExpression();
066    
067            // Make a fake call loopRange.iterator(), and try to resolve it
068            Name iterator = Name.identifier("iterator");
069            Pair<Call, OverloadResolutionResults<FunctionDescriptor>> calls =
070                    fakeCallResolver.makeAndResolveFakeCall(loopRange, context, Collections.<JetExpression>emptyList(), iterator,
071                                                                 loopRange.getExpression());
072            OverloadResolutionResults<FunctionDescriptor> iteratorResolutionResults = calls.getSecond();
073    
074            if (iteratorResolutionResults.isSuccess()) {
075                ResolvedCall<FunctionDescriptor> iteratorResolvedCall = iteratorResolutionResults.getResultingCall();
076                context.trace.record(LOOP_RANGE_ITERATOR_RESOLVED_CALL, loopRangeExpression, iteratorResolvedCall);
077                FunctionDescriptor iteratorFunction = iteratorResolvedCall.getResultingDescriptor();
078    
079                symbolUsageValidator.validateCall(iteratorFunction, context.trace, loopRangeExpression);
080    
081                JetType iteratorType = iteratorFunction.getReturnType();
082                JetType hasNextType = checkConventionForIterator(context, loopRangeExpression, iteratorType, "hasNext",
083                                                                 HAS_NEXT_FUNCTION_AMBIGUITY, HAS_NEXT_MISSING, HAS_NEXT_FUNCTION_NONE_APPLICABLE,
084                                                                 LOOP_RANGE_HAS_NEXT_RESOLVED_CALL);
085                if (hasNextType != null && !builtIns.isBooleanOrSubtype(hasNextType)) {
086                    context.trace.report(HAS_NEXT_FUNCTION_TYPE_MISMATCH.on(loopRangeExpression, hasNextType));
087                }
088                return checkConventionForIterator(context, loopRangeExpression, iteratorType, "next",
089                                                  NEXT_AMBIGUITY, NEXT_MISSING, NEXT_NONE_APPLICABLE,
090                                                  LOOP_RANGE_NEXT_RESOLVED_CALL);
091            }
092            else {
093                if (iteratorResolutionResults.isAmbiguity()) {
094                    context.trace.report(ITERATOR_AMBIGUITY.on(loopRangeExpression, iteratorResolutionResults.getResultingCalls()));
095                }
096                else {
097                    context.trace.report(ITERATOR_MISSING.on(loopRangeExpression));
098                }
099            }
100            return null;
101        }
102    
103        @Nullable
104        private JetType checkConventionForIterator(
105                @NotNull ExpressionTypingContext context,
106                @NotNull JetExpression loopRangeExpression,
107                @NotNull JetType iteratorType,
108                @NotNull String name,
109                @NotNull DiagnosticFactory1<JetExpression, JetType> ambiguity,
110                @NotNull DiagnosticFactory1<JetExpression, JetType> missing,
111                @NotNull DiagnosticFactory1<JetExpression, JetType> noneApplicable,
112                @NotNull WritableSlice<JetExpression, ResolvedCall<FunctionDescriptor>> resolvedCallKey
113        ) {
114            OverloadResolutionResults<FunctionDescriptor> nextResolutionResults = fakeCallResolver.resolveFakeCall(
115                    context, new TransientReceiver(iteratorType), Name.identifier(name), loopRangeExpression);
116            if (nextResolutionResults.isAmbiguity()) {
117                context.trace.report(ambiguity.on(loopRangeExpression, iteratorType));
118            }
119            else if (nextResolutionResults.isNothing()) {
120                context.trace.report(missing.on(loopRangeExpression, iteratorType));
121            }
122            else if (!nextResolutionResults.isSuccess()) {
123                context.trace.report(noneApplicable.on(loopRangeExpression, iteratorType));
124            }
125            else {
126                assert nextResolutionResults.isSuccess();
127                ResolvedCall<FunctionDescriptor> resolvedCall = nextResolutionResults.getResultingCall();
128                context.trace.record(resolvedCallKey, loopRangeExpression, resolvedCall);
129                FunctionDescriptor functionDescriptor = resolvedCall.getResultingDescriptor();
130                symbolUsageValidator.validateCall(functionDescriptor, context.trace, loopRangeExpression);
131                return functionDescriptor.getReturnType();
132            }
133            return null;
134        }
135    }