001 /*
002 * Copyright 2010-2013 JetBrains s.r.o.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 * http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017 package org.jetbrains.jet.storage;
018
019 import jet.Function0;
020 import jet.Function1;
021 import jet.Unit;
022 import org.jetbrains.annotations.NotNull;
023 import org.jetbrains.annotations.Nullable;
024 import org.jetbrains.jet.utils.ExceptionUtils;
025 import org.jetbrains.jet.utils.WrappedValues;
026
027 import java.util.concurrent.ConcurrentHashMap;
028 import java.util.concurrent.ConcurrentMap;
029 import java.util.concurrent.locks.Lock;
030 import java.util.concurrent.locks.ReentrantLock;
031
032 public class LockBasedStorageManager implements StorageManager {
033
034 public static final StorageManager NO_LOCKS = new LockBasedStorageManager(NoLock.INSTANCE) {
035 @Override
036 public String toString() {
037 return "NO_LOCKS";
038 }
039 };
040
041 protected final Lock lock;
042
043 public LockBasedStorageManager() {
044 this(new ReentrantLock());
045 }
046
047 private LockBasedStorageManager(@NotNull Lock lock) {
048 this.lock = lock;
049 }
050
051 @NotNull
052 @Override
053 public <K, V> MemoizedFunctionToNotNull<K, V> createMemoizedFunction(@NotNull Function1<K, V> compute) {
054 return createMemoizedFunction(compute, new ConcurrentHashMap<K, Object>());
055 }
056
057 @NotNull
058 protected <K, V> MemoizedFunctionToNotNull<K, V> createMemoizedFunction(
059 @NotNull Function1<K, V> compute,
060 @NotNull ConcurrentMap<K, Object> map
061 ) {
062 return new MapBasedMemoizedFunctionToNotNull<K, V>(lock, map, compute);
063 }
064
065 @NotNull
066 @Override
067 public <K, V> MemoizedFunctionToNullable<K, V> createMemoizedFunctionWithNullableValues(@NotNull Function1<K, V> compute) {
068 return createMemoizedFunctionWithNullableValues(compute, new ConcurrentHashMap<K, Object>());
069 }
070
071 @NotNull
072 protected <K, V> MemoizedFunctionToNullable<K, V> createMemoizedFunctionWithNullableValues(
073 @NotNull Function1<K, V> compute,
074 @NotNull ConcurrentMap<K, Object> map
075 ) {
076 return new MapBasedMemoizedFunction<K, V>(lock, map, compute);
077 }
078
079 @NotNull
080 @Override
081 public <T> NotNullLazyValue<T> createLazyValue(@NotNull Function0<T> computable) {
082 return new LockBasedNotNullLazyValue<T>(lock, computable);
083 }
084
085 @NotNull
086 @Override
087 public <T> NotNullLazyValue<T> createRecursionTolerantLazyValue(
088 @NotNull Function0<T> computable, @NotNull final T onRecursiveCall
089 ) {
090 return new LockBasedNotNullLazyValue<T>(lock, computable) {
091 @Override
092 protected T recursionDetected(boolean firstTime) {
093 return onRecursiveCall;
094 }
095 };
096 }
097
098 @NotNull
099 @Override
100 public <T> NotNullLazyValue<T> createLazyValueWithPostCompute(
101 @NotNull Function0<T> computable,
102 final Function1<Boolean, T> onRecursiveCall,
103 @NotNull final Function1<T, Unit> postCompute
104 ) {
105 return new LockBasedNotNullLazyValue<T>(lock, computable) {
106 @Nullable
107 @Override
108 protected T recursionDetected(boolean firstTime) {
109 if (onRecursiveCall == null) {
110 return super.recursionDetected(firstTime);
111 }
112 return onRecursiveCall.invoke(firstTime);
113 }
114
115 @Override
116 protected void postCompute(@NotNull T value) {
117 postCompute.invoke(value);
118 }
119 };
120 }
121
122 @NotNull
123 @Override
124 public <T> NullableLazyValue<T> createNullableLazyValue(@NotNull Function0<T> computable) {
125 return new LockBasedLazyValue<T>(lock, computable);
126 }
127
128 @NotNull
129 @Override
130 public <T> NullableLazyValue<T> createRecursionTolerantNullableLazyValue(@NotNull Function0<T> computable, final T onRecursiveCall) {
131 return new LockBasedLazyValue<T>(lock, computable) {
132 @Override
133 protected T recursionDetected(boolean firstTime) {
134 return onRecursiveCall;
135 }
136 };
137 }
138
139 @NotNull
140 @Override
141 public <T> NullableLazyValue<T> createNullableLazyValueWithPostCompute(
142 @NotNull Function0<T> computable, @NotNull final Function1<T, Unit> postCompute
143 ) {
144 return new LockBasedLazyValue<T>(lock, computable) {
145 @Override
146 protected void postCompute(@Nullable T value) {
147 postCompute.invoke(value);
148 }
149 };
150 }
151
152 @Override
153 public <T> T compute(@NotNull Function0<T> computable) {
154 lock.lock();
155 try {
156 return computable.invoke();
157 }
158 finally {
159 lock.unlock();
160 }
161 }
162
163 private static class LockBasedLazyValue<T> implements NullableLazyValue<T> {
164
165 private enum NotValue {
166 NOT_COMPUTED,
167 COMPUTING,
168 RECURSION_WAS_DETECTED
169 }
170
171 private final Lock lock;
172 private final Function0<T> computable;
173
174 @Nullable
175 private volatile Object value = NotValue.NOT_COMPUTED;
176
177 public LockBasedLazyValue(@NotNull Lock lock, @NotNull Function0<T> computable) {
178 this.lock = lock;
179 this.computable = computable;
180 }
181
182 @Override
183 public boolean isComputed() {
184 return value != NotValue.NOT_COMPUTED && value != NotValue.COMPUTING;
185 }
186
187 @Override
188 public T invoke() {
189 Object _value = value;
190 if (!(value instanceof NotValue)) return WrappedValues.unescapeThrowable(_value);
191
192 lock.lock();
193 try {
194 _value = value;
195 if (!(_value instanceof NotValue)) return WrappedValues.unescapeThrowable(_value);
196
197 if (_value == NotValue.COMPUTING) {
198 value = NotValue.RECURSION_WAS_DETECTED;
199 return recursionDetected(/*firstTime = */ true);
200 }
201
202 if (_value == NotValue.RECURSION_WAS_DETECTED) {
203 return recursionDetected(/*firstTime = */ false);
204 }
205
206 value = NotValue.COMPUTING;
207 try {
208 T typedValue = computable.invoke();
209 value = typedValue;
210 postCompute(typedValue);
211 return typedValue;
212 }
213 catch (Throwable throwable) {
214 if (value == NotValue.COMPUTING) {
215 // Store only if it's a genuine result, not something thrown through recursionDetected()
216 value = WrappedValues.escapeThrowable(throwable);
217 }
218 throw ExceptionUtils.rethrow(throwable);
219 }
220 }
221 finally {
222 lock.unlock();
223 }
224 }
225
226 /**
227 * @param firstTime {@code true} when recursion has been just detected, {@code false} otherwise
228 * @return a value to be returned on a recursive call or subsequent calls
229 */
230 @Nullable
231 protected T recursionDetected(boolean firstTime) {
232 throw new IllegalStateException("Recursive call in a lazy value");
233 }
234
235 protected void postCompute(T value) {
236 // Doing something in post-compute helps prevent infinite recursion
237 }
238 }
239
240 private static class LockBasedNotNullLazyValue<T> extends LockBasedLazyValue<T> implements NotNullLazyValue<T> {
241
242 public LockBasedNotNullLazyValue(@NotNull Lock lock, @NotNull Function0<T> computable) {
243 super(lock, computable);
244 }
245
246 @Override
247 @NotNull
248 public T invoke() {
249 T result = super.invoke();
250 assert result != null : "compute() returned null";
251 return result;
252 }
253 }
254
255 private static class MapBasedMemoizedFunction<K, V> implements MemoizedFunctionToNullable<K, V> {
256 private final Lock lock;
257 private final ConcurrentMap<K, Object> cache;
258 private final Function1<K, V> compute;
259
260 public MapBasedMemoizedFunction(@NotNull Lock lock, @NotNull ConcurrentMap<K, Object> map, @NotNull Function1<K, V> compute) {
261 this.lock = lock;
262 this.cache = map;
263 this.compute = compute;
264 }
265
266 @Override
267 @Nullable
268 public V invoke(K input) {
269 Object value = cache.get(input);
270 if (value != null) return WrappedValues.unescapeExceptionOrNull(value);
271
272 lock.lock();
273 try {
274 value = cache.get(input);
275 if (value != null) return WrappedValues.unescapeExceptionOrNull(value);
276
277 try {
278 V typedValue = compute.invoke(input);
279 Object oldValue = cache.put(input, WrappedValues.escapeNull(typedValue));
280 assert oldValue == null : "Race condition or recursion detected. Old value is " + oldValue;
281
282 return typedValue;
283 }
284 catch (Throwable throwable) {
285 Object oldValue = cache.put(input, WrappedValues.escapeThrowable(throwable));
286 assert oldValue == null : "Race condition or recursion detected. Old value is " + oldValue;
287
288 throw ExceptionUtils.rethrow(throwable);
289 }
290 }
291 finally {
292 lock.unlock();
293 }
294 }
295 }
296
297 private static class MapBasedMemoizedFunctionToNotNull<K, V> extends MapBasedMemoizedFunction<K, V> implements MemoizedFunctionToNotNull<K, V> {
298
299 public MapBasedMemoizedFunctionToNotNull(
300 @NotNull Lock lock,
301 @NotNull ConcurrentMap<K, Object> map,
302 @NotNull Function1<K, V> compute
303 ) {
304 super(lock, map, compute);
305 }
306
307 @NotNull
308 @Override
309 public V invoke(K input) {
310 V result = super.invoke(input);
311 assert result != null : "compute() returned null";
312 return result;
313 }
314 }
315 }