You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by jp...@apache.org on 2020/12/18 15:15:15 UTC
[lucene-solr] 01/02: LUCENE-9629: Use computed masks (#2113)
This is an automated email from the ASF dual-hosted git repository.
jpountz pushed a commit to branch branch_8x
in repository https://gitbox.apache.org/repos/asf/lucene-solr.git
commit 1e1cb518251f62b79b5fa84d45366458b34f3ce1
Author: gf2121 <52...@users.noreply.github.com>
AuthorDate: Fri Dec 18 08:59:40 2020 -0600
LUCENE-9629: Use computed masks (#2113)
Co-authored-by: 郭峰 <gu...@bytedance.com>
---
.../org/apache/lucene/codecs/lucene84/ForUtil.java | 131 ++++++++++++---------
.../apache/lucene/codecs/lucene84/gen_ForUtil.py | 44 ++++---
2 files changed, 101 insertions(+), 74 deletions(-)
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene84/ForUtil.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene84/ForUtil.java
index eb07ec1..266624f 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene84/ForUtil.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene84/ForUtil.java
@@ -250,17 +250,17 @@ final class ForUtil {
final int remainingBitsPerLong = shift + bitsPerValue;
final long maskRemainingBitsPerLong;
if (nextPrimitive == 8) {
- maskRemainingBitsPerLong = mask8(remainingBitsPerLong);
+ maskRemainingBitsPerLong = MASKS8[remainingBitsPerLong];
} else if (nextPrimitive == 16) {
- maskRemainingBitsPerLong = mask16(remainingBitsPerLong);
+ maskRemainingBitsPerLong = MASKS16[remainingBitsPerLong];
} else {
- maskRemainingBitsPerLong = mask32(remainingBitsPerLong);
+ maskRemainingBitsPerLong = MASKS32[remainingBitsPerLong];
}
int tmpIdx = 0;
int remainingBitsPerValue = bitsPerValue;
while (idx < numLongs) {
- if (remainingBitsPerValue > remainingBitsPerLong) {
+ if (remainingBitsPerValue >= remainingBitsPerLong) {
remainingBitsPerValue -= remainingBitsPerLong;
tmp[tmpIdx++] |= (longs[idx] >>> remainingBitsPerValue) & maskRemainingBitsPerLong;
if (remainingBitsPerValue == 0) {
@@ -270,14 +270,14 @@ final class ForUtil {
} else {
final long mask1, mask2;
if (nextPrimitive == 8) {
- mask1 = mask8(remainingBitsPerValue);
- mask2 = mask8(remainingBitsPerLong - remainingBitsPerValue);
+ mask1 = MASKS8[remainingBitsPerValue];
+ mask2 = MASKS8[remainingBitsPerLong - remainingBitsPerValue];
} else if (nextPrimitive == 16) {
- mask1 = mask16(remainingBitsPerValue);
- mask2 = mask16(remainingBitsPerLong - remainingBitsPerValue);
+ mask1 = MASKS16[remainingBitsPerValue];
+ mask2 = MASKS16[remainingBitsPerLong - remainingBitsPerValue];
} else {
- mask1 = mask32(remainingBitsPerValue);
- mask2 = mask32(remainingBitsPerLong - remainingBitsPerValue);
+ mask1 = MASKS32[remainingBitsPerValue];
+ mask2 = MASKS32[remainingBitsPerLong - remainingBitsPerValue];
}
tmp[tmpIdx] |= (longs[idx++] & mask1) << (remainingBitsPerLong - remainingBitsPerValue);
remainingBitsPerValue = bitsPerValue - remainingBitsPerLong + remainingBitsPerValue;
@@ -302,7 +302,7 @@ final class ForUtil {
private static void decodeSlow(int bitsPerValue, DataInput in, long[] tmp, long[] longs) throws IOException {
final int numLongs = bitsPerValue << 1;
in.readLELongs(tmp, 0, numLongs);
- final long mask = mask32(bitsPerValue);
+ final long mask = MASKS32[bitsPerValue];
int longsIdx = 0;
int shift = 32 - bitsPerValue;
for (; shift >= 0; shift -= bitsPerValue) {
@@ -310,18 +310,18 @@ final class ForUtil {
longsIdx += numLongs;
}
final int remainingBitsPerLong = shift + bitsPerValue;
- final long mask32RemainingBitsPerLong = mask32(remainingBitsPerLong);
+ final long mask32RemainingBitsPerLong = MASKS32[remainingBitsPerLong];
int tmpIdx = 0;
int remainingBits = remainingBitsPerLong;
for (; longsIdx < BLOCK_SIZE / 2; ++longsIdx) {
int b = bitsPerValue - remainingBits;
- long l = (tmp[tmpIdx++] & mask32(remainingBits)) << b;
+ long l = (tmp[tmpIdx++] & MASKS32[remainingBits]) << b;
while (b >= remainingBitsPerLong) {
b -= remainingBitsPerLong;
l |= (tmp[tmpIdx++] & mask32RemainingBitsPerLong) << b;
}
if (b > 0) {
- l |= (tmp[tmpIdx] >>> (remainingBitsPerLong-b)) & mask32(b);
+ l |= (tmp[tmpIdx] >>> (remainingBitsPerLong-b)) & MASKS32[b];
remainingBits = remainingBitsPerLong - b;
} else {
remainingBits = remainingBitsPerLong;
@@ -341,50 +341,65 @@ final class ForUtil {
}
}
- private static final long MASK8_1 = mask8(1);
- private static final long MASK8_2 = mask8(2);
- private static final long MASK8_3 = mask8(3);
- private static final long MASK8_4 = mask8(4);
- private static final long MASK8_5 = mask8(5);
- private static final long MASK8_6 = mask8(6);
- private static final long MASK8_7 = mask8(7);
- private static final long MASK16_1 = mask16(1);
- private static final long MASK16_2 = mask16(2);
- private static final long MASK16_3 = mask16(3);
- private static final long MASK16_4 = mask16(4);
- private static final long MASK16_5 = mask16(5);
- private static final long MASK16_6 = mask16(6);
- private static final long MASK16_7 = mask16(7);
- private static final long MASK16_9 = mask16(9);
- private static final long MASK16_10 = mask16(10);
- private static final long MASK16_11 = mask16(11);
- private static final long MASK16_12 = mask16(12);
- private static final long MASK16_13 = mask16(13);
- private static final long MASK16_14 = mask16(14);
- private static final long MASK16_15 = mask16(15);
- private static final long MASK32_1 = mask32(1);
- private static final long MASK32_2 = mask32(2);
- private static final long MASK32_3 = mask32(3);
- private static final long MASK32_4 = mask32(4);
- private static final long MASK32_5 = mask32(5);
- private static final long MASK32_6 = mask32(6);
- private static final long MASK32_7 = mask32(7);
- private static final long MASK32_8 = mask32(8);
- private static final long MASK32_9 = mask32(9);
- private static final long MASK32_10 = mask32(10);
- private static final long MASK32_11 = mask32(11);
- private static final long MASK32_12 = mask32(12);
- private static final long MASK32_13 = mask32(13);
- private static final long MASK32_14 = mask32(14);
- private static final long MASK32_15 = mask32(15);
- private static final long MASK32_17 = mask32(17);
- private static final long MASK32_18 = mask32(18);
- private static final long MASK32_19 = mask32(19);
- private static final long MASK32_20 = mask32(20);
- private static final long MASK32_21 = mask32(21);
- private static final long MASK32_22 = mask32(22);
- private static final long MASK32_23 = mask32(23);
- private static final long MASK32_24 = mask32(24);
+ private static final long[] MASKS8 = new long[8];
+ private static final long[] MASKS16 = new long[16];
+ private static final long[] MASKS32 = new long[32];
+ static {
+ for (int i = 0; i < 8; ++i) {
+ MASKS8[i] = mask8(i);
+ }
+ for (int i = 0; i < 16; ++i) {
+ MASKS16[i] = mask16(i);
+ }
+ for (int i = 0; i < 32; ++i) {
+ MASKS32[i] = mask32(i);
+ }
+ }
+ //mark values in array as final longs to avoid the cost of reading array, arrays should only be used when the idx is a variable
+ private static final long MASK8_1 = MASKS8[1];
+ private static final long MASK8_2 = MASKS8[2];
+ private static final long MASK8_3 = MASKS8[3];
+ private static final long MASK8_4 = MASKS8[4];
+ private static final long MASK8_5 = MASKS8[5];
+ private static final long MASK8_6 = MASKS8[6];
+ private static final long MASK8_7 = MASKS8[7];
+ private static final long MASK16_1 = MASKS16[1];
+ private static final long MASK16_2 = MASKS16[2];
+ private static final long MASK16_3 = MASKS16[3];
+ private static final long MASK16_4 = MASKS16[4];
+ private static final long MASK16_5 = MASKS16[5];
+ private static final long MASK16_6 = MASKS16[6];
+ private static final long MASK16_7 = MASKS16[7];
+ private static final long MASK16_9 = MASKS16[9];
+ private static final long MASK16_10 = MASKS16[10];
+ private static final long MASK16_11 = MASKS16[11];
+ private static final long MASK16_12 = MASKS16[12];
+ private static final long MASK16_13 = MASKS16[13];
+ private static final long MASK16_14 = MASKS16[14];
+ private static final long MASK16_15 = MASKS16[15];
+ private static final long MASK32_1 = MASKS32[1];
+ private static final long MASK32_2 = MASKS32[2];
+ private static final long MASK32_3 = MASKS32[3];
+ private static final long MASK32_4 = MASKS32[4];
+ private static final long MASK32_5 = MASKS32[5];
+ private static final long MASK32_6 = MASKS32[6];
+ private static final long MASK32_7 = MASKS32[7];
+ private static final long MASK32_8 = MASKS32[8];
+ private static final long MASK32_9 = MASKS32[9];
+ private static final long MASK32_10 = MASKS32[10];
+ private static final long MASK32_11 = MASKS32[11];
+ private static final long MASK32_12 = MASKS32[12];
+ private static final long MASK32_13 = MASKS32[13];
+ private static final long MASK32_14 = MASKS32[14];
+ private static final long MASK32_15 = MASKS32[15];
+ private static final long MASK32_17 = MASKS32[17];
+ private static final long MASK32_18 = MASKS32[18];
+ private static final long MASK32_19 = MASKS32[19];
+ private static final long MASK32_20 = MASKS32[20];
+ private static final long MASK32_21 = MASKS32[21];
+ private static final long MASK32_22 = MASKS32[22];
+ private static final long MASK32_23 = MASKS32[23];
+ private static final long MASK32_24 = MASKS32[24];
/**
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene84/gen_ForUtil.py b/lucene/core/src/java/org/apache/lucene/codecs/lucene84/gen_ForUtil.py
index 94f31e2..3025618 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene84/gen_ForUtil.py
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene84/gen_ForUtil.py
@@ -21,6 +21,7 @@ from fractions import gcd
MAX_SPECIALIZED_BITS_PER_VALUE = 24
OUTPUT_FILE = "ForUtil.java"
+PRIMITIVE_SIZE = [8, 16, 32]
HEADER = """// This file has been automatically generated, DO NOT EDIT
/*
@@ -273,17 +274,17 @@ final class ForUtil {
final int remainingBitsPerLong = shift + bitsPerValue;
final long maskRemainingBitsPerLong;
if (nextPrimitive == 8) {
- maskRemainingBitsPerLong = mask8(remainingBitsPerLong);
+ maskRemainingBitsPerLong = MASKS8[remainingBitsPerLong];
} else if (nextPrimitive == 16) {
- maskRemainingBitsPerLong = mask16(remainingBitsPerLong);
+ maskRemainingBitsPerLong = MASKS16[remainingBitsPerLong];
} else {
- maskRemainingBitsPerLong = mask32(remainingBitsPerLong);
+ maskRemainingBitsPerLong = MASKS32[remainingBitsPerLong];
}
int tmpIdx = 0;
int remainingBitsPerValue = bitsPerValue;
while (idx < numLongs) {
- if (remainingBitsPerValue > remainingBitsPerLong) {
+ if (remainingBitsPerValue >= remainingBitsPerLong) {
remainingBitsPerValue -= remainingBitsPerLong;
tmp[tmpIdx++] |= (longs[idx] >>> remainingBitsPerValue) & maskRemainingBitsPerLong;
if (remainingBitsPerValue == 0) {
@@ -293,14 +294,14 @@ final class ForUtil {
} else {
final long mask1, mask2;
if (nextPrimitive == 8) {
- mask1 = mask8(remainingBitsPerValue);
- mask2 = mask8(remainingBitsPerLong - remainingBitsPerValue);
+ mask1 = MASKS8[remainingBitsPerValue];
+ mask2 = MASKS8[remainingBitsPerLong - remainingBitsPerValue];
} else if (nextPrimitive == 16) {
- mask1 = mask16(remainingBitsPerValue);
- mask2 = mask16(remainingBitsPerLong - remainingBitsPerValue);
+ mask1 = MASKS16[remainingBitsPerValue];
+ mask2 = MASKS16[remainingBitsPerLong - remainingBitsPerValue];
} else {
- mask1 = mask32(remainingBitsPerValue);
- mask2 = mask32(remainingBitsPerLong - remainingBitsPerValue);
+ mask1 = MASKS32[remainingBitsPerValue];
+ mask2 = MASKS32[remainingBitsPerLong - remainingBitsPerValue];
}
tmp[tmpIdx] |= (longs[idx++] & mask1) << (remainingBitsPerLong - remainingBitsPerValue);
remainingBitsPerValue = bitsPerValue - remainingBitsPerLong + remainingBitsPerValue;
@@ -325,7 +326,7 @@ final class ForUtil {
private static void decodeSlow(int bitsPerValue, DataInput in, long[] tmp, long[] longs) throws IOException {
final int numLongs = bitsPerValue << 1;
in.readLELongs(tmp, 0, numLongs);
- final long mask = mask32(bitsPerValue);
+ final long mask = MASKS32[bitsPerValue];
int longsIdx = 0;
int shift = 32 - bitsPerValue;
for (; shift >= 0; shift -= bitsPerValue) {
@@ -333,18 +334,18 @@ final class ForUtil {
longsIdx += numLongs;
}
final int remainingBitsPerLong = shift + bitsPerValue;
- final long mask32RemainingBitsPerLong = mask32(remainingBitsPerLong);
+ final long mask32RemainingBitsPerLong = MASKS32[remainingBitsPerLong];
int tmpIdx = 0;
int remainingBits = remainingBitsPerLong;
for (; longsIdx < BLOCK_SIZE / 2; ++longsIdx) {
int b = bitsPerValue - remainingBits;
- long l = (tmp[tmpIdx++] & mask32(remainingBits)) << b;
+ long l = (tmp[tmpIdx++] & MASKS32[remainingBits]) << b;
while (b >= remainingBitsPerLong) {
b -= remainingBitsPerLong;
l |= (tmp[tmpIdx++] & mask32RemainingBitsPerLong) << b;
}
if (b > 0) {
- l |= (tmp[tmpIdx] >>> (remainingBitsPerLong-b)) & mask32(b);
+ l |= (tmp[tmpIdx] >>> (remainingBitsPerLong-b)) & MASKS32[b];
remainingBits = remainingBitsPerLong - b;
} else {
remainingBits = remainingBitsPerLong;
@@ -374,6 +375,7 @@ def writeRemainderWithSIMDOptimize(bpv, next_primitive, remaining_bits_per_long,
num_values /= 2
iteration *= 2
+
f.write(' shiftLongs(tmp, %d, tmp, 0, 0, MASK%d_%d);\n' % (iteration * num_longs, next_primitive, remaining_bits_per_long))
f.write(' for (int iter = 0, tmpIdx = 0, longsIdx = %d; iter < %d; ++iter, tmpIdx += %d, longsIdx += %d) {\n' %(o, iteration, num_longs, num_values))
tmp_idx = 0
@@ -450,11 +452,21 @@ def writeDecode(bpv, f):
if __name__ == '__main__':
f = open(OUTPUT_FILE, 'w')
f.write(HEADER)
- for primitive_size in [8, 16, 32]:
+ for primitive_size in PRIMITIVE_SIZE:
+ f.write(' private static final long[] MASKS%d = new long[%d];\n' %(primitive_size, primitive_size))
+ f.write(' static {\n')
+ for primitive_size in PRIMITIVE_SIZE:
+ f.write(' for (int i = 0; i < %d; ++i) {\n' %primitive_size)
+ f.write(' MASKS%d[i] = mask%d(i);\n' %(primitive_size, primitive_size))
+ f.write(' }\n')
+ f.write(' }\n')
+ f.write(' //mark values in array as final longs to avoid the cost of reading array, arrays should only be used when the idx is a variable\n')
+ for primitive_size in PRIMITIVE_SIZE:
for bpv in range(1, min(MAX_SPECIALIZED_BITS_PER_VALUE + 1, primitive_size)):
if bpv * 2 != primitive_size or primitive_size == 8:
- f.write(' private static final long MASK%d_%d = mask%d(%d);\n' %(primitive_size, bpv, primitive_size, bpv))
+ f.write(' private static final long MASK%d_%d = MASKS%d[%d];\n' %(primitive_size, bpv, primitive_size, bpv))
f.write('\n')
+
f.write("""
/**
* Decode 128 integers into {@code longs}.