001 /*
002 * Copyright 2010-2015 JetBrains s.r.o.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 * http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017 package org.jetbrains.kotlin.storage;
018
019 import kotlin.Unit;
020 import kotlin.jvm.functions.Function0;
021 import kotlin.jvm.functions.Function1;
022 import kotlin.text.StringsKt;
023 import org.jetbrains.annotations.NotNull;
024 import org.jetbrains.annotations.Nullable;
025 import org.jetbrains.kotlin.utils.ExceptionUtilsKt;
026 import org.jetbrains.kotlin.utils.WrappedValues;
027
028 import java.util.Arrays;
029 import java.util.List;
030 import java.util.concurrent.ConcurrentHashMap;
031 import java.util.concurrent.ConcurrentMap;
032 import java.util.concurrent.locks.Lock;
033 import java.util.concurrent.locks.ReentrantLock;
034
035 public class LockBasedStorageManager implements StorageManager {
036 private static final String PACKAGE_NAME = StringsKt.substringBeforeLast(LockBasedStorageManager.class.getCanonicalName(), ".", "");
037
038 public interface ExceptionHandlingStrategy {
039 ExceptionHandlingStrategy THROW = new ExceptionHandlingStrategy() {
040 @NotNull
041 @Override
042 public RuntimeException handleException(@NotNull Throwable throwable) {
043 throw ExceptionUtilsKt.rethrow(throwable);
044 }
045 };
046
047 /*
048 * The signature of this method is a trick: it is used as
049 *
050 * throw strategy.handleException(...)
051 *
052 * most implementations of this method throw exceptions themselves, so it does not matter what they return
053 */
054 @NotNull
055 RuntimeException handleException(@NotNull Throwable throwable);
056 }
057
058 public static final StorageManager NO_LOCKS = new LockBasedStorageManager("NO_LOCKS", ExceptionHandlingStrategy.THROW, NoLock.INSTANCE) {
059 @NotNull
060 @Override
061 protected <T> RecursionDetectedResult<T> recursionDetectedDefault() {
062 return RecursionDetectedResult.fallThrough();
063 }
064 };
065
066 @NotNull
067 public static LockBasedStorageManager createWithExceptionHandling(@NotNull ExceptionHandlingStrategy exceptionHandlingStrategy) {
068 return new LockBasedStorageManager(exceptionHandlingStrategy);
069 }
070
071 protected final Lock lock;
072 private final ExceptionHandlingStrategy exceptionHandlingStrategy;
073 private final String debugText;
074
075 private LockBasedStorageManager(
076 @NotNull String debugText,
077 @NotNull ExceptionHandlingStrategy exceptionHandlingStrategy,
078 @NotNull Lock lock
079 ) {
080 this.lock = lock;
081 this.exceptionHandlingStrategy = exceptionHandlingStrategy;
082 this.debugText = debugText;
083 }
084
085 public LockBasedStorageManager() {
086 this(defaultDebugName(), ExceptionHandlingStrategy.THROW, new ReentrantLock());
087 }
088
089 protected LockBasedStorageManager(@NotNull ExceptionHandlingStrategy exceptionHandlingStrategy) {
090 this(defaultDebugName(), exceptionHandlingStrategy, new ReentrantLock());
091 }
092
093 private static String defaultDebugName() {
094 return "<unknown creating class>";
095 }
096
097 @Override
098 public String toString() {
099 return getClass().getSimpleName() + "@" + Integer.toHexString(hashCode()) + " (" + debugText + ")";
100 }
101
102 @NotNull
103 @Override
104 public <K, V> MemoizedFunctionToNotNull<K, V> createMemoizedFunction(@NotNull Function1<? super K, ? extends V> compute) {
105 return createMemoizedFunction(compute, LockBasedStorageManager.<K>createConcurrentHashMap());
106 }
107
108 @NotNull
109 @Override
110 public <K, V> MemoizedFunctionToNotNull<K, V> createMemoizedFunction(
111 @NotNull Function1<? super K, ? extends V> compute,
112 @NotNull ConcurrentMap<K, Object> map
113 ) {
114 return new MapBasedMemoizedFunctionToNotNull<K, V>(this, map, compute);
115 }
116
117 @NotNull
118 @Override
119 public <K, V> MemoizedFunctionToNullable<K, V> createMemoizedFunctionWithNullableValues(@NotNull Function1<? super K, ? extends V> compute) {
120 return createMemoizedFunctionWithNullableValues(compute, LockBasedStorageManager.<K>createConcurrentHashMap());
121 }
122
123 @Override
124 @NotNull
125 public <K, V> MemoizedFunctionToNullable<K, V> createMemoizedFunctionWithNullableValues(
126 @NotNull Function1<? super K, ? extends V> compute,
127 @NotNull ConcurrentMap<K, Object> map
128 ) {
129 return new MapBasedMemoizedFunction<K, V>(this, map, compute);
130 }
131
132 @NotNull
133 @Override
134 public <T> NotNullLazyValue<T> createLazyValue(@NotNull Function0<? extends T> computable) {
135 return new LockBasedNotNullLazyValue<T>(this, computable);
136 }
137
138 @NotNull
139 @Override
140 public <T> NotNullLazyValue<T> createRecursionTolerantLazyValue(
141 @NotNull Function0<? extends T> computable, @NotNull final T onRecursiveCall
142 ) {
143 return new LockBasedNotNullLazyValue<T>(this, computable) {
144 @NotNull
145 @Override
146 protected RecursionDetectedResult<T> recursionDetected(boolean firstTime) {
147 return RecursionDetectedResult.value(onRecursiveCall);
148 }
149 };
150 }
151
152 @NotNull
153 @Override
154 public <T> NotNullLazyValue<T> createLazyValueWithPostCompute(
155 @NotNull Function0<? extends T> computable,
156 final Function1<? super Boolean, ? extends T> onRecursiveCall,
157 @NotNull final Function1<? super T, Unit> postCompute
158 ) {
159 return new LockBasedNotNullLazyValue<T>(this, computable) {
160 @NotNull
161 @Override
162 protected RecursionDetectedResult<T> recursionDetected(boolean firstTime) {
163 if (onRecursiveCall == null) {
164 return super.recursionDetected(firstTime);
165 }
166 return RecursionDetectedResult.value(onRecursiveCall.invoke(firstTime));
167 }
168
169 @Override
170 protected void postCompute(@NotNull T value) {
171 postCompute.invoke(value);
172 }
173 };
174 }
175
176 @NotNull
177 @Override
178 public <T> NullableLazyValue<T> createNullableLazyValue(@NotNull Function0<? extends T> computable) {
179 return new LockBasedLazyValue<T>(this, computable);
180 }
181
182 @NotNull
183 @Override
184 public <T> NullableLazyValue<T> createRecursionTolerantNullableLazyValue(@NotNull Function0<? extends T> computable, final T onRecursiveCall) {
185 return new LockBasedLazyValue<T>(this, computable) {
186 @NotNull
187 @Override
188 protected RecursionDetectedResult<T> recursionDetected(boolean firstTime) {
189 return RecursionDetectedResult.value(onRecursiveCall);
190 }
191 };
192 }
193
194 @NotNull
195 @Override
196 public <T> NullableLazyValue<T> createNullableLazyValueWithPostCompute(
197 @NotNull Function0<? extends T> computable, @NotNull final Function1<? super T, Unit> postCompute
198 ) {
199 return new LockBasedLazyValue<T>(this, computable) {
200 @Override
201 protected void postCompute(@Nullable T value) {
202 postCompute.invoke(value);
203 }
204 };
205 }
206
207 @Override
208 public <T> T compute(@NotNull Function0<? extends T> computable) {
209 lock.lock();
210 try {
211 return computable.invoke();
212 }
213 catch (Throwable throwable) {
214 throw exceptionHandlingStrategy.handleException(throwable);
215 }
216 finally {
217 lock.unlock();
218 }
219 }
220
221 @NotNull
222 private static <K> ConcurrentMap<K, Object> createConcurrentHashMap() {
223 // memory optimization: fewer segments and entries stored
224 return new ConcurrentHashMap<K, Object>(3, 1, 2);
225 }
226
227 @NotNull
228 protected <T> RecursionDetectedResult<T> recursionDetectedDefault() {
229 throw sanitizeStackTrace(new IllegalStateException("Recursive call in a lazy value under " + this));
230 }
231
232 private static class RecursionDetectedResult<T> {
233
234 @NotNull
235 public static <T> RecursionDetectedResult<T> value(T value) {
236 return new RecursionDetectedResult<T>(value, false);
237 }
238
239 @NotNull
240 public static <T> RecursionDetectedResult<T> fallThrough() {
241 return new RecursionDetectedResult<T>(null, true);
242 }
243
244 private final T value;
245 private final boolean fallThrough;
246
247 private RecursionDetectedResult(T value, boolean fallThrough) {
248 this.value = value;
249 this.fallThrough = fallThrough;
250 }
251
252 public T getValue() {
253 assert !fallThrough : "A value requested from FALL_THROUGH in " + this;
254 return value;
255 }
256
257 public boolean isFallThrough() {
258 return fallThrough;
259 }
260
261 @Override
262 public String toString() {
263 return isFallThrough() ? "FALL_THROUGH" : String.valueOf(value);
264 }
265 }
266
267 private enum NotValue {
268 NOT_COMPUTED,
269 COMPUTING,
270 RECURSION_WAS_DETECTED
271 }
272
273 // Being static is memory optimization to prevent capturing outer-class reference at each level of inheritance hierarchy
274 private static class LockBasedLazyValue<T> implements NullableLazyValue<T> {
275 private final LockBasedStorageManager storageManager;
276 private final Function0<? extends T> computable;
277
278 @Nullable
279 private volatile Object value = NotValue.NOT_COMPUTED;
280
281 public LockBasedLazyValue(@NotNull LockBasedStorageManager storageManager, @NotNull Function0<? extends T> computable) {
282 this.storageManager = storageManager;
283 this.computable = computable;
284 }
285
286 @Override
287 public boolean isComputed() {
288 return value != NotValue.NOT_COMPUTED && value != NotValue.COMPUTING;
289 }
290
291 @Override
292 public boolean isComputing() {
293 return value == NotValue.COMPUTING;
294 }
295
296 @Override
297 public T invoke() {
298 Object _value = value;
299 if (!(_value instanceof NotValue)) return WrappedValues.unescapeThrowable(_value);
300
301 storageManager.lock.lock();
302 try {
303 _value = value;
304 if (!(_value instanceof NotValue)) return WrappedValues.unescapeThrowable(_value);
305
306 if (_value == NotValue.COMPUTING) {
307 value = NotValue.RECURSION_WAS_DETECTED;
308 RecursionDetectedResult<T> result = recursionDetected(/*firstTime = */ true);
309 if (!result.isFallThrough()) {
310 return result.getValue();
311 }
312 }
313
314 if (_value == NotValue.RECURSION_WAS_DETECTED) {
315 RecursionDetectedResult<T> result = recursionDetected(/*firstTime = */ false);
316 if (!result.isFallThrough()) {
317 return result.getValue();
318 }
319 }
320
321 value = NotValue.COMPUTING;
322 try {
323 T typedValue = computable.invoke();
324 value = typedValue;
325 postCompute(typedValue);
326 return typedValue;
327 }
328 catch (Throwable throwable) {
329 if (value == NotValue.COMPUTING) {
330 // Store only if it's a genuine result, not something thrown through recursionDetected()
331 value = WrappedValues.escapeThrowable(throwable);
332 }
333 throw storageManager.exceptionHandlingStrategy.handleException(throwable);
334 }
335 }
336 finally {
337 storageManager.lock.unlock();
338 }
339 }
340
341 /**
342 * @param firstTime {@code true} when recursion has been just detected, {@code false} otherwise
343 * @return a value to be returned on a recursive call or subsequent calls
344 */
345 @NotNull
346 protected RecursionDetectedResult<T> recursionDetected(boolean firstTime) {
347 return storageManager.recursionDetectedDefault();
348 }
349
350 protected void postCompute(T value) {
351 // Doing something in post-compute helps prevent infinite recursion
352 }
353 }
354
355 private static class LockBasedNotNullLazyValue<T> extends LockBasedLazyValue<T> implements NotNullLazyValue<T> {
356
357 public LockBasedNotNullLazyValue(@NotNull LockBasedStorageManager storageManager, @NotNull Function0<? extends T> computable) {
358 super(storageManager, computable);
359 }
360
361 @Override
362 @NotNull
363 public T invoke() {
364 T result = super.invoke();
365 assert result != null : "compute() returned null";
366 return result;
367 }
368 }
369
370 private static class MapBasedMemoizedFunction<K, V> implements MemoizedFunctionToNullable<K, V> {
371 private final LockBasedStorageManager storageManager;
372 private final ConcurrentMap<K, Object> cache;
373 private final Function1<? super K, ? extends V> compute;
374
375 public MapBasedMemoizedFunction(
376 @NotNull LockBasedStorageManager storageManager,
377 @NotNull ConcurrentMap<K, Object> map,
378 @NotNull Function1<? super K, ? extends V> compute
379 ) {
380 this.storageManager = storageManager;
381 this.cache = map;
382 this.compute = compute;
383 }
384
385 @Override
386 @Nullable
387 public V invoke(K input) {
388 Object value = cache.get(input);
389 if (value != null && value != NotValue.COMPUTING) return WrappedValues.unescapeExceptionOrNull(value);
390
391 storageManager.lock.lock();
392 try {
393 value = cache.get(input);
394 if (value == NotValue.COMPUTING) {
395 throw recursionDetected(input);
396 }
397 if (value != null) return WrappedValues.unescapeExceptionOrNull(value);
398
399 AssertionError error = null;
400 try {
401 cache.put(input, NotValue.COMPUTING);
402 V typedValue = compute.invoke(input);
403 Object oldValue = cache.put(input, WrappedValues.escapeNull(typedValue));
404
405 // This code effectively asserts that oldValue is null
406 // The trickery is here because below we catch all exceptions thrown here, and this is the only exception that shouldn't be stored
407 // A seemingly obvious way to come about this case would be to declare a special exception class, but the problem is that
408 // one memoized function is likely to (indirectly) call another, and if this second one throws this exception, we are screwed
409 if (oldValue != NotValue.COMPUTING) {
410 error = raceCondition(input, oldValue);
411 throw error;
412 }
413
414 return typedValue;
415 }
416 catch (Throwable throwable) {
417 if (throwable == error) throw storageManager.exceptionHandlingStrategy.handleException(throwable);
418
419 Object oldValue = cache.put(input, WrappedValues.escapeThrowable(throwable));
420 if (oldValue != NotValue.COMPUTING) {
421 throw raceCondition(input, oldValue);
422 }
423
424 throw storageManager.exceptionHandlingStrategy.handleException(throwable);
425 }
426 }
427 finally {
428 storageManager.lock.unlock();
429 }
430 }
431
432 @NotNull
433 private AssertionError recursionDetected(K input) {
434 return sanitizeStackTrace(
435 new AssertionError("Recursion detected on input: " + input + " under " + storageManager)
436 );
437 }
438
439 @NotNull
440 private AssertionError raceCondition(K input, Object oldValue) {
441 return sanitizeStackTrace(
442 new AssertionError("Race condition detected on input " + input + ". Old value is " + oldValue +
443 " under " + storageManager)
444 );
445 }
446
447 @Override
448 public boolean isComputed(K key) {
449 Object value = cache.get(key);
450 return value != null && value != NotValue.COMPUTING;
451 }
452
453 protected LockBasedStorageManager getStorageManager() {
454 return storageManager;
455 }
456 }
457
458 private static class MapBasedMemoizedFunctionToNotNull<K, V> extends MapBasedMemoizedFunction<K, V> implements MemoizedFunctionToNotNull<K, V> {
459
460 public MapBasedMemoizedFunctionToNotNull(
461 @NotNull LockBasedStorageManager storageManager, @NotNull ConcurrentMap<K, Object> map,
462 @NotNull Function1<? super K, ? extends V> compute
463 ) {
464 super(storageManager, map, compute);
465 }
466
467 @NotNull
468 @Override
469 public V invoke(K input) {
470 V result = super.invoke(input);
471 assert result != null : "compute() returned null under " + getStorageManager();
472 return result;
473 }
474 }
475
476 @NotNull
477 public static LockBasedStorageManager createDelegatingWithSameLock(
478 @NotNull LockBasedStorageManager base,
479 @NotNull ExceptionHandlingStrategy newStrategy
480 ) {
481 return new LockBasedStorageManager(defaultDebugName(), newStrategy, base.lock);
482 }
483
484 @NotNull
485 private static <T extends Throwable> T sanitizeStackTrace(@NotNull T throwable) {
486 StackTraceElement[] stackTrace = throwable.getStackTrace();
487 int size = stackTrace.length;
488
489 int firstNonStorage = -1;
490 for (int i = 0; i < size; i++) {
491 // Skip everything (memoized functions and lazy values) from package org.jetbrains.kotlin.storage
492 if (!stackTrace[i].getClassName().startsWith(PACKAGE_NAME)) {
493 firstNonStorage = i;
494 break;
495 }
496 }
497 assert firstNonStorage >= 0 : "This method should only be called on exceptions created in LockBasedStorageManager";
498
499 List<StackTraceElement> list = Arrays.asList(stackTrace).subList(firstNonStorage, size);
500 throwable.setStackTrace(list.toArray(new StackTraceElement[list.size()]));
501 return throwable;
502 }
503
504 @NotNull
505 @Override
506 public <K, V> CacheWithNullableValues<K, V> createCacheWithNullableValues() {
507 return new CacheWithNullableValuesBasedOnMemoizedFunction<K, V>(
508 this, LockBasedStorageManager.<KeyWithComputation<K,V>>createConcurrentHashMap());
509 }
510
511 private static class CacheWithNullableValuesBasedOnMemoizedFunction<K, V> extends MapBasedMemoizedFunction<KeyWithComputation<K, V>, V> implements CacheWithNullableValues<K, V> {
512
513 private CacheWithNullableValuesBasedOnMemoizedFunction(
514 @NotNull LockBasedStorageManager storageManager,
515 @NotNull ConcurrentMap<KeyWithComputation<K, V>, Object> map
516 ) {
517 super(storageManager, map, new Function1<KeyWithComputation<K, V>, V>() {
518 @Override
519 public V invoke(KeyWithComputation<K, V> computation) {
520 return computation.computation.invoke();
521 }
522 });
523 }
524
525 @Nullable
526 @Override
527 public V computeIfAbsent(K key, @NotNull Function0<? extends V> computation) {
528 return invoke(new KeyWithComputation<K, V>(key, computation));
529 }
530 }
531
532 @NotNull
533 @Override
534 public <K, V> CacheWithNotNullValues<K, V> createCacheWithNotNullValues() {
535 return new CacheWithNotNullValuesBasedOnMemoizedFunction<K, V>(this, LockBasedStorageManager.<KeyWithComputation<K,V>>createConcurrentHashMap());
536 }
537
538 private static class CacheWithNotNullValuesBasedOnMemoizedFunction<K, V> extends CacheWithNullableValuesBasedOnMemoizedFunction<K, V> implements CacheWithNotNullValues<K, V> {
539
540 private CacheWithNotNullValuesBasedOnMemoizedFunction(
541 @NotNull LockBasedStorageManager storageManager,
542 @NotNull ConcurrentMap<KeyWithComputation<K, V>, Object> map
543 ) {
544 super(storageManager, map);
545 }
546
547 @NotNull
548 @Override
549 public V computeIfAbsent(K key, @NotNull Function0<? extends V> computation) {
550 V result = super.computeIfAbsent(key, computation);
551 assert result != null : "computeIfAbsent() returned null under " + getStorageManager();
552 return result;
553 }
554 }
555
556 // equals and hashCode use only key
557 private static class KeyWithComputation<K, V> {
558 private final K key;
559 private final Function0<? extends V> computation;
560
561 public KeyWithComputation(K key, Function0<? extends V> computation) {
562 this.key = key;
563 this.computation = computation;
564 }
565
566 @Override
567 public boolean equals(Object o) {
568 if (this == o) return true;
569 if (o == null || getClass() != o.getClass()) return false;
570
571 KeyWithComputation<?, ?> that = (KeyWithComputation<?, ?>) o;
572
573 if (!key.equals(that.key)) return false;
574
575 return true;
576 }
577
578 @Override
579 public int hashCode() {
580 return key.hashCode();
581 }
582 }
583 }