You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@pinot.apache.org by ja...@apache.org on 2022/02/21 17:57:03 UTC
[pinot] branch master updated: Ensure partition function never return negative partition (#8221)
This is an automated email from the ASF dual-hosted git repository.
jackie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push:
new 21632da Ensure partition function never return negative partition (#8221)
21632da is described below
commit 21632dadb8cd2d8b77aec523a758d73a64f70b07
Author: Xiaotian (Jackie) Jiang <17...@users.noreply.github.com>
AuthorDate: Mon Feb 21 09:56:32 2022 -0800
Ensure partition function never return negative partition (#8221)
There are some corner cases not handled in partition function which can lead to negative partition:
- In HashCodePartitionFunction when hashcode is Integer.MIN_VALUE
- In ModuloPartitionFunction when value is negative
Also enhance the test to ensure passing number and string returns the same result
---
.../spi/partition/ByteArrayPartitionFunction.java | 9 +-
.../spi/partition/HashCodePartitionFunction.java | 10 +-
.../spi/partition/ModuloPartitionFunction.java | 10 +-
.../spi/partition/MurmurPartitionFunction.java | 5 +-
.../spi/partition/PartitionFunctionTest.java | 151 ++++++++++-----------
5 files changed, 90 insertions(+), 95 deletions(-)
diff --git a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/partition/ByteArrayPartitionFunction.java b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/partition/ByteArrayPartitionFunction.java
index fc52fb8..aa97021 100644
--- a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/partition/ByteArrayPartitionFunction.java
+++ b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/partition/ByteArrayPartitionFunction.java
@@ -21,6 +21,8 @@ package org.apache.pinot.segment.spi.partition;
import com.google.common.base.Preconditions;
import java.util.Arrays;
+import static java.nio.charset.StandardCharsets.UTF_8;
+
/**
* Implementation of {@link Byte array partitioner}
@@ -40,8 +42,8 @@ public class ByteArrayPartitionFunction implements PartitionFunction {
}
@Override
- public int getPartition(Object valueIn) {
- return abs(Arrays.hashCode(valueIn.toString().getBytes())) % _numPartitions;
+ public int getPartition(Object value) {
+ return abs(Arrays.hashCode(value.toString().getBytes(UTF_8))) % _numPartitions;
}
@Override
@@ -60,7 +62,8 @@ public class ByteArrayPartitionFunction implements PartitionFunction {
return NAME;
}
- private int abs(int n) {
+ // NOTE: This matches the Utils.abs() in Kafka
+ private static int abs(int n) {
return (n == Integer.MIN_VALUE) ? 0 : Math.abs(n);
}
}
diff --git a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/partition/HashCodePartitionFunction.java b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/partition/HashCodePartitionFunction.java
index f3846e7..11d34e3 100644
--- a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/partition/HashCodePartitionFunction.java
+++ b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/partition/HashCodePartitionFunction.java
@@ -37,9 +37,8 @@ public class HashCodePartitionFunction implements PartitionFunction {
}
@Override
- public int getPartition(Object valueIn) {
- String value = valueIn instanceof String ? (String) valueIn : valueIn.toString();
- return Math.abs(value.hashCode()) % _numPartitions;
+ public int getPartition(Object value) {
+ return abs(value.toString().hashCode()) % _numPartitions;
}
@Override
@@ -57,4 +56,9 @@ public class HashCodePartitionFunction implements PartitionFunction {
public String toString() {
return NAME;
}
+
+ // NOTE: This matches the Utils.abs() in Kafka
+ private static int abs(int n) {
+ return (n == Integer.MIN_VALUE) ? 0 : Math.abs(n);
+ }
}
diff --git a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/partition/ModuloPartitionFunction.java b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/partition/ModuloPartitionFunction.java
index e2e1d32..4d62d10 100644
--- a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/partition/ModuloPartitionFunction.java
+++ b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/partition/ModuloPartitionFunction.java
@@ -51,13 +51,13 @@ public class ModuloPartitionFunction implements PartitionFunction {
@Override
public int getPartition(Object value) {
if (value instanceof Integer) {
- return ((Integer) value) % _numPartitions;
+ return toNonNegative((Integer) value % _numPartitions);
} else if (value instanceof Long) {
// Since _numPartitions is int, the modulo should also be int.
- return (int) (((Long) value) % _numPartitions);
+ return toNonNegative((int) ((Long) value % _numPartitions));
} else if (value instanceof String) {
// Parse String as Long, to support both Integer and Long.
- return (int) ((Long.parseLong((String) value)) % _numPartitions);
+ return toNonNegative((int) (Long.parseLong((String) value) % _numPartitions));
} else {
throw new IllegalArgumentException(
"Illegal argument for partitioning, expected Integer, got: " + value.getClass().getName());
@@ -79,4 +79,8 @@ public class ModuloPartitionFunction implements PartitionFunction {
public String toString() {
return NAME;
}
+
+ private int toNonNegative(int partition) {
+ return partition < 0 ? partition + _numPartitions : partition;
+ }
}
diff --git a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/partition/MurmurPartitionFunction.java b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/partition/MurmurPartitionFunction.java
index dd566e0..7d2b9b8 100644
--- a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/partition/MurmurPartitionFunction.java
+++ b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/partition/MurmurPartitionFunction.java
@@ -41,9 +41,8 @@ public class MurmurPartitionFunction implements PartitionFunction {
}
@Override
- public int getPartition(Object valueIn) {
- String value = (valueIn instanceof String) ? (String) valueIn : valueIn.toString();
- return (murmur2(value.getBytes(UTF_8)) & 0x7fffffff) % _numPartitions;
+ public int getPartition(Object value) {
+ return (murmur2(value.toString().getBytes(UTF_8)) & Integer.MAX_VALUE) % _numPartitions;
}
@Override
diff --git a/pinot-segment-spi/src/test/java/org/apache/pinot/segment/spi/partition/PartitionFunctionTest.java b/pinot-segment-spi/src/test/java/org/apache/pinot/segment/spi/partition/PartitionFunctionTest.java
index 8e813ec..1f089ad 100644
--- a/pinot-segment-spi/src/test/java/org/apache/pinot/segment/spi/partition/PartitionFunctionTest.java
+++ b/pinot-segment-spi/src/test/java/org/apache/pinot/segment/spi/partition/PartitionFunctionTest.java
@@ -32,6 +32,8 @@ import static org.testng.Assert.assertTrue;
* Unit test for {@link PartitionFunction}
*/
public class PartitionFunctionTest {
+ private static final int NUM_ROUNDS = 1000;
+ private static final int MAX_NUM_PARTITIONS = 100;
/**
* Unit test for {@link ModuloPartitionFunction}.
@@ -46,42 +48,35 @@ public class PartitionFunctionTest {
long seed = System.currentTimeMillis();
Random random = new Random(seed);
- for (int i = 0; i < 1000; i++) {
- int expectedNumPartitions = Math.abs(random.nextInt());
-
- // Avoid divide-by-zero.
- if (expectedNumPartitions == 0) {
- expectedNumPartitions = 1;
- }
+ for (int i = 0; i < NUM_ROUNDS; i++) {
+ int numPartitions = random.nextInt(MAX_NUM_PARTITIONS) + 1;
String functionName = "MoDuLo";
- PartitionFunction partitionFunction =
- PartitionFunctionFactory.getPartitionFunction(functionName, expectedNumPartitions);
-
- testBasicProperties(partitionFunction, functionName, expectedNumPartitions);
-
- // Test int values.
- for (int j = 0; j < 1000; j++) {
- int value = random.nextInt();
- assertEquals(partitionFunction.getPartition(value), (value % expectedNumPartitions));
- }
-
- // Test long values.
- for (int j = 0; j < 1000; j++) {
- long value = random.nextLong();
- assertEquals(partitionFunction.getPartition(value), (value % expectedNumPartitions));
- }
-
- // Test Integer represented as String.
- for (int j = 0; j < 1000; j++) {
- int value = random.nextInt();
- assertEquals(partitionFunction.getPartition(Integer.toString(value)), (value % expectedNumPartitions));
+ PartitionFunction partitionFunction = PartitionFunctionFactory.getPartitionFunction(functionName, numPartitions);
+
+ testBasicProperties(partitionFunction, functionName, numPartitions);
+
+ // Test int values
+ for (int j = 0; j < NUM_ROUNDS; j++) {
+ int value = j == 0 ? Integer.MIN_VALUE : random.nextInt();
+ int expectedPartition = value % numPartitions;
+ if (expectedPartition < 0) {
+ expectedPartition += numPartitions;
+ }
+ assertEquals(partitionFunction.getPartition(value), expectedPartition);
+ assertEquals(partitionFunction.getPartition((long) value), expectedPartition);
+ assertEquals(partitionFunction.getPartition(Integer.toString(value)), expectedPartition);
}
- // Test Long represented as String.
- for (int j = 0; j < 1000; j++) {
- long value = random.nextLong();
- assertEquals(partitionFunction.getPartition(Long.toString(value)), (value % expectedNumPartitions));
+ // Test long values
+ for (int j = 0; j < NUM_ROUNDS; j++) {
+ long value = j == 0 ? Long.MIN_VALUE : random.nextLong();
+ int expectedPartition = (int) (value % numPartitions);
+ if (expectedPartition < 0) {
+ expectedPartition += numPartitions;
+ }
+ assertEquals(partitionFunction.getPartition(value), expectedPartition);
+ assertEquals(partitionFunction.getPartition(Long.toString(value)), expectedPartition);
}
}
}
@@ -98,25 +93,20 @@ public class PartitionFunctionTest {
long seed = System.currentTimeMillis();
Random random = new Random(seed);
- for (int i = 0; i < 1000; i++) {
- int expectedNumPartitions = Math.abs(random.nextInt());
-
- // Avoid divide-by-zero.
- if (expectedNumPartitions == 0) {
- expectedNumPartitions = 1;
- }
+ for (int i = 0; i < NUM_ROUNDS; i++) {
+ int numPartitions = random.nextInt(MAX_NUM_PARTITIONS) + 1;
String functionName = "mUrmur";
- PartitionFunction partitionFunction =
- PartitionFunctionFactory.getPartitionFunction(functionName, expectedNumPartitions);
+ PartitionFunction partitionFunction = PartitionFunctionFactory.getPartitionFunction(functionName, numPartitions);
- testBasicProperties(partitionFunction, functionName, expectedNumPartitions);
+ testBasicProperties(partitionFunction, functionName, numPartitions);
- for (int j = 0; j < 1000; j++) {
- int value = random.nextInt();
- String stringValue = Integer.toString(value);
- assertTrue(partitionFunction.getPartition(stringValue) < expectedNumPartitions,
- "Illegal: " + partitionFunction.getPartition(stringValue) + " " + expectedNumPartitions);
+ for (int j = 0; j < NUM_ROUNDS; j++) {
+ int value = j == 0 ? Integer.MIN_VALUE : random.nextInt();
+ int partition1 = partitionFunction.getPartition(value);
+ int partition2 = partitionFunction.getPartition(Integer.toString(value));
+ assertEquals(partition1, partition2);
+ assertTrue(partition1 >= 0 && partition1 < numPartitions);
}
}
}
@@ -133,24 +123,20 @@ public class PartitionFunctionTest {
long seed = System.currentTimeMillis();
Random random = new Random(seed);
- for (int i = 0; i < 1000; i++) {
- int expectedNumPartitions = Math.abs(random.nextInt());
-
- // Avoid divide-by-zero.
- if (expectedNumPartitions == 0) {
- expectedNumPartitions = 1;
- }
+ for (int i = 0; i < NUM_ROUNDS; i++) {
+ int numPartitions = random.nextInt(MAX_NUM_PARTITIONS) + 1;
String functionName = "bYteArray";
- PartitionFunction partitionFunction =
- PartitionFunctionFactory.getPartitionFunction(functionName, expectedNumPartitions);
+ PartitionFunction partitionFunction = PartitionFunctionFactory.getPartitionFunction(functionName, numPartitions);
- testBasicProperties(partitionFunction, functionName, expectedNumPartitions);
+ testBasicProperties(partitionFunction, functionName, numPartitions);
- for (int j = 0; j < 1000; j++) {
- Integer value = random.nextInt();
- assertTrue(partitionFunction.getPartition(value) < expectedNumPartitions,
- "Illegal: " + partitionFunction.getPartition(value) + " " + expectedNumPartitions);
+ for (int j = 0; j < NUM_ROUNDS; j++) {
+ int value = j == 0 ? Integer.MIN_VALUE : random.nextInt();
+ int partition1 = partitionFunction.getPartition(value);
+ int partition2 = partitionFunction.getPartition(Integer.toString(value));
+ assertEquals(partition1, partition2);
+ assertTrue(partition1 >= 0 && partition1 < numPartitions);
}
}
}
@@ -160,31 +146,30 @@ public class PartitionFunctionTest {
long seed = System.currentTimeMillis();
Random random = new Random(seed);
- for (int i = 0; i < 1000; i++) {
- int expectedNumPartitions = Math.abs(random.nextInt());
-
- // Avoid divide-by-zero.
- if (expectedNumPartitions == 0) {
- expectedNumPartitions = 1;
- }
+ for (int i = 0; i < NUM_ROUNDS; i++) {
+ int numPartitions = random.nextInt(MAX_NUM_PARTITIONS) + 1;
String functionName = "HaShCoDe";
- PartitionFunction partitionFunction =
- PartitionFunctionFactory.getPartitionFunction(functionName, expectedNumPartitions);
+ PartitionFunction partitionFunction = PartitionFunctionFactory.getPartitionFunction(functionName, numPartitions);
- testBasicProperties(partitionFunction, functionName, expectedNumPartitions);
+ testBasicProperties(partitionFunction, functionName, numPartitions);
- // Test Integer values
- for (int j = 0; j < 1000; j++) {
- Integer value = random.nextInt();
- assertEquals(partitionFunction.getPartition(value),
- Math.abs(value.toString().hashCode()) % expectedNumPartitions);
+ // Test int values
+ for (int j = 0; j < NUM_ROUNDS; j++) {
+ int value = j == 0 ? Integer.MIN_VALUE : random.nextInt();
+ int hashCode = Integer.toString(value).hashCode();
+ int expectedPartition = ((hashCode == Integer.MIN_VALUE) ? 0 : Math.abs(hashCode)) % numPartitions;
+ assertEquals(partitionFunction.getPartition(value), expectedPartition);
+ assertEquals(partitionFunction.getPartition(Integer.toString(value)), expectedPartition);
}
- // Test Double values represented as String.
- for (int j = 0; j < 1000; j++) {
- String value = String.valueOf(random.nextDouble());
- assertEquals(partitionFunction.getPartition(value), Math.abs(value.hashCode()) % expectedNumPartitions);
+ // Test double values
+ for (int j = 0; j < NUM_ROUNDS; j++) {
+ double value = j == 0 ? Double.NEGATIVE_INFINITY : random.nextDouble();
+ int hashCode = Double.toString(value).hashCode();
+ int expectedPartition = ((hashCode == Integer.MIN_VALUE) ? 0 : Math.abs(hashCode)) % numPartitions;
+ assertEquals(partitionFunction.getPartition(value), expectedPartition);
+ assertEquals(partitionFunction.getPartition(Double.toString(value)), expectedPartition);
}
}
}
@@ -200,7 +185,8 @@ public class PartitionFunctionTest {
}
/**
- * Tests the equivalence of org.apache.kafka.common.utils.Utils::murmur2 and {@link MurmurPartitionFunction::murmur2}
+ * Tests the equivalence of org.apache.kafka.common.utils.Utils::murmur2 and
+ * {@link MurmurPartitionFunction#getPartition}
* Our implementation of murmur2 has been copied over from Utils::murmur2
*/
@Test
@@ -232,8 +218,7 @@ public class PartitionFunctionTest {
/**
* Tests the equivalence of partitioning using org.apache.kafka.common.utils.Utils::partition and
- * {@link MurmurPartitionFunction
- * ::getPartition}
+ * {@link MurmurPartitionFunction#getPartition}
*/
@Test
public void testMurmurPartitionFunctionEquivalence() {
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@pinot.apache.org
For additional commands, e-mail: commits-help@pinot.apache.org