001    /*
002     * Copyright 2010-2015 JetBrains s.r.o.
003     *
004     * Licensed under the Apache License, Version 2.0 (the "License");
005     * you may not use this file except in compliance with the License.
006     * You may obtain a copy of the License at
007     *
008     * http://www.apache.org/licenses/LICENSE-2.0
009     *
010     * Unless required by applicable law or agreed to in writing, software
011     * distributed under the License is distributed on an "AS IS" BASIS,
012     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013     * See the License for the specific language governing permissions and
014     * limitations under the License.
015     */
016    
017    package org.jetbrains.kotlin.storage;
018    
019    import kotlin.Unit;
020    import kotlin.jvm.functions.Function0;
021    import kotlin.jvm.functions.Function1;
022    import kotlin.text.StringsKt;
023    import org.jetbrains.annotations.NotNull;
024    import org.jetbrains.annotations.Nullable;
025    import org.jetbrains.kotlin.utils.ExceptionUtilsKt;
026    import org.jetbrains.kotlin.utils.WrappedValues;
027    
028    import java.util.Arrays;
029    import java.util.List;
030    import java.util.concurrent.ConcurrentHashMap;
031    import java.util.concurrent.ConcurrentMap;
032    import java.util.concurrent.locks.Lock;
033    import java.util.concurrent.locks.ReentrantLock;
034    
035    public class LockBasedStorageManager implements StorageManager {
036        private static final String PACKAGE_NAME = StringsKt.substringBeforeLast(LockBasedStorageManager.class.getCanonicalName(), ".", "");
037    
038        public interface ExceptionHandlingStrategy {
039            ExceptionHandlingStrategy THROW = new ExceptionHandlingStrategy() {
040                @NotNull
041                @Override
042                public RuntimeException handleException(@NotNull Throwable throwable) {
043                    throw ExceptionUtilsKt.rethrow(throwable);
044                }
045            };
046    
047            /*
048             * The signature of this method is a trick: it is used as
049             *
050             *     throw strategy.handleException(...)
051             *
052             * most implementations of this method throw exceptions themselves, so it does not matter what they return
053             */
054            @NotNull
055            RuntimeException handleException(@NotNull Throwable throwable);
056        }
057    
058        public static final StorageManager NO_LOCKS = new LockBasedStorageManager("NO_LOCKS", ExceptionHandlingStrategy.THROW, NoLock.INSTANCE) {
059            @NotNull
060            @Override
061            protected <T> RecursionDetectedResult<T> recursionDetectedDefault() {
062                return RecursionDetectedResult.fallThrough();
063            }
064        };
065    
066        @NotNull
067        public static LockBasedStorageManager createWithExceptionHandling(@NotNull ExceptionHandlingStrategy exceptionHandlingStrategy) {
068            return new LockBasedStorageManager(exceptionHandlingStrategy);
069        }
070    
071        protected final Lock lock;
072        private final ExceptionHandlingStrategy exceptionHandlingStrategy;
073        private final String debugText;
074    
075        private LockBasedStorageManager(
076                @NotNull String debugText,
077                @NotNull ExceptionHandlingStrategy exceptionHandlingStrategy,
078                @NotNull Lock lock
079        ) {
080            this.lock = lock;
081            this.exceptionHandlingStrategy = exceptionHandlingStrategy;
082            this.debugText = debugText;
083        }
084    
085        public LockBasedStorageManager() {
086            this(defaultDebugName(), ExceptionHandlingStrategy.THROW, new ReentrantLock());
087        }
088    
089        protected LockBasedStorageManager(@NotNull ExceptionHandlingStrategy exceptionHandlingStrategy) {
090            this(defaultDebugName(), exceptionHandlingStrategy, new ReentrantLock());
091        }
092    
093        private static String defaultDebugName() {
094            return "<unknown creating class>";
095        }
096    
097        @Override
098        public String toString() {
099            return getClass().getSimpleName() + "@" + Integer.toHexString(hashCode()) + " (" + debugText + ")";
100        }
101    
102        @NotNull
103        @Override
104        public <K, V> MemoizedFunctionToNotNull<K, V> createMemoizedFunction(@NotNull Function1<? super K, ? extends V> compute) {
105            return createMemoizedFunction(compute, LockBasedStorageManager.<K>createConcurrentHashMap());
106        }
107    
108        @NotNull
109        @Override
110        public <K, V> MemoizedFunctionToNotNull<K, V> createMemoizedFunction(
111                @NotNull Function1<? super K, ? extends V> compute,
112                @NotNull ConcurrentMap<K, Object> map
113        ) {
114            return new MapBasedMemoizedFunctionToNotNull<K, V>(this, map, compute);
115        }
116    
117        @NotNull
118        @Override
119        public <K, V> MemoizedFunctionToNullable<K, V> createMemoizedFunctionWithNullableValues(@NotNull Function1<? super K, ? extends V> compute) {
120            return createMemoizedFunctionWithNullableValues(compute, LockBasedStorageManager.<K>createConcurrentHashMap());
121        }
122    
123        @Override
124        @NotNull
125        public  <K, V> MemoizedFunctionToNullable<K, V> createMemoizedFunctionWithNullableValues(
126                @NotNull Function1<? super K, ? extends V> compute,
127                @NotNull ConcurrentMap<K, Object> map
128        ) {
129            return new MapBasedMemoizedFunction<K, V>(this, map, compute);
130        }
131    
132        @NotNull
133        @Override
134        public <T> NotNullLazyValue<T> createLazyValue(@NotNull Function0<? extends T> computable) {
135            return new LockBasedNotNullLazyValue<T>(this, computable);
136        }
137    
138        @NotNull
139        @Override
140        public <T> NotNullLazyValue<T> createRecursionTolerantLazyValue(
141                @NotNull Function0<? extends T> computable, @NotNull final T onRecursiveCall
142        ) {
143            return new LockBasedNotNullLazyValue<T>(this, computable) {
144                @NotNull
145                @Override
146                protected RecursionDetectedResult<T> recursionDetected(boolean firstTime) {
147                    return RecursionDetectedResult.value(onRecursiveCall);
148                }
149            };
150        }
151    
152        @NotNull
153        @Override
154        public <T> NotNullLazyValue<T> createLazyValueWithPostCompute(
155                @NotNull Function0<? extends T> computable,
156                final Function1<? super Boolean, ? extends T> onRecursiveCall,
157                @NotNull final Function1<? super T, Unit> postCompute
158        ) {
159            return new LockBasedNotNullLazyValue<T>(this, computable) {
160                @NotNull
161                @Override
162                protected RecursionDetectedResult<T> recursionDetected(boolean firstTime) {
163                    if (onRecursiveCall == null) {
164                        return super.recursionDetected(firstTime);
165                    }
166                    return RecursionDetectedResult.value(onRecursiveCall.invoke(firstTime));
167                }
168    
169                @Override
170                protected void postCompute(@NotNull T value) {
171                    postCompute.invoke(value);
172                }
173            };
174        }
175    
176        @NotNull
177        @Override
178        public <T> NullableLazyValue<T> createNullableLazyValue(@NotNull Function0<? extends T> computable) {
179            return new LockBasedLazyValue<T>(this, computable);
180        }
181    
182        @NotNull
183        @Override
184        public <T> NullableLazyValue<T> createRecursionTolerantNullableLazyValue(@NotNull Function0<? extends T> computable, final T onRecursiveCall) {
185            return new LockBasedLazyValue<T>(this, computable) {
186                @NotNull
187                @Override
188                protected RecursionDetectedResult<T> recursionDetected(boolean firstTime) {
189                    return RecursionDetectedResult.value(onRecursiveCall);
190                }
191            };
192        }
193    
194        @NotNull
195        @Override
196        public <T> NullableLazyValue<T> createNullableLazyValueWithPostCompute(
197                @NotNull Function0<? extends T> computable, @NotNull final Function1<? super T, Unit> postCompute
198        ) {
199            return new LockBasedLazyValue<T>(this, computable) {
200                @Override
201                protected void postCompute(@Nullable T value) {
202                    postCompute.invoke(value);
203                }
204            };
205        }
206    
207        @Override
208        public <T> T compute(@NotNull Function0<? extends T> computable) {
209            lock.lock();
210            try {
211                return computable.invoke();
212            }
213            catch (Throwable throwable) {
214                throw exceptionHandlingStrategy.handleException(throwable);
215            }
216            finally {
217                lock.unlock();
218            }
219        }
220    
221        @NotNull
222        private static <K> ConcurrentMap<K, Object> createConcurrentHashMap() {
223            // memory optimization: fewer segments and entries stored
224            return new ConcurrentHashMap<K, Object>(3, 1, 2);
225        }
226    
227        @NotNull
228        protected <T> RecursionDetectedResult<T> recursionDetectedDefault() {
229            throw sanitizeStackTrace(new IllegalStateException("Recursive call in a lazy value under " + this));
230        }
231    
232        private static class RecursionDetectedResult<T> {
233    
234            @NotNull
235            public static <T> RecursionDetectedResult<T> value(T value) {
236                return new RecursionDetectedResult<T>(value, false);
237            }
238    
239            @NotNull
240            public static <T> RecursionDetectedResult<T> fallThrough() {
241                return new RecursionDetectedResult<T>(null, true);
242            }
243    
244            private final T value;
245            private final boolean fallThrough;
246    
247            private RecursionDetectedResult(T value, boolean fallThrough) {
248                this.value = value;
249                this.fallThrough = fallThrough;
250            }
251    
252            public T getValue() {
253                assert !fallThrough : "A value requested from FALL_THROUGH in " + this;
254                return value;
255            }
256    
257            public boolean isFallThrough() {
258                return fallThrough;
259            }
260    
261            @Override
262            public String toString() {
263                return isFallThrough() ? "FALL_THROUGH" : String.valueOf(value);
264            }
265        }
266    
267        private enum NotValue {
268            NOT_COMPUTED,
269            COMPUTING,
270            RECURSION_WAS_DETECTED
271        }
272    
273        // Being static is memory optimization to prevent capturing outer-class reference at each level of inheritance hierarchy
274        private static class LockBasedLazyValue<T> implements NullableLazyValue<T> {
275            private final LockBasedStorageManager storageManager;
276            private final Function0<? extends T> computable;
277    
278            @Nullable
279            private volatile Object value = NotValue.NOT_COMPUTED;
280    
281            public LockBasedLazyValue(@NotNull LockBasedStorageManager storageManager, @NotNull Function0<? extends T> computable) {
282                this.storageManager = storageManager;
283                this.computable = computable;
284            }
285    
286            @Override
287            public boolean isComputed() {
288                return value != NotValue.NOT_COMPUTED && value != NotValue.COMPUTING;
289            }
290    
291            @Override
292            public boolean isComputing() {
293                return value == NotValue.COMPUTING;
294            }
295    
296            @Override
297            public T invoke() {
298                Object _value = value;
299                if (!(_value instanceof NotValue)) return WrappedValues.unescapeThrowable(_value);
300    
301                storageManager.lock.lock();
302                try {
303                    _value = value;
304                    if (!(_value instanceof NotValue)) return WrappedValues.unescapeThrowable(_value);
305    
306                    if (_value == NotValue.COMPUTING) {
307                        value = NotValue.RECURSION_WAS_DETECTED;
308                        RecursionDetectedResult<T> result = recursionDetected(/*firstTime = */ true);
309                        if (!result.isFallThrough()) {
310                            return result.getValue();
311                        }
312                    }
313    
314                    if (_value == NotValue.RECURSION_WAS_DETECTED) {
315                        RecursionDetectedResult<T> result = recursionDetected(/*firstTime = */ false);
316                        if (!result.isFallThrough()) {
317                            return result.getValue();
318                        }
319                    }
320    
321                    value = NotValue.COMPUTING;
322                    try {
323                        T typedValue = computable.invoke();
324                        value = typedValue;
325                        postCompute(typedValue);
326                        return typedValue;
327                    }
328                    catch (Throwable throwable) {
329                        if (value == NotValue.COMPUTING) {
330                            // Store only if it's a genuine result, not something thrown through recursionDetected()
331                            value = WrappedValues.escapeThrowable(throwable);
332                        }
333                        throw storageManager.exceptionHandlingStrategy.handleException(throwable);
334                    }
335                }
336                finally {
337                    storageManager.lock.unlock();
338                }
339            }
340    
341            /**
342             * @param firstTime {@code true} when recursion has been just detected, {@code false} otherwise
343             * @return a value to be returned on a recursive call or subsequent calls
344             */
345            @NotNull
346            protected RecursionDetectedResult<T> recursionDetected(boolean firstTime) {
347                return storageManager.recursionDetectedDefault();
348            }
349    
350            protected void postCompute(T value) {
351                // Doing something in post-compute helps prevent infinite recursion
352            }
353        }
354    
355        private static class LockBasedNotNullLazyValue<T> extends LockBasedLazyValue<T> implements NotNullLazyValue<T> {
356    
357            public LockBasedNotNullLazyValue(@NotNull LockBasedStorageManager storageManager, @NotNull Function0<? extends T> computable) {
358                super(storageManager, computable);
359            }
360    
361            @Override
362            @NotNull
363            public T invoke() {
364                T result = super.invoke();
365                assert result != null : "compute() returned null";
366                return result;
367            }
368        }
369    
370        private static class MapBasedMemoizedFunction<K, V> implements MemoizedFunctionToNullable<K, V> {
371            private final LockBasedStorageManager storageManager;
372            private final ConcurrentMap<K, Object> cache;
373            private final Function1<? super K, ? extends V> compute;
374    
375            public MapBasedMemoizedFunction(
376                    @NotNull LockBasedStorageManager storageManager,
377                    @NotNull ConcurrentMap<K, Object> map,
378                    @NotNull Function1<? super K, ? extends V> compute
379            ) {
380                this.storageManager = storageManager;
381                this.cache = map;
382                this.compute = compute;
383            }
384    
385            @Override
386            @Nullable
387            public V invoke(K input) {
388                Object value = cache.get(input);
389                if (value != null && value != NotValue.COMPUTING) return WrappedValues.unescapeExceptionOrNull(value);
390    
391                storageManager.lock.lock();
392                try {
393                    value = cache.get(input);
394                    if (value == NotValue.COMPUTING) {
395                        throw recursionDetected(input);
396                    }
397                    if (value != null) return WrappedValues.unescapeExceptionOrNull(value);
398    
399                    AssertionError error = null;
400                    try {
401                        cache.put(input, NotValue.COMPUTING);
402                        V typedValue = compute.invoke(input);
403                        Object oldValue = cache.put(input, WrappedValues.escapeNull(typedValue));
404    
405                        // This code effectively asserts that oldValue is null
406                        // The trickery is here because below we catch all exceptions thrown here, and this is the only exception that shouldn't be stored
407                        // A seemingly obvious way to come about this case would be to declare a special exception class, but the problem is that
408                        // one memoized function is likely to (indirectly) call another, and if this second one throws this exception, we are screwed
409                        if (oldValue != NotValue.COMPUTING) {
410                            error = raceCondition(input, oldValue);
411                            throw error;
412                        }
413    
414                        return typedValue;
415                    }
416                    catch (Throwable throwable) {
417                        if (throwable == error) throw storageManager.exceptionHandlingStrategy.handleException(throwable);
418    
419                        Object oldValue = cache.put(input, WrappedValues.escapeThrowable(throwable));
420                        if (oldValue != NotValue.COMPUTING) {
421                            throw raceCondition(input, oldValue);
422                        }
423    
424                        throw storageManager.exceptionHandlingStrategy.handleException(throwable);
425                    }
426                }
427                finally {
428                    storageManager.lock.unlock();
429                }
430            }
431    
432            @NotNull
433            private AssertionError recursionDetected(K input) {
434                return sanitizeStackTrace(
435                        new AssertionError("Recursion detected on input: " + input + " under " + storageManager)
436                );
437            }
438    
439            @NotNull
440            private AssertionError raceCondition(K input, Object oldValue) {
441                return sanitizeStackTrace(
442                        new AssertionError("Race condition detected on input " + input + ". Old value is " + oldValue +
443                                           " under " + storageManager)
444                );
445            }
446    
447            @Override
448            public boolean isComputed(K key) {
449                Object value = cache.get(key);
450                return value != null && value != NotValue.COMPUTING;
451            }
452    
453            protected LockBasedStorageManager getStorageManager() {
454                return storageManager;
455            }
456        }
457    
458        private static class MapBasedMemoizedFunctionToNotNull<K, V> extends MapBasedMemoizedFunction<K, V> implements MemoizedFunctionToNotNull<K, V> {
459    
460            public MapBasedMemoizedFunctionToNotNull(
461                    @NotNull LockBasedStorageManager storageManager, @NotNull ConcurrentMap<K, Object> map,
462                    @NotNull Function1<? super K, ? extends V> compute
463            ) {
464                super(storageManager, map, compute);
465            }
466    
467            @NotNull
468            @Override
469            public V invoke(K input) {
470                V result = super.invoke(input);
471                assert result != null : "compute() returned null under " + getStorageManager();
472                return result;
473            }
474        }
475    
476        @NotNull
477        public static LockBasedStorageManager createDelegatingWithSameLock(
478                @NotNull LockBasedStorageManager base,
479                @NotNull ExceptionHandlingStrategy newStrategy
480        ) {
481            return new LockBasedStorageManager(defaultDebugName(), newStrategy, base.lock);
482        }
483    
484        @NotNull
485        private static <T extends Throwable> T sanitizeStackTrace(@NotNull T throwable) {
486            StackTraceElement[] stackTrace = throwable.getStackTrace();
487            int size = stackTrace.length;
488    
489            int firstNonStorage = -1;
490            for (int i = 0; i < size; i++) {
491                // Skip everything (memoized functions and lazy values) from package org.jetbrains.kotlin.storage
492                if (!stackTrace[i].getClassName().startsWith(PACKAGE_NAME)) {
493                    firstNonStorage = i;
494                    break;
495                }
496            }
497            assert firstNonStorage >= 0 : "This method should only be called on exceptions created in LockBasedStorageManager";
498    
499            List<StackTraceElement> list = Arrays.asList(stackTrace).subList(firstNonStorage, size);
500            throwable.setStackTrace(list.toArray(new StackTraceElement[list.size()]));
501            return throwable;
502        }
503    
504        @NotNull
505        @Override
506        public <K, V> CacheWithNullableValues<K, V> createCacheWithNullableValues() {
507            return new CacheWithNullableValuesBasedOnMemoizedFunction<K, V>(
508                    this, LockBasedStorageManager.<KeyWithComputation<K,V>>createConcurrentHashMap());
509        }
510    
511        private static class CacheWithNullableValuesBasedOnMemoizedFunction<K, V> extends MapBasedMemoizedFunction<KeyWithComputation<K, V>, V> implements CacheWithNullableValues<K, V> {
512    
513            private CacheWithNullableValuesBasedOnMemoizedFunction(
514                    @NotNull LockBasedStorageManager storageManager,
515                    @NotNull ConcurrentMap<KeyWithComputation<K, V>, Object> map
516            ) {
517                super(storageManager, map, new Function1<KeyWithComputation<K, V>, V>() {
518                    @Override
519                    public V invoke(KeyWithComputation<K, V> computation) {
520                        return computation.computation.invoke();
521                    }
522                });
523            }
524    
525            @Nullable
526            @Override
527            public V computeIfAbsent(K key, @NotNull Function0<? extends V> computation) {
528                return invoke(new KeyWithComputation<K, V>(key, computation));
529            }
530        }
531    
532        @NotNull
533        @Override
534        public <K, V> CacheWithNotNullValues<K, V> createCacheWithNotNullValues() {
535            return new CacheWithNotNullValuesBasedOnMemoizedFunction<K, V>(this, LockBasedStorageManager.<KeyWithComputation<K,V>>createConcurrentHashMap());
536        }
537    
538        private static class CacheWithNotNullValuesBasedOnMemoizedFunction<K, V> extends CacheWithNullableValuesBasedOnMemoizedFunction<K, V> implements CacheWithNotNullValues<K, V> {
539    
540            private CacheWithNotNullValuesBasedOnMemoizedFunction(
541                    @NotNull LockBasedStorageManager storageManager,
542                    @NotNull ConcurrentMap<KeyWithComputation<K, V>, Object> map
543            ) {
544                super(storageManager, map);
545            }
546    
547            @NotNull
548            @Override
549            public V computeIfAbsent(K key, @NotNull Function0<? extends V> computation) {
550                V result = super.computeIfAbsent(key, computation);
551                assert result != null : "computeIfAbsent() returned null under " + getStorageManager();
552                return result;
553            }
554        }
555    
556        // equals and hashCode use only key
557        private static class KeyWithComputation<K, V> {
558            private final K key;
559            private final Function0<? extends V> computation;
560    
561            public KeyWithComputation(K key, Function0<? extends V> computation) {
562                this.key = key;
563                this.computation = computation;
564            }
565    
566            @Override
567            public boolean equals(Object o) {
568                if (this == o) return true;
569                if (o == null || getClass() != o.getClass()) return false;
570    
571                KeyWithComputation<?, ?> that = (KeyWithComputation<?, ?>) o;
572    
573                if (!key.equals(that.key)) return false;
574    
575                return true;
576            }
577    
578            @Override
579            public int hashCode() {
580                return key.hashCode();
581            }
582        }
583    }