Revise ConcurrentReferenceHashMap for @ConcurrencyLimit race condition

Closes gh-35788
See gh-35794
This commit is contained in:
Juergen Hoeller 2025-11-11 13:39:15 +01:00
parent 721c40b5c5
commit 0552cdb7ed
3 changed files with 200 additions and 46 deletions

View File

@ -19,6 +19,7 @@ package org.springframework.resilience.annotation;
import java.lang.reflect.Method;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import org.aopalliance.intercept.MethodInterceptor;
import org.aopalliance.intercept.MethodInvocation;
@ -73,7 +74,7 @@ public class ConcurrencyLimitBeanPostProcessor extends AbstractBeanFactoryAwareA
private class ConcurrencyLimitInterceptor implements MethodInterceptor {
private final Map<Object, ConcurrencyThrottleCache> cachePerInstance =
private final ConcurrentMap<Object, ConcurrencyThrottleCache> cachePerInstance =
new ConcurrentReferenceHashMap<>(16, ConcurrentReferenceHashMap.ReferenceType.WEAK);
@Override
@ -87,8 +88,11 @@ public class ConcurrencyLimitBeanPostProcessor extends AbstractBeanFactoryAwareA
}
Assert.state(target != null, "Target must not be null");
// Build unique ConcurrencyThrottleCache instance per target object
ConcurrencyThrottleCache cache = this.cachePerInstance.computeIfAbsent(target,
k -> new ConcurrencyThrottleCache());
// Determine method-specific interceptor instance with isolated concurrency count
MethodInterceptor interceptor = cache.methodInterceptors.get(method);
if (interceptor == null) {
synchronized (cache) {

View File

@ -33,11 +33,13 @@ import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.BiFunction;
import java.util.function.Function;
import org.jspecify.annotations.Nullable;
/**
* A {@link ConcurrentHashMap} that uses {@link ReferenceType#SOFT soft} or
* A {@link ConcurrentHashMap} variant that uses {@link ReferenceType#SOFT soft} or
* {@linkplain ReferenceType#WEAK weak} references for both {@code keys} and {@code values}.
*
* <p>This class can be used as an alternative to
@ -320,7 +322,7 @@ public class ConcurrentReferenceHashMap<K, V> extends AbstractMap<K, V> implemen
return false;
}
});
return (Boolean.TRUE.equals(result));
return Boolean.TRUE.equals(result);
}
@Override
@ -335,7 +337,7 @@ public class ConcurrentReferenceHashMap<K, V> extends AbstractMap<K, V> implemen
return false;
}
});
return (Boolean.TRUE.equals(result));
return Boolean.TRUE.equals(result);
}
@Override
@ -353,6 +355,114 @@ public class ConcurrentReferenceHashMap<K, V> extends AbstractMap<K, V> implemen
});
}
@Override
public @Nullable V computeIfAbsent(@Nullable K key, Function<@Nullable ? super K, @Nullable ? extends V> mappingFunction) {
return doTask(key, new Task<V>(TaskOption.RESTRUCTURE_BEFORE, TaskOption.RESIZE) {
@Override
protected @Nullable V execute(@Nullable Reference<K, V> ref, @Nullable Entry<K, V> entry, @Nullable Entries<V> entries) {
if (entry != null) {
return entry.getValue();
}
V value = mappingFunction.apply(key);
// Add entry only if not null
if (value != null) {
Assert.state(entries != null, "No entries segment");
entries.add(value);
}
return value;
}
});
}
@Override
public @Nullable V computeIfPresent(@Nullable K key, BiFunction<@Nullable ? super K, @Nullable ? super V, @Nullable ? extends V> remappingFunction) {
return doTask(key, new Task<V>(TaskOption.RESTRUCTURE_BEFORE, TaskOption.RESIZE) {
@Override
protected @Nullable V execute(@Nullable Reference<K, V> ref, @Nullable Entry<K, V> entry, @Nullable Entries<V> entries) {
if (entry != null) {
V oldValue = entry.getValue();
V value = remappingFunction.apply(key, oldValue);
if (value != null) {
// Replace entry
entry.setValue(value);
return value;
}
else {
// Remove entry
if (ref != null) {
ref.release();
}
}
}
return null;
}
});
}
@Override
public @Nullable V compute(@Nullable K key, BiFunction<@Nullable ? super K, @Nullable ? super V, @Nullable ? extends V> remappingFunction) {
return doTask(key, new Task<V>(TaskOption.RESTRUCTURE_BEFORE, TaskOption.RESIZE) {
@Override
protected @Nullable V execute(@Nullable Reference<K, V> ref, @Nullable Entry<K, V> entry, @Nullable Entries<V> entries) {
V oldValue = null;
if (entry != null) {
oldValue = entry.getValue();
}
V value = remappingFunction.apply(key, oldValue);
if (value != null) {
if (entry != null) {
// Replace entry
entry.setValue(value);
}
else {
// Add entry
Assert.state(entries != null, "No entries segment");
entries.add(value);
}
return value;
}
else {
// Remove entry
if (ref != null) {
ref.release();
}
}
return null;
}
});
}
@Override
public @Nullable V merge(@Nullable K key, @Nullable V value, BiFunction<@Nullable ? super V, @Nullable ? super V, @Nullable ? extends V> remappingFunction) {
return doTask(key, new Task<V>(TaskOption.RESTRUCTURE_BEFORE, TaskOption.RESIZE) {
@Override
protected @Nullable V execute(@Nullable Reference<K, V> ref, @Nullable Entry<K, V> entry, @Nullable Entries<V> entries) {
if (entry != null) {
V oldValue = entry.getValue();
V newValue = remappingFunction.apply(oldValue, value);
if (newValue != null) {
// Replace entry
entry.setValue(newValue);
return newValue;
}
else {
// Remove entry
if (ref != null) {
ref.release();
}
return null;
}
}
else {
// Add entry
Assert.state(entries != null, "No entries segment");
entries.add(value);
return value;
}
}
});
}
@Override
public void clear() {
for (Segment segment : this.segments) {

View File

@ -53,7 +53,7 @@ class ConcurrentReferenceHashMapTests {
@Test
void shouldCreateWithDefaults() {
void createWithDefaults() {
ConcurrentReferenceHashMap<Integer, String> map = new ConcurrentReferenceHashMap<>();
assertThat(map.getSegmentsSize()).isEqualTo(16);
assertThat(map.getSegment(0).getSize()).isEqualTo(1);
@ -61,7 +61,7 @@ class ConcurrentReferenceHashMapTests {
}
@Test
void shouldCreateWithInitialCapacity() {
void createWithInitialCapacity() {
ConcurrentReferenceHashMap<Integer, String> map = new ConcurrentReferenceHashMap<>(32);
assertThat(map.getSegmentsSize()).isEqualTo(16);
assertThat(map.getSegment(0).getSize()).isEqualTo(2);
@ -69,7 +69,7 @@ class ConcurrentReferenceHashMapTests {
}
@Test
void shouldCreateWithInitialCapacityAndLoadFactor() {
void createWithInitialCapacityAndLoadFactor() {
ConcurrentReferenceHashMap<Integer, String> map = new ConcurrentReferenceHashMap<>(32, 0.5f);
assertThat(map.getSegmentsSize()).isEqualTo(16);
assertThat(map.getSegment(0).getSize()).isEqualTo(2);
@ -77,7 +77,7 @@ class ConcurrentReferenceHashMapTests {
}
@Test
void shouldCreateWithInitialCapacityAndConcurrentLevel() {
void createWithInitialCapacityAndConcurrentLevel() {
ConcurrentReferenceHashMap<Integer, String> map = new ConcurrentReferenceHashMap<>(16, 2);
assertThat(map.getSegmentsSize()).isEqualTo(2);
assertThat(map.getSegment(0).getSize()).isEqualTo(8);
@ -85,7 +85,7 @@ class ConcurrentReferenceHashMapTests {
}
@Test
void shouldCreateFullyCustom() {
void createFullyCustom() {
ConcurrentReferenceHashMap<Integer, String> map = new ConcurrentReferenceHashMap<>(5, 0.5f, 3);
// concurrencyLevel of 3 ends up as 4 (nearest power of 2)
assertThat(map.getSegmentsSize()).isEqualTo(4);
@ -95,28 +95,28 @@ class ConcurrentReferenceHashMapTests {
}
@Test
void shouldNeedNonNegativeInitialCapacity() {
void nonNegativeInitialCapacity() {
assertThatNoException().isThrownBy(() -> new ConcurrentReferenceHashMap<Integer, String>(0, 1));
assertThatIllegalArgumentException().isThrownBy(() -> new ConcurrentReferenceHashMap<Integer, String>(-1, 1))
.withMessageContaining("Initial capacity must not be negative");
}
@Test
void shouldNeedPositiveLoadFactor() {
void positiveLoadFactor() {
assertThatNoException().isThrownBy(() -> new ConcurrentReferenceHashMap<Integer, String>(0, 0.1f, 1));
assertThatIllegalArgumentException().isThrownBy(() -> new ConcurrentReferenceHashMap<Integer, String>(0, 0.0f, 1))
.withMessageContaining("Load factor must be positive");
}
@Test
void shouldNeedPositiveConcurrencyLevel() {
void positiveConcurrencyLevel() {
assertThatNoException().isThrownBy(() -> new ConcurrentReferenceHashMap<Integer, String>(1, 1));
assertThatIllegalArgumentException().isThrownBy(() -> new ConcurrentReferenceHashMap<Integer, String>(1, 0))
.withMessageContaining("Concurrency level must be positive");
}
@Test
void shouldPutAndGet() {
void putAndGet() {
// NOTE we are using mock references so we don't need to worry about GC
assertThat(this.map).isEmpty();
this.map.put(123, "123");
@ -129,14 +129,14 @@ class ConcurrentReferenceHashMapTests {
}
@Test
void shouldReplaceOnDoublePut() {
void replaceOnDoublePut() {
this.map.put(123, "321");
this.map.put(123, "123");
assertThat(this.map.get(123)).isEqualTo("123");
}
@Test
void shouldPutNullKey() {
void putNullKey() {
assertThat(this.map.get(null)).isNull();
assertThat(this.map.getOrDefault(null, "456")).isEqualTo("456");
this.map.put(null, "123");
@ -145,7 +145,7 @@ class ConcurrentReferenceHashMapTests {
}
@Test
void shouldPutNullValue() {
void putNullValue() {
assertThat(this.map.get(123)).isNull();
assertThat(this.map.getOrDefault(123, "456")).isEqualTo("456");
this.map.put(123, "321");
@ -157,12 +157,12 @@ class ConcurrentReferenceHashMapTests {
}
@Test
void shouldGetWithNoItems() {
void getWithNoItems() {
assertThat(this.map.get(123)).isNull();
}
@Test
void shouldApplySupplementalHash() {
void applySupplementalHash() {
Integer key = 123;
this.map.put(key, "123");
assertThat(this.map.getSupplementalHash()).isNotEqualTo(key.hashCode());
@ -170,7 +170,7 @@ class ConcurrentReferenceHashMapTests {
}
@Test
void shouldGetFollowingNexts() {
void getFollowingNexts() {
// Use loadFactor to disable resize
this.map = new TestWeakConcurrentCache<>(1, 10.0f, 1);
this.map.put(1, "1");
@ -184,7 +184,7 @@ class ConcurrentReferenceHashMapTests {
}
@Test
void shouldResize() {
void resize() {
this.map = new TestWeakConcurrentCache<>(1, 0.75f, 1);
this.map.put(1, "1");
assertThat(this.map.getSegment(0).getSize()).isEqualTo(1);
@ -214,7 +214,7 @@ class ConcurrentReferenceHashMapTests {
}
@Test
void shouldPurgeOnGet() {
void purgeOnGet() {
this.map = new TestWeakConcurrentCache<>(1, 0.75f, 1);
for (int i = 1; i <= 5; i++) {
this.map.put(i, String.valueOf(i));
@ -229,7 +229,7 @@ class ConcurrentReferenceHashMapTests {
}
@Test
void shouldPurgeOnPut() {
void purgeOnPut() {
this.map = new TestWeakConcurrentCache<>(1, 0.75f, 1);
for (int i = 1; i <= 5; i++) {
this.map.put(i, String.valueOf(i));
@ -245,28 +245,28 @@ class ConcurrentReferenceHashMapTests {
}
@Test
void shouldPutIfAbsent() {
void putIfAbsent() {
assertThat(this.map.putIfAbsent(123, "123")).isNull();
assertThat(this.map.putIfAbsent(123, "123b")).isEqualTo("123");
assertThat(this.map.get(123)).isEqualTo("123");
}
@Test
void shouldPutIfAbsentWithNullValue() {
void putIfAbsentWithNullValue() {
assertThat(this.map.putIfAbsent(123, null)).isNull();
assertThat(this.map.putIfAbsent(123, "123")).isNull();
assertThat(this.map.get(123)).isNull();
}
@Test
void shouldPutIfAbsentWithNullKey() {
void putIfAbsentWithNullKey() {
assertThat(this.map.putIfAbsent(null, "123")).isNull();
assertThat(this.map.putIfAbsent(null, "123b")).isEqualTo("123");
assertThat(this.map.get(null)).isEqualTo("123");
}
@Test
void shouldRemoveKeyAndValue() {
void removeKeyAndValue() {
this.map.put(123, "123");
assertThat(this.map.remove(123, "456")).isFalse();
assertThat(this.map.get(123)).isEqualTo("123");
@ -276,7 +276,7 @@ class ConcurrentReferenceHashMapTests {
}
@Test
void shouldRemoveKeyAndValueWithExistingNull() {
void removeKeyAndValueWithExistingNull() {
this.map.put(123, null);
assertThat(this.map.remove(123, "456")).isFalse();
assertThat(this.map.get(123)).isNull();
@ -286,7 +286,7 @@ class ConcurrentReferenceHashMapTests {
}
@Test
void shouldReplaceOldValueWithNewValue() {
void replaceOldValueWithNewValue() {
this.map.put(123, "123");
assertThat(this.map.replace(123, "456", "789")).isFalse();
assertThat(this.map.get(123)).isEqualTo("123");
@ -295,7 +295,7 @@ class ConcurrentReferenceHashMapTests {
}
@Test
void shouldReplaceOldNullValueWithNewValue() {
void replaceOldNullValueWithNewValue() {
this.map.put(123, null);
assertThat(this.map.replace(123, "456", "789")).isFalse();
assertThat(this.map.get(123)).isNull();
@ -304,21 +304,61 @@ class ConcurrentReferenceHashMapTests {
}
@Test
void shouldReplaceValue() {
void replaceValue() {
this.map.put(123, "123");
assertThat(this.map.replace(123, "456")).isEqualTo("123");
assertThat(this.map.get(123)).isEqualTo("456");
}
@Test
void shouldReplaceNullValue() {
void replaceNullValue() {
this.map.put(123, null);
assertThat(this.map.replace(123, "456")).isNull();
assertThat(this.map.get(123)).isEqualTo("456");
}
@Test
void shouldGetSize() {
void computeIfAbsent() {
assertThat(this.map.computeIfAbsent(123, k -> "123")).isEqualTo("123");
assertThat(this.map.computeIfAbsent(123, k -> "123b")).isEqualTo("123");
assertThat(this.map.get(123)).isEqualTo("123");
this.map.remove(123);
assertThat(this.map.computeIfAbsent(123, k -> null)).isNull();
assertThat(this.map.containsKey(123)).isFalse();
}
@Test
void computeIfPresent() {
assertThat(this.map.computeIfPresent(123, (k, v) -> "123")).isNull();
this.map.put(123, "123");
assertThat(this.map.computeIfPresent(123, (k, v) -> v + "b")).isEqualTo("123b");
assertThat(this.map.get(123)).isEqualTo("123b");
assertThat(this.map.computeIfPresent(123, (k, v) -> null)).isNull();
assertThat(this.map.containsKey(123)).isFalse();
}
@Test
void compute() {
assertThat(this.map.compute(123, (k, v) -> "123" + v)).isEqualTo("123null");
assertThat(this.map.compute(123, (k, v) -> null)).isNull();
assertThat(this.map.compute(123, (k, v) -> null)).isNull();
assertThat(this.map.compute(123, (k, v) -> "123")).isEqualTo("123");
assertThat(this.map.compute(123, (k, v) -> v + "b")).isEqualTo("123b");
assertThat(this.map.get(123)).isEqualTo("123b");
}
@Test
void merge() {
assertThat(this.map.merge(123, "123", (v1, v2) -> v1 + v2)).isEqualTo("123");
assertThat(this.map.merge(123, null, (v1, v2) -> v1 + v2)).isEqualTo("123null");
assertThat(this.map.merge(123, null, (v1, v2) -> null)).isNull();
assertThat(this.map.merge(123, "123", (v1, v2) -> v1 + v2)).isEqualTo("123");
assertThat(this.map.merge(123, "b", (v1, v2) -> v1 + v2)).isEqualTo("123b");
assertThat(this.map.get(123)).isEqualTo("123b");
}
@Test
void size() {
assertThat(this.map).isEmpty();
this.map.put(123, "123");
this.map.put(123, null);
@ -327,7 +367,7 @@ class ConcurrentReferenceHashMapTests {
}
@Test
void shouldSupportIsEmpty() {
void isEmpty() {
assertThat(this.map).isEmpty();
this.map.put(123, "123");
this.map.put(123, null);
@ -336,7 +376,7 @@ class ConcurrentReferenceHashMapTests {
}
@Test
void shouldContainKey() {
void containsKey() {
assertThat(this.map.containsKey(123)).isFalse();
assertThat(this.map.containsKey(456)).isFalse();
this.map.put(123, "123");
@ -346,7 +386,7 @@ class ConcurrentReferenceHashMapTests {
}
@Test
void shouldContainValue() {
void containsValue() {
assertThat(this.map.containsValue("123")).isFalse();
assertThat(this.map.containsValue(null)).isFalse();
this.map.put(123, "123");
@ -356,7 +396,7 @@ class ConcurrentReferenceHashMapTests {
}
@Test
void shouldRemoveWhenKeyIsInMap() {
void removeWhenKeyIsInMap() {
this.map.put(123, null);
this.map.put(456, "456");
this.map.put(null, "789");
@ -367,14 +407,14 @@ class ConcurrentReferenceHashMapTests {
}
@Test
void shouldRemoveWhenKeyIsNotInMap() {
void removeWhenKeyIsNotInMap() {
assertThat(this.map.remove(123)).isNull();
assertThat(this.map.remove(null)).isNull();
assertThat(this.map).isEmpty();
}
@Test
void shouldPutAll() {
void putAll() {
Map<Integer, String> m = new HashMap<>();
m.put(123, "123");
m.put(456, null);
@ -387,7 +427,7 @@ class ConcurrentReferenceHashMapTests {
}
@Test
void shouldClear() {
void clear() {
this.map.put(123, "123");
this.map.put(456, null);
this.map.put(null, "789");
@ -399,7 +439,7 @@ class ConcurrentReferenceHashMapTests {
}
@Test
void shouldGetKeySet() {
void keySet() {
this.map.put(123, "123");
this.map.put(456, null);
this.map.put(null, "789");
@ -411,7 +451,7 @@ class ConcurrentReferenceHashMapTests {
}
@Test
void shouldGetValues() {
void valuesCollection() {
this.map.put(123, "123");
this.map.put(456, null);
this.map.put(null, "789");
@ -426,7 +466,7 @@ class ConcurrentReferenceHashMapTests {
}
@Test
void shouldGetEntrySet() {
void getEntrySet() {
this.map.put(123, "123");
this.map.put(456, null);
this.map.put(null, "789");
@ -438,7 +478,7 @@ class ConcurrentReferenceHashMapTests {
}
@Test
void shouldGetEntrySetFollowingNext() {
void getEntrySetFollowingNext() {
// Use loadFactor to disable resize
this.map = new TestWeakConcurrentCache<>(1, 10.0f, 1);
this.map.put(1, "1");
@ -452,7 +492,7 @@ class ConcurrentReferenceHashMapTests {
}
@Test
void shouldRemoveViaEntrySet() {
void removeViaEntrySet() {
this.map.put(1, "1");
this.map.put(2, "2");
this.map.put(3, "3");
@ -468,7 +508,7 @@ class ConcurrentReferenceHashMapTests {
}
@Test
void shouldSetViaEntrySet() {
void setViaEntrySet() {
this.map.put(1, "1");
this.map.put(2, "2");
this.map.put(3, "3");
@ -502,7 +542,7 @@ class ConcurrentReferenceHashMapTests {
}
@Test
void shouldSupportNullReference() {
void supportNullReference() {
// GC could happen during restructure so we must be able to create a reference for a null entry
map.createReferenceManager().createReference(null, 1234, null);
}