001 /*
002 * Copyright 2010-2013 JetBrains s.r.o.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 * http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017 package org.jetbrains.jet.storage;
018
019 import kotlin.Function0;
020 import kotlin.Function1;
021 import kotlin.Unit;
022 import org.jetbrains.annotations.NotNull;
023 import org.jetbrains.annotations.Nullable;
024 import org.jetbrains.jet.utils.UtilsPackage;
025 import org.jetbrains.jet.utils.WrappedValues;
026
027 import java.util.concurrent.ConcurrentHashMap;
028 import java.util.concurrent.ConcurrentMap;
029 import java.util.concurrent.locks.Lock;
030 import java.util.concurrent.locks.ReentrantLock;
031
032 public class LockBasedStorageManager implements StorageManager {
033
034 public interface ExceptionHandlingStrategy {
035 ExceptionHandlingStrategy THROW = new ExceptionHandlingStrategy() {
036 @NotNull
037 @Override
038 public RuntimeException handleException(@NotNull Throwable throwable) {
039 throw UtilsPackage.rethrow(throwable);
040 }
041 };
042
043 /*
044 * The signature of this method is a trick: it is used as
045 *
046 * throw strategy.handleException(...)
047 *
048 * most implementations of this method throw exceptions themselves, so it does not matter what they return
049 */
050 @NotNull
051 RuntimeException handleException(@NotNull Throwable throwable);
052 }
053
054 public static final StorageManager NO_LOCKS = new LockBasedStorageManager("NO_LOCKS", ExceptionHandlingStrategy.THROW, NoLock.INSTANCE) {
055 @NotNull
056 @Override
057 protected <T> RecursionDetectedResult<T> recursionDetectedDefault() {
058 return RecursionDetectedResult.fallThrough();
059 }
060 };
061
062 @NotNull
063 public static LockBasedStorageManager createWithExceptionHandling(@NotNull ExceptionHandlingStrategy exceptionHandlingStrategy) {
064 return new LockBasedStorageManager(exceptionHandlingStrategy);
065 }
066
067 protected final Lock lock;
068 private final ExceptionHandlingStrategy exceptionHandlingStrategy;
069 private final String debugText;
070
071 private LockBasedStorageManager(
072 @NotNull String debugText,
073 @NotNull ExceptionHandlingStrategy exceptionHandlingStrategy,
074 @NotNull Lock lock
075 ) {
076 this.lock = lock;
077 this.exceptionHandlingStrategy = exceptionHandlingStrategy;
078 this.debugText = debugText;
079 }
080
081 public LockBasedStorageManager() {
082 this(getPointOfConstruction(), ExceptionHandlingStrategy.THROW, new ReentrantLock());
083 }
084
085 protected LockBasedStorageManager(@NotNull ExceptionHandlingStrategy exceptionHandlingStrategy) {
086 this(getPointOfConstruction(), exceptionHandlingStrategy, new ReentrantLock());
087 }
088
089 private static String getPointOfConstruction() {
090 StackTraceElement[] trace = Thread.currentThread().getStackTrace();
091 // we need to skip frames for getStackTrace(), this method and the constructor that's calling it
092 if (trace.length <= 3) return "<unknown creating class>";
093 return trace[3].toString();
094 }
095
096 @Override
097 public String toString() {
098 return getClass().getSimpleName() + "@" + Integer.toHexString(hashCode()) + " (" + debugText + ")";
099 }
100
101 @NotNull
102 @Override
103 public <K, V> MemoizedFunctionToNotNull<K, V> createMemoizedFunction(@NotNull Function1<? super K, ? extends V> compute) {
104 return createMemoizedFunction(compute, new ConcurrentHashMap<K, Object>());
105 }
106
107 @NotNull
108 protected <K, V> MemoizedFunctionToNotNull<K, V> createMemoizedFunction(
109 @NotNull Function1<? super K, ? extends V> compute,
110 @NotNull ConcurrentMap<K, Object> map
111 ) {
112 return new MapBasedMemoizedFunctionToNotNull<K, V>(map, compute);
113 }
114
115 @NotNull
116 @Override
117 public <K, V> MemoizedFunctionToNullable<K, V> createMemoizedFunctionWithNullableValues(@NotNull Function1<? super K, ? extends V> compute) {
118 return createMemoizedFunctionWithNullableValues(compute, new ConcurrentHashMap<K, Object>());
119 }
120
121 @NotNull
122 protected <K, V> MemoizedFunctionToNullable<K, V> createMemoizedFunctionWithNullableValues(
123 @NotNull Function1<? super K, ? extends V> compute,
124 @NotNull ConcurrentMap<K, Object> map
125 ) {
126 return new MapBasedMemoizedFunction<K, V>(map, compute);
127 }
128
129 @NotNull
130 @Override
131 public <T> NotNullLazyValue<T> createLazyValue(@NotNull Function0<? extends T> computable) {
132 return new LockBasedNotNullLazyValue<T>(computable);
133 }
134
135 @NotNull
136 @Override
137 public <T> NotNullLazyValue<T> createRecursionTolerantLazyValue(
138 @NotNull Function0<? extends T> computable, @NotNull final T onRecursiveCall
139 ) {
140 return new LockBasedNotNullLazyValue<T>(computable) {
141 @NotNull
142 @Override
143 protected RecursionDetectedResult<T> recursionDetected(boolean firstTime) {
144 return RecursionDetectedResult.value(onRecursiveCall);
145 }
146 };
147 }
148
149 @NotNull
150 @Override
151 public <T> NotNullLazyValue<T> createLazyValueWithPostCompute(
152 @NotNull Function0<? extends T> computable,
153 final Function1<? super Boolean, ? extends T> onRecursiveCall,
154 @NotNull final Function1<? super T, ? extends Unit> postCompute
155 ) {
156 return new LockBasedNotNullLazyValue<T>(computable) {
157 @NotNull
158 @Override
159 protected RecursionDetectedResult<T> recursionDetected(boolean firstTime) {
160 if (onRecursiveCall == null) {
161 return super.recursionDetected(firstTime);
162 }
163 return RecursionDetectedResult.value(onRecursiveCall.invoke(firstTime));
164 }
165
166 @Override
167 protected void postCompute(@NotNull T value) {
168 postCompute.invoke(value);
169 }
170 };
171 }
172
173 @NotNull
174 @Override
175 public <T> NullableLazyValue<T> createNullableLazyValue(@NotNull Function0<? extends T> computable) {
176 return new LockBasedLazyValue<T>(computable);
177 }
178
179 @NotNull
180 @Override
181 public <T> NullableLazyValue<T> createRecursionTolerantNullableLazyValue(@NotNull Function0<? extends T> computable, final T onRecursiveCall) {
182 return new LockBasedLazyValue<T>(computable) {
183 @NotNull
184 @Override
185 protected RecursionDetectedResult<T> recursionDetected(boolean firstTime) {
186 return RecursionDetectedResult.value(onRecursiveCall);
187 }
188 };
189 }
190
191 @NotNull
192 @Override
193 public <T> NullableLazyValue<T> createNullableLazyValueWithPostCompute(
194 @NotNull Function0<? extends T> computable, @NotNull final Function1<? super T, ? extends Unit> postCompute
195 ) {
196 return new LockBasedLazyValue<T>(computable) {
197 @Override
198 protected void postCompute(@Nullable T value) {
199 postCompute.invoke(value);
200 }
201 };
202 }
203
204 @Override
205 public <T> T compute(@NotNull Function0<? extends T> computable) {
206 lock.lock();
207 try {
208 return computable.invoke();
209 }
210 catch (Throwable throwable) {
211 throw exceptionHandlingStrategy.handleException(throwable);
212 }
213 finally {
214 lock.unlock();
215 }
216 }
217
218 @NotNull
219 protected <T> RecursionDetectedResult<T> recursionDetectedDefault() {
220 throw new IllegalStateException("Recursive call in a lazy value under " + this);
221 }
222
223 private static class RecursionDetectedResult<T> {
224
225 @NotNull
226 public static <T> RecursionDetectedResult<T> value(T value) {
227 return new RecursionDetectedResult<T>(value, false);
228 }
229
230 @NotNull
231 public static <T> RecursionDetectedResult<T> fallThrough() {
232 return new RecursionDetectedResult<T>(null, true);
233 }
234
235 private final T value;
236 private final boolean fallThrough;
237
238 private RecursionDetectedResult(T value, boolean fallThrough) {
239 this.value = value;
240 this.fallThrough = fallThrough;
241 }
242
243 public T getValue() {
244 assert !fallThrough : "A value requested from FALL_THROUGH in " + this;
245 return value;
246 }
247
248 public boolean isFallThrough() {
249 return fallThrough;
250 }
251
252 @Override
253 public String toString() {
254 return isFallThrough() ? "FALL_THROUGH" : String.valueOf(value);
255 }
256 }
257
258 private enum NotValue {
259 NOT_COMPUTED,
260 COMPUTING,
261 RECURSION_WAS_DETECTED
262 }
263
264 private class LockBasedLazyValue<T> implements NullableLazyValue<T> {
265
266 private final Function0<? extends T> computable;
267
268 @Nullable
269 private volatile Object value = NotValue.NOT_COMPUTED;
270
271 public LockBasedLazyValue(@NotNull Function0<? extends T> computable) {
272 this.computable = computable;
273 }
274
275 @Override
276 public boolean isComputed() {
277 return value != NotValue.NOT_COMPUTED && value != NotValue.COMPUTING;
278 }
279
280 @Override
281 public T invoke() {
282 Object _value = value;
283 if (!(_value instanceof NotValue)) return WrappedValues.unescapeThrowable(_value);
284
285 lock.lock();
286 try {
287 _value = value;
288 if (!(_value instanceof NotValue)) return WrappedValues.unescapeThrowable(_value);
289
290 if (_value == NotValue.COMPUTING) {
291 value = NotValue.RECURSION_WAS_DETECTED;
292 RecursionDetectedResult<T> result = recursionDetected(/*firstTime = */ true);
293 if (!result.isFallThrough()) {
294 return result.getValue();
295 }
296 }
297
298 if (_value == NotValue.RECURSION_WAS_DETECTED) {
299 RecursionDetectedResult<T> result = recursionDetected(/*firstTime = */ false);
300 if (!result.isFallThrough()) {
301 return result.getValue();
302 }
303 }
304
305 value = NotValue.COMPUTING;
306 try {
307 T typedValue = computable.invoke();
308 value = typedValue;
309 postCompute(typedValue);
310 return typedValue;
311 }
312 catch (Throwable throwable) {
313 if (value == NotValue.COMPUTING) {
314 // Store only if it's a genuine result, not something thrown through recursionDetected()
315 value = WrappedValues.escapeThrowable(throwable);
316 }
317 throw exceptionHandlingStrategy.handleException(throwable);
318 }
319 }
320 finally {
321 lock.unlock();
322 }
323 }
324
325 /**
326 * @param firstTime {@code true} when recursion has been just detected, {@code false} otherwise
327 * @return a value to be returned on a recursive call or subsequent calls
328 */
329 @NotNull
330 protected RecursionDetectedResult<T> recursionDetected(boolean firstTime) {
331 return recursionDetectedDefault();
332 }
333
334 protected void postCompute(T value) {
335 // Doing something in post-compute helps prevent infinite recursion
336 }
337 }
338
339 private class LockBasedNotNullLazyValue<T> extends LockBasedLazyValue<T> implements NotNullLazyValue<T> {
340
341 public LockBasedNotNullLazyValue(@NotNull Function0<? extends T> computable) {
342 super(computable);
343 }
344
345 @Override
346 @NotNull
347 public T invoke() {
348 T result = super.invoke();
349 assert result != null : "compute() returned null";
350 return result;
351 }
352 }
353
354 private class MapBasedMemoizedFunction<K, V> implements MemoizedFunctionToNullable<K, V> {
355 private final ConcurrentMap<K, Object> cache;
356 private final Function1<? super K, ? extends V> compute;
357
358 public MapBasedMemoizedFunction(@NotNull ConcurrentMap<K, Object> map, @NotNull Function1<? super K, ? extends V> compute) {
359 this.cache = map;
360 this.compute = compute;
361 }
362
363 @Override
364 @Nullable
365 public V invoke(K input) {
366 Object value = cache.get(input);
367 if (value != null && value != NotValue.COMPUTING) return WrappedValues.unescapeExceptionOrNull(value);
368
369 lock.lock();
370 try {
371 value = cache.get(input);
372 assert value != NotValue.COMPUTING : "Recursion detected on input: " + input + " under " + LockBasedStorageManager.this;
373 if (value != null) return WrappedValues.unescapeExceptionOrNull(value);
374
375 AssertionError error = null;
376 try {
377 cache.put(input, NotValue.COMPUTING);
378 V typedValue = compute.invoke(input);
379 Object oldValue = cache.put(input, WrappedValues.escapeNull(typedValue));
380
381 // This code effectively asserts that oldValue is null
382 // The trickery is here because below we catch all exceptions thrown here, and this is the only exception that shouldn't be stored
383 // A seemingly obvious way to come about this case would be to declare a special exception class, but the problem is that
384 // one memoized function is likely to (indirectly) call another, and if this second one throws this exception, we are screwed
385 if (oldValue != NotValue.COMPUTING) {
386 error = new AssertionError("Race condition detected on input " + input + ". Old value is " + oldValue +
387 " under " + LockBasedStorageManager.this);
388 throw error;
389 }
390
391 return typedValue;
392 }
393 catch (Throwable throwable) {
394 if (throwable == error) throw exceptionHandlingStrategy.handleException(throwable);
395
396 Object oldValue = cache.put(input, WrappedValues.escapeThrowable(throwable));
397 assert oldValue == NotValue.COMPUTING : "Race condition detected on input " + input + ". Old value is " + oldValue +
398 " under " + LockBasedStorageManager.this;
399
400 throw exceptionHandlingStrategy.handleException(throwable);
401 }
402 }
403 finally {
404 lock.unlock();
405 }
406 }
407 }
408
409 private class MapBasedMemoizedFunctionToNotNull<K, V> extends MapBasedMemoizedFunction<K, V> implements MemoizedFunctionToNotNull<K, V> {
410
411 public MapBasedMemoizedFunctionToNotNull(
412 @NotNull ConcurrentMap<K, Object> map,
413 @NotNull Function1<? super K, ? extends V> compute
414 ) {
415 super(map, compute);
416 }
417
418 @NotNull
419 @Override
420 public V invoke(K input) {
421 V result = super.invoke(input);
422 assert result != null : "compute() returned null under " + LockBasedStorageManager.this;
423 return result;
424 }
425 }
426
427 @NotNull
428 public static LockBasedStorageManager createDelegatingWithSameLock(
429 @NotNull LockBasedStorageManager base,
430 @NotNull ExceptionHandlingStrategy newStrategy
431 ) {
432 return new LockBasedStorageManager(getPointOfConstruction(), newStrategy, base.lock);
433 }
434 }