001 /*
002 * Copyright 2010-2013 JetBrains s.r.o.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 * http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017 package org.jetbrains.jet.storage;
018
019 import kotlin.Function0;
020 import kotlin.Function1;
021 import kotlin.Unit;
022 import org.jetbrains.annotations.NotNull;
023 import org.jetbrains.annotations.Nullable;
024 import org.jetbrains.jet.utils.UtilsPackage;
025 import org.jetbrains.jet.utils.WrappedValues;
026
027 import java.util.concurrent.ConcurrentHashMap;
028 import java.util.concurrent.ConcurrentMap;
029 import java.util.concurrent.locks.Lock;
030 import java.util.concurrent.locks.ReentrantLock;
031
032 public class LockBasedStorageManager implements StorageManager {
033
034 public interface ExceptionHandlingStrategy {
035 ExceptionHandlingStrategy THROW = new ExceptionHandlingStrategy() {
036 @NotNull
037 @Override
038 public RuntimeException handleException(@NotNull Throwable throwable) {
039 throw UtilsPackage.rethrow(throwable);
040 }
041 };
042
043 /*
044 * The signature of this method is a trick: it is used as
045 *
046 * throw strategy.handleException(...)
047 *
048 * most implementations of this method throw exceptions themselves, so it does not matter what they return
049 */
050 @NotNull
051 RuntimeException handleException(@NotNull Throwable throwable);
052 }
053
054 public static final StorageManager NO_LOCKS = new LockBasedStorageManager("NO_LOCKS", ExceptionHandlingStrategy.THROW, NoLock.INSTANCE) {
055 @NotNull
056 @Override
057 protected <T> RecursionDetectedResult<T> recursionDetectedDefault() {
058 return RecursionDetectedResult.fallThrough();
059 }
060 };
061
062 @NotNull
063 public static LockBasedStorageManager createWithExceptionHandling(@NotNull ExceptionHandlingStrategy exceptionHandlingStrategy) {
064 return new LockBasedStorageManager(exceptionHandlingStrategy);
065 }
066
067 protected final Lock lock;
068 private final ExceptionHandlingStrategy exceptionHandlingStrategy;
069 private final String debugText;
070
071 private LockBasedStorageManager(
072 @NotNull String debugText,
073 @NotNull ExceptionHandlingStrategy exceptionHandlingStrategy,
074 @NotNull Lock lock
075 ) {
076 this.lock = lock;
077 this.exceptionHandlingStrategy = exceptionHandlingStrategy;
078 this.debugText = debugText;
079 }
080
081 public LockBasedStorageManager() {
082 this(getPointOfConstruction(), ExceptionHandlingStrategy.THROW, new ReentrantLock());
083 }
084
085 protected LockBasedStorageManager(@NotNull ExceptionHandlingStrategy exceptionHandlingStrategy) {
086 this(getPointOfConstruction(), exceptionHandlingStrategy, new ReentrantLock());
087 }
088
089 private static String getPointOfConstruction() {
090 StackTraceElement[] trace = Thread.currentThread().getStackTrace();
091 // we need to skip frames for getStackTrace(), this method and the constructor that's calling it
092 if (trace.length <= 3) return "<unknown creating class>";
093 return trace[3].toString();
094 }
095
096 @Override
097 public String toString() {
098 return getClass().getSimpleName() + "@" + Integer.toHexString(hashCode()) + " (" + debugText + ")";
099 }
100
101 @NotNull
102 @Override
103 public <K, V> MemoizedFunctionToNotNull<K, V> createMemoizedFunction(@NotNull Function1<? super K, ? extends V> compute) {
104 return createMemoizedFunction(compute, new ConcurrentHashMap<K, Object>());
105 }
106
107 @NotNull
108 protected <K, V> MemoizedFunctionToNotNull<K, V> createMemoizedFunction(
109 @NotNull Function1<? super K, ? extends V> compute,
110 @NotNull ConcurrentMap<K, Object> map
111 ) {
112 return new MapBasedMemoizedFunctionToNotNull<K, V>(map, compute);
113 }
114
115 @NotNull
116 @Override
117 public <K, V> MemoizedFunctionToNullable<K, V> createMemoizedFunctionWithNullableValues(@NotNull Function1<? super K, ? extends V> compute) {
118 return createMemoizedFunctionWithNullableValues(compute, new ConcurrentHashMap<K, Object>());
119 }
120
121 @NotNull
122 protected <K, V> MemoizedFunctionToNullable<K, V> createMemoizedFunctionWithNullableValues(
123 @NotNull Function1<? super K, ? extends V> compute,
124 @NotNull ConcurrentMap<K, Object> map
125 ) {
126 return new MapBasedMemoizedFunction<K, V>(map, compute);
127 }
128
129 @NotNull
130 @Override
131 public <T> NotNullLazyValue<T> createLazyValue(@NotNull Function0<? extends T> computable) {
132 return new LockBasedNotNullLazyValue<T>(computable);
133 }
134
135 @NotNull
136 @Override
137 public <T> NotNullLazyValue<T> createRecursionTolerantLazyValue(
138 @NotNull Function0<? extends T> computable, @NotNull final T onRecursiveCall
139 ) {
140 return new LockBasedNotNullLazyValue<T>(computable) {
141 @NotNull
142 @Override
143 protected RecursionDetectedResult<T> recursionDetected(boolean firstTime) {
144 return RecursionDetectedResult.value(onRecursiveCall);
145 }
146 };
147 }
148
149 @NotNull
150 @Override
151 public <T> NotNullLazyValue<T> createLazyValueWithPostCompute(
152 @NotNull Function0<? extends T> computable,
153 final Function1<? super Boolean, ? extends T> onRecursiveCall,
154 @NotNull final Function1<? super T, ? extends Unit> postCompute
155 ) {
156 return new LockBasedNotNullLazyValue<T>(computable) {
157 @NotNull
158 @Override
159 protected RecursionDetectedResult<T> recursionDetected(boolean firstTime) {
160 if (onRecursiveCall == null) {
161 return super.recursionDetected(firstTime);
162 }
163 return RecursionDetectedResult.value(onRecursiveCall.invoke(firstTime));
164 }
165
166 @Override
167 protected void postCompute(@NotNull T value) {
168 postCompute.invoke(value);
169 }
170 };
171 }
172
173 @NotNull
174 @Override
175 public <T> NullableLazyValue<T> createNullableLazyValue(@NotNull Function0<? extends T> computable) {
176 return new LockBasedLazyValue<T>(computable);
177 }
178
179 @NotNull
180 @Override
181 public <T> NullableLazyValue<T> createRecursionTolerantNullableLazyValue(@NotNull Function0<? extends T> computable, final T onRecursiveCall) {
182 return new LockBasedLazyValue<T>(computable) {
183 @NotNull
184 @Override
185 protected RecursionDetectedResult<T> recursionDetected(boolean firstTime) {
186 return RecursionDetectedResult.value(onRecursiveCall);
187 }
188 };
189 }
190
191 @NotNull
192 @Override
193 public <T> NullableLazyValue<T> createNullableLazyValueWithPostCompute(
194 @NotNull Function0<? extends T> computable, @NotNull final Function1<? super T, ? extends Unit> postCompute
195 ) {
196 return new LockBasedLazyValue<T>(computable) {
197 @Override
198 protected void postCompute(@Nullable T value) {
199 postCompute.invoke(value);
200 }
201 };
202 }
203
204 @Override
205 public <T> T compute(@NotNull Function0<? extends T> computable) {
206 lock.lock();
207 try {
208 return computable.invoke();
209 }
210 finally {
211 lock.unlock();
212 }
213 }
214
215 @NotNull
216 protected <T> RecursionDetectedResult<T> recursionDetectedDefault() {
217 throw new IllegalStateException("Recursive call in a lazy value under " + this);
218 }
219
220 private static class RecursionDetectedResult<T> {
221
222 @NotNull
223 public static <T> RecursionDetectedResult<T> value(T value) {
224 return new RecursionDetectedResult<T>(value, false);
225 }
226
227 @NotNull
228 public static <T> RecursionDetectedResult<T> fallThrough() {
229 return new RecursionDetectedResult<T>(null, true);
230 }
231
232 private final T value;
233 private final boolean fallThrough;
234
235 private RecursionDetectedResult(T value, boolean fallThrough) {
236 this.value = value;
237 this.fallThrough = fallThrough;
238 }
239
240 public T getValue() {
241 assert !fallThrough : "A value requested from FALL_THROUGH in " + this;
242 return value;
243 }
244
245 public boolean isFallThrough() {
246 return fallThrough;
247 }
248
249 @Override
250 public String toString() {
251 return isFallThrough() ? "FALL_THROUGH" : String.valueOf(value);
252 }
253 }
254
255 private enum NotValue {
256 NOT_COMPUTED,
257 COMPUTING,
258 RECURSION_WAS_DETECTED
259 }
260
261 private class LockBasedLazyValue<T> implements NullableLazyValue<T> {
262
263 private final Function0<? extends T> computable;
264
265 @Nullable
266 private volatile Object value = NotValue.NOT_COMPUTED;
267
268 public LockBasedLazyValue(@NotNull Function0<? extends T> computable) {
269 this.computable = computable;
270 }
271
272 @Override
273 public boolean isComputed() {
274 return value != NotValue.NOT_COMPUTED && value != NotValue.COMPUTING;
275 }
276
277 @Override
278 public T invoke() {
279 Object _value = value;
280 if (!(_value instanceof NotValue)) return WrappedValues.unescapeThrowable(_value);
281
282 lock.lock();
283 try {
284 _value = value;
285 if (!(_value instanceof NotValue)) return WrappedValues.unescapeThrowable(_value);
286
287 if (_value == NotValue.COMPUTING) {
288 value = NotValue.RECURSION_WAS_DETECTED;
289 RecursionDetectedResult<T> result = recursionDetected(/*firstTime = */ true);
290 if (!result.isFallThrough()) {
291 return result.getValue();
292 }
293 }
294
295 if (_value == NotValue.RECURSION_WAS_DETECTED) {
296 RecursionDetectedResult<T> result = recursionDetected(/*firstTime = */ false);
297 if (!result.isFallThrough()) {
298 return result.getValue();
299 }
300 }
301
302 value = NotValue.COMPUTING;
303 try {
304 T typedValue = computable.invoke();
305 value = typedValue;
306 postCompute(typedValue);
307 return typedValue;
308 }
309 catch (Throwable throwable) {
310 if (value == NotValue.COMPUTING) {
311 // Store only if it's a genuine result, not something thrown through recursionDetected()
312 value = WrappedValues.escapeThrowable(throwable);
313 }
314 throw exceptionHandlingStrategy.handleException(throwable);
315 }
316 }
317 finally {
318 lock.unlock();
319 }
320 }
321
322 /**
323 * @param firstTime {@code true} when recursion has been just detected, {@code false} otherwise
324 * @return a value to be returned on a recursive call or subsequent calls
325 */
326 @NotNull
327 protected RecursionDetectedResult<T> recursionDetected(boolean firstTime) {
328 return recursionDetectedDefault();
329 }
330
331 protected void postCompute(T value) {
332 // Doing something in post-compute helps prevent infinite recursion
333 }
334 }
335
336 private class LockBasedNotNullLazyValue<T> extends LockBasedLazyValue<T> implements NotNullLazyValue<T> {
337
338 public LockBasedNotNullLazyValue(@NotNull Function0<? extends T> computable) {
339 super(computable);
340 }
341
342 @Override
343 @NotNull
344 public T invoke() {
345 T result = super.invoke();
346 assert result != null : "compute() returned null";
347 return result;
348 }
349 }
350
351 private class MapBasedMemoizedFunction<K, V> implements MemoizedFunctionToNullable<K, V> {
352 private final ConcurrentMap<K, Object> cache;
353 private final Function1<? super K, ? extends V> compute;
354
355 public MapBasedMemoizedFunction(@NotNull ConcurrentMap<K, Object> map, @NotNull Function1<? super K, ? extends V> compute) {
356 this.cache = map;
357 this.compute = compute;
358 }
359
360 @Override
361 @Nullable
362 public V invoke(K input) {
363 Object value = cache.get(input);
364 if (value != null && value != NotValue.COMPUTING) return WrappedValues.unescapeExceptionOrNull(value);
365
366 lock.lock();
367 try {
368 value = cache.get(input);
369 assert value != NotValue.COMPUTING : "Recursion detected on input: " + input + " under " + LockBasedStorageManager.this;
370 if (value != null) return WrappedValues.unescapeExceptionOrNull(value);
371
372 AssertionError error = null;
373 try {
374 cache.put(input, NotValue.COMPUTING);
375 V typedValue = compute.invoke(input);
376 Object oldValue = cache.put(input, WrappedValues.escapeNull(typedValue));
377
378 // This code effectively asserts that oldValue is null
379 // The trickery is here because below we catch all exceptions thrown here, and this is the only exception that shouldn't be stored
380 // A seemingly obvious way to come about this case would be to declare a special exception class, but the problem is that
381 // one memoized function is likely to (indirectly) call another, and if this second one throws this exception, we are screwed
382 if (oldValue != NotValue.COMPUTING) {
383 error = new AssertionError("Race condition detected on input " + input + ". Old value is " + oldValue +
384 " under " + LockBasedStorageManager.this);
385 throw error;
386 }
387
388 return typedValue;
389 }
390 catch (Throwable throwable) {
391 if (throwable == error) throw exceptionHandlingStrategy.handleException(throwable);
392
393 Object oldValue = cache.put(input, WrappedValues.escapeThrowable(throwable));
394 assert oldValue == NotValue.COMPUTING : "Race condition detected on input " + input + ". Old value is " + oldValue +
395 " under " + LockBasedStorageManager.this;
396
397 throw exceptionHandlingStrategy.handleException(throwable);
398 }
399 }
400 finally {
401 lock.unlock();
402 }
403 }
404 }
405
406 private class MapBasedMemoizedFunctionToNotNull<K, V> extends MapBasedMemoizedFunction<K, V> implements MemoizedFunctionToNotNull<K, V> {
407
408 public MapBasedMemoizedFunctionToNotNull(
409 @NotNull ConcurrentMap<K, Object> map,
410 @NotNull Function1<? super K, ? extends V> compute
411 ) {
412 super(map, compute);
413 }
414
415 @NotNull
416 @Override
417 public V invoke(K input) {
418 V result = super.invoke(input);
419 assert result != null : "compute() returned null under " + LockBasedStorageManager.this;
420 return result;
421 }
422 }
423
424 }