You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@pinot.apache.org by ja...@apache.org on 2022/02/21 17:57:03 UTC

[pinot] branch master updated: Ensure partition function never return negative partition (#8221)

This is an automated email from the ASF dual-hosted git repository.

jackie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new 21632da  Ensure partition function never return negative partition (#8221)
21632da is described below

commit 21632dadb8cd2d8b77aec523a758d73a64f70b07
Author: Xiaotian (Jackie) Jiang <17...@users.noreply.github.com>
AuthorDate: Mon Feb 21 09:56:32 2022 -0800

    Ensure partition function never return negative partition (#8221)
    
    There are some corner cases not handled in partition function which can lead to negative partition:
    - In HashCodePartitionFunction when hashcode is Integer.MIN_VALUE
    - In ModuloPartitionFunction when value is negative
    
    Also enhance the test to ensure passing number and string returns the same result
---
 .../spi/partition/ByteArrayPartitionFunction.java  |   9 +-
 .../spi/partition/HashCodePartitionFunction.java   |  10 +-
 .../spi/partition/ModuloPartitionFunction.java     |  10 +-
 .../spi/partition/MurmurPartitionFunction.java     |   5 +-
 .../spi/partition/PartitionFunctionTest.java       | 151 ++++++++++-----------
 5 files changed, 90 insertions(+), 95 deletions(-)

diff --git a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/partition/ByteArrayPartitionFunction.java b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/partition/ByteArrayPartitionFunction.java
index fc52fb8..aa97021 100644
--- a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/partition/ByteArrayPartitionFunction.java
+++ b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/partition/ByteArrayPartitionFunction.java
@@ -21,6 +21,8 @@ package org.apache.pinot.segment.spi.partition;
 import com.google.common.base.Preconditions;
 import java.util.Arrays;
 
+import static java.nio.charset.StandardCharsets.UTF_8;
+
 
 /**
  * Implementation of {@link Byte array partitioner}
@@ -40,8 +42,8 @@ public class ByteArrayPartitionFunction implements PartitionFunction {
   }
 
   @Override
-  public int getPartition(Object valueIn) {
-    return abs(Arrays.hashCode(valueIn.toString().getBytes())) % _numPartitions;
+  public int getPartition(Object value) {
+    return abs(Arrays.hashCode(value.toString().getBytes(UTF_8))) % _numPartitions;
   }
 
   @Override
@@ -60,7 +62,8 @@ public class ByteArrayPartitionFunction implements PartitionFunction {
     return NAME;
   }
 
-  private int abs(int n) {
+  // NOTE: This matches the Utils.abs() in Kafka
+  private static int abs(int n) {
     return (n == Integer.MIN_VALUE) ? 0 : Math.abs(n);
   }
 }
diff --git a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/partition/HashCodePartitionFunction.java b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/partition/HashCodePartitionFunction.java
index f3846e7..11d34e3 100644
--- a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/partition/HashCodePartitionFunction.java
+++ b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/partition/HashCodePartitionFunction.java
@@ -37,9 +37,8 @@ public class HashCodePartitionFunction implements PartitionFunction {
   }
 
   @Override
-  public int getPartition(Object valueIn) {
-    String value = valueIn instanceof String ? (String) valueIn : valueIn.toString();
-    return Math.abs(value.hashCode()) % _numPartitions;
+  public int getPartition(Object value) {
+    return abs(value.toString().hashCode()) % _numPartitions;
   }
 
   @Override
@@ -57,4 +56,9 @@ public class HashCodePartitionFunction implements PartitionFunction {
   public String toString() {
     return NAME;
   }
+
+  // NOTE: This matches the Utils.abs() in Kafka
+  private static int abs(int n) {
+    return (n == Integer.MIN_VALUE) ? 0 : Math.abs(n);
+  }
 }
diff --git a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/partition/ModuloPartitionFunction.java b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/partition/ModuloPartitionFunction.java
index e2e1d32..4d62d10 100644
--- a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/partition/ModuloPartitionFunction.java
+++ b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/partition/ModuloPartitionFunction.java
@@ -51,13 +51,13 @@ public class ModuloPartitionFunction implements PartitionFunction {
   @Override
   public int getPartition(Object value) {
     if (value instanceof Integer) {
-      return ((Integer) value) % _numPartitions;
+      return toNonNegative((Integer) value % _numPartitions);
     } else if (value instanceof Long) {
       // Since _numPartitions is int, the modulo should also be int.
-      return (int) (((Long) value) % _numPartitions);
+      return toNonNegative((int) ((Long) value % _numPartitions));
     } else if (value instanceof String) {
       // Parse String as Long, to support both Integer and Long.
-      return (int) ((Long.parseLong((String) value)) % _numPartitions);
+      return toNonNegative((int) (Long.parseLong((String) value) % _numPartitions));
     } else {
       throw new IllegalArgumentException(
           "Illegal argument for partitioning, expected Integer, got: " + value.getClass().getName());
@@ -79,4 +79,8 @@ public class ModuloPartitionFunction implements PartitionFunction {
   public String toString() {
     return NAME;
   }
+
+  private int toNonNegative(int partition) {
+    return partition < 0 ? partition + _numPartitions : partition;
+  }
 }
diff --git a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/partition/MurmurPartitionFunction.java b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/partition/MurmurPartitionFunction.java
index dd566e0..7d2b9b8 100644
--- a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/partition/MurmurPartitionFunction.java
+++ b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/partition/MurmurPartitionFunction.java
@@ -41,9 +41,8 @@ public class MurmurPartitionFunction implements PartitionFunction {
   }
 
   @Override
-  public int getPartition(Object valueIn) {
-    String value = (valueIn instanceof String) ? (String) valueIn : valueIn.toString();
-    return (murmur2(value.getBytes(UTF_8)) & 0x7fffffff) % _numPartitions;
+  public int getPartition(Object value) {
+    return (murmur2(value.toString().getBytes(UTF_8)) & Integer.MAX_VALUE) % _numPartitions;
   }
 
   @Override
diff --git a/pinot-segment-spi/src/test/java/org/apache/pinot/segment/spi/partition/PartitionFunctionTest.java b/pinot-segment-spi/src/test/java/org/apache/pinot/segment/spi/partition/PartitionFunctionTest.java
index 8e813ec..1f089ad 100644
--- a/pinot-segment-spi/src/test/java/org/apache/pinot/segment/spi/partition/PartitionFunctionTest.java
+++ b/pinot-segment-spi/src/test/java/org/apache/pinot/segment/spi/partition/PartitionFunctionTest.java
@@ -32,6 +32,8 @@ import static org.testng.Assert.assertTrue;
  * Unit test for {@link PartitionFunction}
  */
 public class PartitionFunctionTest {
+  private static final int NUM_ROUNDS = 1000;
+  private static final int MAX_NUM_PARTITIONS = 100;
 
   /**
    * Unit test for {@link ModuloPartitionFunction}.
@@ -46,42 +48,35 @@ public class PartitionFunctionTest {
     long seed = System.currentTimeMillis();
     Random random = new Random(seed);
 
-    for (int i = 0; i < 1000; i++) {
-      int expectedNumPartitions = Math.abs(random.nextInt());
-
-      // Avoid divide-by-zero.
-      if (expectedNumPartitions == 0) {
-        expectedNumPartitions = 1;
-      }
+    for (int i = 0; i < NUM_ROUNDS; i++) {
+      int numPartitions = random.nextInt(MAX_NUM_PARTITIONS) + 1;
 
       String functionName = "MoDuLo";
-      PartitionFunction partitionFunction =
-          PartitionFunctionFactory.getPartitionFunction(functionName, expectedNumPartitions);
-
-      testBasicProperties(partitionFunction, functionName, expectedNumPartitions);
-
-      // Test int values.
-      for (int j = 0; j < 1000; j++) {
-        int value = random.nextInt();
-        assertEquals(partitionFunction.getPartition(value), (value % expectedNumPartitions));
-      }
-
-      // Test long values.
-      for (int j = 0; j < 1000; j++) {
-        long value = random.nextLong();
-        assertEquals(partitionFunction.getPartition(value), (value % expectedNumPartitions));
-      }
-
-      // Test Integer represented as String.
-      for (int j = 0; j < 1000; j++) {
-        int value = random.nextInt();
-        assertEquals(partitionFunction.getPartition(Integer.toString(value)), (value % expectedNumPartitions));
+      PartitionFunction partitionFunction = PartitionFunctionFactory.getPartitionFunction(functionName, numPartitions);
+
+      testBasicProperties(partitionFunction, functionName, numPartitions);
+
+      // Test int values
+      for (int j = 0; j < NUM_ROUNDS; j++) {
+        int value = j == 0 ? Integer.MIN_VALUE : random.nextInt();
+        int expectedPartition = value % numPartitions;
+        if (expectedPartition < 0) {
+          expectedPartition += numPartitions;
+        }
+        assertEquals(partitionFunction.getPartition(value), expectedPartition);
+        assertEquals(partitionFunction.getPartition((long) value), expectedPartition);
+        assertEquals(partitionFunction.getPartition(Integer.toString(value)), expectedPartition);
       }
 
-      // Test Long represented as String.
-      for (int j = 0; j < 1000; j++) {
-        long value = random.nextLong();
-        assertEquals(partitionFunction.getPartition(Long.toString(value)), (value % expectedNumPartitions));
+      // Test long values
+      for (int j = 0; j < NUM_ROUNDS; j++) {
+        long value = j == 0 ? Long.MIN_VALUE : random.nextLong();
+        int expectedPartition = (int) (value % numPartitions);
+        if (expectedPartition < 0) {
+          expectedPartition += numPartitions;
+        }
+        assertEquals(partitionFunction.getPartition(value), expectedPartition);
+        assertEquals(partitionFunction.getPartition(Long.toString(value)), expectedPartition);
       }
     }
   }
@@ -98,25 +93,20 @@ public class PartitionFunctionTest {
     long seed = System.currentTimeMillis();
     Random random = new Random(seed);
 
-    for (int i = 0; i < 1000; i++) {
-      int expectedNumPartitions = Math.abs(random.nextInt());
-
-      // Avoid divide-by-zero.
-      if (expectedNumPartitions == 0) {
-        expectedNumPartitions = 1;
-      }
+    for (int i = 0; i < NUM_ROUNDS; i++) {
+      int numPartitions = random.nextInt(MAX_NUM_PARTITIONS) + 1;
 
       String functionName = "mUrmur";
-      PartitionFunction partitionFunction =
-          PartitionFunctionFactory.getPartitionFunction(functionName, expectedNumPartitions);
+      PartitionFunction partitionFunction = PartitionFunctionFactory.getPartitionFunction(functionName, numPartitions);
 
-      testBasicProperties(partitionFunction, functionName, expectedNumPartitions);
+      testBasicProperties(partitionFunction, functionName, numPartitions);
 
-      for (int j = 0; j < 1000; j++) {
-        int value = random.nextInt();
-        String stringValue = Integer.toString(value);
-        assertTrue(partitionFunction.getPartition(stringValue) < expectedNumPartitions,
-            "Illegal: " + partitionFunction.getPartition(stringValue) + " " + expectedNumPartitions);
+      for (int j = 0; j < NUM_ROUNDS; j++) {
+        int value = j == 0 ? Integer.MIN_VALUE : random.nextInt();
+        int partition1 = partitionFunction.getPartition(value);
+        int partition2 = partitionFunction.getPartition(Integer.toString(value));
+        assertEquals(partition1, partition2);
+        assertTrue(partition1 >= 0 && partition1 < numPartitions);
       }
     }
   }
@@ -133,24 +123,20 @@ public class PartitionFunctionTest {
     long seed = System.currentTimeMillis();
     Random random = new Random(seed);
 
-    for (int i = 0; i < 1000; i++) {
-      int expectedNumPartitions = Math.abs(random.nextInt());
-
-      // Avoid divide-by-zero.
-      if (expectedNumPartitions == 0) {
-        expectedNumPartitions = 1;
-      }
+    for (int i = 0; i < NUM_ROUNDS; i++) {
+      int numPartitions = random.nextInt(MAX_NUM_PARTITIONS) + 1;
 
       String functionName = "bYteArray";
-      PartitionFunction partitionFunction =
-          PartitionFunctionFactory.getPartitionFunction(functionName, expectedNumPartitions);
+      PartitionFunction partitionFunction = PartitionFunctionFactory.getPartitionFunction(functionName, numPartitions);
 
-      testBasicProperties(partitionFunction, functionName, expectedNumPartitions);
+      testBasicProperties(partitionFunction, functionName, numPartitions);
 
-      for (int j = 0; j < 1000; j++) {
-        Integer value = random.nextInt();
-        assertTrue(partitionFunction.getPartition(value) < expectedNumPartitions,
-            "Illegal: " + partitionFunction.getPartition(value) + " " + expectedNumPartitions);
+      for (int j = 0; j < NUM_ROUNDS; j++) {
+        int value = j == 0 ? Integer.MIN_VALUE : random.nextInt();
+        int partition1 = partitionFunction.getPartition(value);
+        int partition2 = partitionFunction.getPartition(Integer.toString(value));
+        assertEquals(partition1, partition2);
+        assertTrue(partition1 >= 0 && partition1 < numPartitions);
       }
     }
   }
@@ -160,31 +146,30 @@ public class PartitionFunctionTest {
     long seed = System.currentTimeMillis();
     Random random = new Random(seed);
 
-    for (int i = 0; i < 1000; i++) {
-      int expectedNumPartitions = Math.abs(random.nextInt());
-
-      // Avoid divide-by-zero.
-      if (expectedNumPartitions == 0) {
-        expectedNumPartitions = 1;
-      }
+    for (int i = 0; i < NUM_ROUNDS; i++) {
+      int numPartitions = random.nextInt(MAX_NUM_PARTITIONS) + 1;
 
       String functionName = "HaShCoDe";
-      PartitionFunction partitionFunction =
-          PartitionFunctionFactory.getPartitionFunction(functionName, expectedNumPartitions);
+      PartitionFunction partitionFunction = PartitionFunctionFactory.getPartitionFunction(functionName, numPartitions);
 
-      testBasicProperties(partitionFunction, functionName, expectedNumPartitions);
+      testBasicProperties(partitionFunction, functionName, numPartitions);
 
-      // Test Integer values
-      for (int j = 0; j < 1000; j++) {
-        Integer value = random.nextInt();
-        assertEquals(partitionFunction.getPartition(value),
-            Math.abs(value.toString().hashCode()) % expectedNumPartitions);
+      // Test int values
+      for (int j = 0; j < NUM_ROUNDS; j++) {
+        int value = j == 0 ? Integer.MIN_VALUE : random.nextInt();
+        int hashCode = Integer.toString(value).hashCode();
+        int expectedPartition = ((hashCode == Integer.MIN_VALUE) ? 0 : Math.abs(hashCode)) % numPartitions;
+        assertEquals(partitionFunction.getPartition(value), expectedPartition);
+        assertEquals(partitionFunction.getPartition(Integer.toString(value)), expectedPartition);
       }
 
-      // Test Double values represented as String.
-      for (int j = 0; j < 1000; j++) {
-        String value = String.valueOf(random.nextDouble());
-        assertEquals(partitionFunction.getPartition(value), Math.abs(value.hashCode()) % expectedNumPartitions);
+      // Test double values
+      for (int j = 0; j < NUM_ROUNDS; j++) {
+        double value = j == 0 ? Double.NEGATIVE_INFINITY : random.nextDouble();
+        int hashCode = Double.toString(value).hashCode();
+        int expectedPartition = ((hashCode == Integer.MIN_VALUE) ? 0 : Math.abs(hashCode)) % numPartitions;
+        assertEquals(partitionFunction.getPartition(value), expectedPartition);
+        assertEquals(partitionFunction.getPartition(Double.toString(value)), expectedPartition);
       }
     }
   }
@@ -200,7 +185,8 @@ public class PartitionFunctionTest {
   }
 
   /**
-   * Tests the equivalence of org.apache.kafka.common.utils.Utils::murmur2 and {@link MurmurPartitionFunction::murmur2}
+   * Tests the equivalence of org.apache.kafka.common.utils.Utils::murmur2 and
+   * {@link MurmurPartitionFunction#getPartition}
    * Our implementation of murmur2 has been copied over from Utils::murmur2
    */
   @Test
@@ -232,8 +218,7 @@ public class PartitionFunctionTest {
 
   /**
    * Tests the equivalence of partitioning using org.apache.kafka.common.utils.Utils::partition and
-   * {@link MurmurPartitionFunction
-   * ::getPartition}
+   * {@link MurmurPartitionFunction#getPartition}
    */
   @Test
   public void testMurmurPartitionFunctionEquivalence() {

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@pinot.apache.org
For additional commands, e-mail: commits-help@pinot.apache.org