You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ratis.apache.org by ru...@apache.org on 2020/11/18 23:44:54 UTC

[incubator-ratis] branch master updated: RATIS-1163. RefCountingMap is not thread-safe. (#285)

This is an automated email from the ASF dual-hosted git repository.

runzhiwang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-ratis.git


The following commit(s) were added to refs/heads/master by this push:
     new 375b738  RATIS-1163. RefCountingMap is not thread-safe. (#285)
375b738 is described below

commit 375b738d8f9d0cd5364daf9bd257f1e784bccf18
Author: Tsz-Wo Nicholas Sze <sz...@apache.org>
AuthorDate: Thu Nov 19 07:43:59 2020 +0800

    RATIS-1163. RefCountingMap is not thread-safe. (#285)
---
 .../ratis/metrics/MetricRegistriesLoader.java      |  4 +-
 .../apache/ratis/metrics/impl/RefCountingMap.java  | 50 +++++++++++++---------
 2 files changed, 31 insertions(+), 23 deletions(-)

diff --git a/ratis-metrics/src/main/java/org/apache/ratis/metrics/MetricRegistriesLoader.java b/ratis-metrics/src/main/java/org/apache/ratis/metrics/MetricRegistriesLoader.java
index e2da686..a60e253 100644
--- a/ratis-metrics/src/main/java/org/apache/ratis/metrics/MetricRegistriesLoader.java
+++ b/ratis-metrics/src/main/java/org/apache/ratis/metrics/MetricRegistriesLoader.java
@@ -1,4 +1,4 @@
-/**
+/*
  *
  * Licensed to the Apache Software Foundation (ASF) under one
  * or more contributor license agreements.  See the NOTICE file
@@ -65,7 +65,7 @@ public final class MetricRegistriesLoader {
       return impl;
     } else if (availableImplementations.isEmpty()) {
       try {
-        return ReflectionUtils.newInstance((Class<MetricRegistries>)Class.forName(DEFAULT_CLASS));
+        return ReflectionUtils.newInstance(Class.forName(DEFAULT_CLASS).asSubclass(MetricRegistries.class));
       } catch (ClassNotFoundException e) {
         throw new RuntimeException(e);
       }
diff --git a/ratis-metrics/src/main/java/org/apache/ratis/metrics/impl/RefCountingMap.java b/ratis-metrics/src/main/java/org/apache/ratis/metrics/impl/RefCountingMap.java
index 1afaab9..4975978 100644
--- a/ratis-metrics/src/main/java/org/apache/ratis/metrics/impl/RefCountingMap.java
+++ b/ratis-metrics/src/main/java/org/apache/ratis/metrics/impl/RefCountingMap.java
@@ -1,5 +1,4 @@
-/**
- *
+/*
  * Licensed to the Apache Software Foundation (ASF) under one
  * or more contributor license agreements.  See the NOTICE file
  * distributed with this work for additional information
@@ -21,6 +20,8 @@ package org.apache.ratis.metrics.impl;
 import java.util.Collection;
 import java.util.Set;
 import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentMap;
+import java.util.concurrent.atomic.AtomicInteger;
 import java.util.function.Supplier;
 import java.util.stream.Collectors;
 
@@ -31,31 +32,39 @@ import java.util.stream.Collectors;
  * from the map iff ref count == 0.
  */
 class RefCountingMap<K, V> {
-
-  private ConcurrentHashMap<K, Payload<V>> map = new ConcurrentHashMap<>();
   private static class Payload<V> {
-    private V v;
-    private int refCount;
+    private final V value;
+    private final AtomicInteger refCount = new AtomicInteger();
+
     Payload(V v) {
-      this.v = v;
-      this.refCount = 1; // create with ref count = 1
+      this.value = v;
+    }
+
+    V get() {
+      return value;
+    }
+
+    V increment() {
+      return refCount.incrementAndGet() > 0? value: null;
+    }
+
+    Payload<V> decrement() {
+      return refCount.decrementAndGet() > 0? this: null;
     }
   }
 
+  private final ConcurrentMap<K, Payload<V>> map = new ConcurrentHashMap<>();
+
   V put(K k, Supplier<V> supplier) {
-    return ((Payload<V>)map.compute(k, (k1, oldValue) -> {
-      if (oldValue != null) {
-        oldValue.refCount++;
-        return oldValue;
-      } else {
-        return new Payload(supplier.get());
-      }
-    })).v;
+    return map.compute(k, (k1, old) -> old != null? old: new Payload<>(supplier.get())).increment();
+  }
+
+  static <V> V get(Payload<V> p) {
+    return p == null ? null : p.get();
   }
 
   V get(K k) {
-    Payload<V> p = map.get(k);
-    return p == null ? null : p.v;
+    return get(map.get(k));
   }
 
   /**
@@ -64,8 +73,7 @@ class RefCountingMap<K, V> {
    * @return the value associated with the specified key or null if key is removed from map.
    */
   V remove(K k) {
-    Payload<V> p = map.computeIfPresent(k, (k1, v) -> --v.refCount <= 0 ? null : v);
-    return p == null ? null : p.v;
+    return get(map.computeIfPresent(k, (k1, v) -> v.decrement()));
   }
 
   void clear() {
@@ -77,7 +85,7 @@ class RefCountingMap<K, V> {
   }
 
   Collection<V> values() {
-    return map.values().stream().map(v -> v.v).collect(Collectors.toList());
+    return map.values().stream().map(Payload::get).collect(Collectors.toList());
   }
 
   int size() {