You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by jp...@apache.org on 2020/12/18 15:15:15 UTC

[lucene-solr] 01/02: LUCENE-9629: Use computed masks (#2113)

This is an automated email from the ASF dual-hosted git repository.

jpountz pushed a commit to branch branch_8x
in repository https://gitbox.apache.org/repos/asf/lucene-solr.git

commit 1e1cb518251f62b79b5fa84d45366458b34f3ce1
Author: gf2121 <52...@users.noreply.github.com>
AuthorDate: Fri Dec 18 08:59:40 2020 -0600

    LUCENE-9629: Use computed masks (#2113)
    
    Co-authored-by: 郭峰 <gu...@bytedance.com>
---
 .../org/apache/lucene/codecs/lucene84/ForUtil.java | 131 ++++++++++++---------
 .../apache/lucene/codecs/lucene84/gen_ForUtil.py   |  44 ++++---
 2 files changed, 101 insertions(+), 74 deletions(-)

diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene84/ForUtil.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene84/ForUtil.java
index eb07ec1..266624f 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene84/ForUtil.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene84/ForUtil.java
@@ -250,17 +250,17 @@ final class ForUtil {
     final int remainingBitsPerLong = shift + bitsPerValue;
     final long maskRemainingBitsPerLong;
     if (nextPrimitive == 8) {
-      maskRemainingBitsPerLong = mask8(remainingBitsPerLong);
+      maskRemainingBitsPerLong = MASKS8[remainingBitsPerLong];
     } else if (nextPrimitive == 16) {
-      maskRemainingBitsPerLong = mask16(remainingBitsPerLong);
+      maskRemainingBitsPerLong = MASKS16[remainingBitsPerLong];
     } else {
-      maskRemainingBitsPerLong = mask32(remainingBitsPerLong);
+      maskRemainingBitsPerLong = MASKS32[remainingBitsPerLong];
     }
 
     int tmpIdx = 0;
     int remainingBitsPerValue = bitsPerValue;
     while (idx < numLongs) {
-      if (remainingBitsPerValue > remainingBitsPerLong) {
+      if (remainingBitsPerValue >= remainingBitsPerLong) {
         remainingBitsPerValue -= remainingBitsPerLong;
         tmp[tmpIdx++] |= (longs[idx] >>> remainingBitsPerValue) & maskRemainingBitsPerLong;
         if (remainingBitsPerValue == 0) {
@@ -270,14 +270,14 @@ final class ForUtil {
       } else {
         final long mask1, mask2;
         if (nextPrimitive == 8) {
-          mask1 = mask8(remainingBitsPerValue);
-          mask2 = mask8(remainingBitsPerLong - remainingBitsPerValue);
+          mask1 = MASKS8[remainingBitsPerValue];
+          mask2 = MASKS8[remainingBitsPerLong - remainingBitsPerValue];
         } else if (nextPrimitive == 16) {
-          mask1 = mask16(remainingBitsPerValue);
-          mask2 = mask16(remainingBitsPerLong - remainingBitsPerValue);
+          mask1 = MASKS16[remainingBitsPerValue];
+          mask2 = MASKS16[remainingBitsPerLong - remainingBitsPerValue];
         } else {
-          mask1 = mask32(remainingBitsPerValue);
-          mask2 = mask32(remainingBitsPerLong - remainingBitsPerValue);
+          mask1 = MASKS32[remainingBitsPerValue];
+          mask2 = MASKS32[remainingBitsPerLong - remainingBitsPerValue];
         }
         tmp[tmpIdx] |= (longs[idx++] & mask1) << (remainingBitsPerLong - remainingBitsPerValue);
         remainingBitsPerValue = bitsPerValue - remainingBitsPerLong + remainingBitsPerValue;
@@ -302,7 +302,7 @@ final class ForUtil {
   private static void decodeSlow(int bitsPerValue, DataInput in, long[] tmp, long[] longs) throws IOException {
     final int numLongs = bitsPerValue << 1;
     in.readLELongs(tmp, 0, numLongs);
-    final long mask = mask32(bitsPerValue);
+    final long mask = MASKS32[bitsPerValue];
     int longsIdx = 0;
     int shift = 32 - bitsPerValue;
     for (; shift >= 0; shift -= bitsPerValue) {
@@ -310,18 +310,18 @@ final class ForUtil {
       longsIdx += numLongs;
     }
     final int remainingBitsPerLong = shift + bitsPerValue;
-    final long mask32RemainingBitsPerLong = mask32(remainingBitsPerLong);
+    final long mask32RemainingBitsPerLong = MASKS32[remainingBitsPerLong];
     int tmpIdx = 0;
     int remainingBits = remainingBitsPerLong;
     for (; longsIdx < BLOCK_SIZE / 2; ++longsIdx) {
       int b = bitsPerValue - remainingBits;
-      long l = (tmp[tmpIdx++] & mask32(remainingBits)) << b;
+      long l = (tmp[tmpIdx++] & MASKS32[remainingBits]) << b;
       while (b >= remainingBitsPerLong) {
         b -= remainingBitsPerLong;
         l |= (tmp[tmpIdx++] & mask32RemainingBitsPerLong) << b;
       }
       if (b > 0) {
-        l |= (tmp[tmpIdx] >>> (remainingBitsPerLong-b)) & mask32(b);
+        l |= (tmp[tmpIdx] >>> (remainingBitsPerLong-b)) & MASKS32[b];
         remainingBits = remainingBitsPerLong - b;
       } else {
         remainingBits = remainingBitsPerLong;
@@ -341,50 +341,65 @@ final class ForUtil {
     }
   }
 
-  private static final long MASK8_1 = mask8(1);
-  private static final long MASK8_2 = mask8(2);
-  private static final long MASK8_3 = mask8(3);
-  private static final long MASK8_4 = mask8(4);
-  private static final long MASK8_5 = mask8(5);
-  private static final long MASK8_6 = mask8(6);
-  private static final long MASK8_7 = mask8(7);
-  private static final long MASK16_1 = mask16(1);
-  private static final long MASK16_2 = mask16(2);
-  private static final long MASK16_3 = mask16(3);
-  private static final long MASK16_4 = mask16(4);
-  private static final long MASK16_5 = mask16(5);
-  private static final long MASK16_6 = mask16(6);
-  private static final long MASK16_7 = mask16(7);
-  private static final long MASK16_9 = mask16(9);
-  private static final long MASK16_10 = mask16(10);
-  private static final long MASK16_11 = mask16(11);
-  private static final long MASK16_12 = mask16(12);
-  private static final long MASK16_13 = mask16(13);
-  private static final long MASK16_14 = mask16(14);
-  private static final long MASK16_15 = mask16(15);
-  private static final long MASK32_1 = mask32(1);
-  private static final long MASK32_2 = mask32(2);
-  private static final long MASK32_3 = mask32(3);
-  private static final long MASK32_4 = mask32(4);
-  private static final long MASK32_5 = mask32(5);
-  private static final long MASK32_6 = mask32(6);
-  private static final long MASK32_7 = mask32(7);
-  private static final long MASK32_8 = mask32(8);
-  private static final long MASK32_9 = mask32(9);
-  private static final long MASK32_10 = mask32(10);
-  private static final long MASK32_11 = mask32(11);
-  private static final long MASK32_12 = mask32(12);
-  private static final long MASK32_13 = mask32(13);
-  private static final long MASK32_14 = mask32(14);
-  private static final long MASK32_15 = mask32(15);
-  private static final long MASK32_17 = mask32(17);
-  private static final long MASK32_18 = mask32(18);
-  private static final long MASK32_19 = mask32(19);
-  private static final long MASK32_20 = mask32(20);
-  private static final long MASK32_21 = mask32(21);
-  private static final long MASK32_22 = mask32(22);
-  private static final long MASK32_23 = mask32(23);
-  private static final long MASK32_24 = mask32(24);
+  private static final long[] MASKS8 = new long[8];
+  private static final long[] MASKS16 = new long[16];
+  private static final long[] MASKS32 = new long[32];
+  static {
+    for (int i = 0; i < 8; ++i) {
+      MASKS8[i] = mask8(i);
+    }
+    for (int i = 0; i < 16; ++i) {
+      MASKS16[i] = mask16(i);
+    }
+    for (int i = 0; i < 32; ++i) {
+      MASKS32[i] = mask32(i);
+    }
+  }
+  //mark values in array as final longs to avoid the cost of reading array, arrays should only be used when the idx is a variable
+  private static final long MASK8_1 = MASKS8[1];
+  private static final long MASK8_2 = MASKS8[2];
+  private static final long MASK8_3 = MASKS8[3];
+  private static final long MASK8_4 = MASKS8[4];
+  private static final long MASK8_5 = MASKS8[5];
+  private static final long MASK8_6 = MASKS8[6];
+  private static final long MASK8_7 = MASKS8[7];
+  private static final long MASK16_1 = MASKS16[1];
+  private static final long MASK16_2 = MASKS16[2];
+  private static final long MASK16_3 = MASKS16[3];
+  private static final long MASK16_4 = MASKS16[4];
+  private static final long MASK16_5 = MASKS16[5];
+  private static final long MASK16_6 = MASKS16[6];
+  private static final long MASK16_7 = MASKS16[7];
+  private static final long MASK16_9 = MASKS16[9];
+  private static final long MASK16_10 = MASKS16[10];
+  private static final long MASK16_11 = MASKS16[11];
+  private static final long MASK16_12 = MASKS16[12];
+  private static final long MASK16_13 = MASKS16[13];
+  private static final long MASK16_14 = MASKS16[14];
+  private static final long MASK16_15 = MASKS16[15];
+  private static final long MASK32_1 = MASKS32[1];
+  private static final long MASK32_2 = MASKS32[2];
+  private static final long MASK32_3 = MASKS32[3];
+  private static final long MASK32_4 = MASKS32[4];
+  private static final long MASK32_5 = MASKS32[5];
+  private static final long MASK32_6 = MASKS32[6];
+  private static final long MASK32_7 = MASKS32[7];
+  private static final long MASK32_8 = MASKS32[8];
+  private static final long MASK32_9 = MASKS32[9];
+  private static final long MASK32_10 = MASKS32[10];
+  private static final long MASK32_11 = MASKS32[11];
+  private static final long MASK32_12 = MASKS32[12];
+  private static final long MASK32_13 = MASKS32[13];
+  private static final long MASK32_14 = MASKS32[14];
+  private static final long MASK32_15 = MASKS32[15];
+  private static final long MASK32_17 = MASKS32[17];
+  private static final long MASK32_18 = MASKS32[18];
+  private static final long MASK32_19 = MASKS32[19];
+  private static final long MASK32_20 = MASKS32[20];
+  private static final long MASK32_21 = MASKS32[21];
+  private static final long MASK32_22 = MASKS32[22];
+  private static final long MASK32_23 = MASKS32[23];
+  private static final long MASK32_24 = MASKS32[24];
 
 
   /**
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene84/gen_ForUtil.py b/lucene/core/src/java/org/apache/lucene/codecs/lucene84/gen_ForUtil.py
index 94f31e2..3025618 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene84/gen_ForUtil.py
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene84/gen_ForUtil.py
@@ -21,6 +21,7 @@ from fractions import gcd
 
 MAX_SPECIALIZED_BITS_PER_VALUE = 24
 OUTPUT_FILE = "ForUtil.java"
+PRIMITIVE_SIZE = [8, 16, 32]
 HEADER = """// This file has been automatically generated, DO NOT EDIT
 
 /*
@@ -273,17 +274,17 @@ final class ForUtil {
     final int remainingBitsPerLong = shift + bitsPerValue;
     final long maskRemainingBitsPerLong;
     if (nextPrimitive == 8) {
-      maskRemainingBitsPerLong = mask8(remainingBitsPerLong);
+      maskRemainingBitsPerLong = MASKS8[remainingBitsPerLong];
     } else if (nextPrimitive == 16) {
-      maskRemainingBitsPerLong = mask16(remainingBitsPerLong);
+      maskRemainingBitsPerLong = MASKS16[remainingBitsPerLong];
     } else {
-      maskRemainingBitsPerLong = mask32(remainingBitsPerLong);
+      maskRemainingBitsPerLong = MASKS32[remainingBitsPerLong];
     }
 
     int tmpIdx = 0;
     int remainingBitsPerValue = bitsPerValue;
     while (idx < numLongs) {
-      if (remainingBitsPerValue > remainingBitsPerLong) {
+      if (remainingBitsPerValue >= remainingBitsPerLong) {
         remainingBitsPerValue -= remainingBitsPerLong;
         tmp[tmpIdx++] |= (longs[idx] >>> remainingBitsPerValue) & maskRemainingBitsPerLong;
         if (remainingBitsPerValue == 0) {
@@ -293,14 +294,14 @@ final class ForUtil {
       } else {
         final long mask1, mask2;
         if (nextPrimitive == 8) {
-          mask1 = mask8(remainingBitsPerValue);
-          mask2 = mask8(remainingBitsPerLong - remainingBitsPerValue);
+          mask1 = MASKS8[remainingBitsPerValue];
+          mask2 = MASKS8[remainingBitsPerLong - remainingBitsPerValue];
         } else if (nextPrimitive == 16) {
-          mask1 = mask16(remainingBitsPerValue);
-          mask2 = mask16(remainingBitsPerLong - remainingBitsPerValue);
+          mask1 = MASKS16[remainingBitsPerValue];
+          mask2 = MASKS16[remainingBitsPerLong - remainingBitsPerValue];
         } else {
-          mask1 = mask32(remainingBitsPerValue);
-          mask2 = mask32(remainingBitsPerLong - remainingBitsPerValue);
+          mask1 = MASKS32[remainingBitsPerValue];
+          mask2 = MASKS32[remainingBitsPerLong - remainingBitsPerValue];
         }
         tmp[tmpIdx] |= (longs[idx++] & mask1) << (remainingBitsPerLong - remainingBitsPerValue);
         remainingBitsPerValue = bitsPerValue - remainingBitsPerLong + remainingBitsPerValue;
@@ -325,7 +326,7 @@ final class ForUtil {
   private static void decodeSlow(int bitsPerValue, DataInput in, long[] tmp, long[] longs) throws IOException {
     final int numLongs = bitsPerValue << 1;
     in.readLELongs(tmp, 0, numLongs);
-    final long mask = mask32(bitsPerValue);
+    final long mask = MASKS32[bitsPerValue];
     int longsIdx = 0;
     int shift = 32 - bitsPerValue;
     for (; shift >= 0; shift -= bitsPerValue) {
@@ -333,18 +334,18 @@ final class ForUtil {
       longsIdx += numLongs;
     }
     final int remainingBitsPerLong = shift + bitsPerValue;
-    final long mask32RemainingBitsPerLong = mask32(remainingBitsPerLong);
+    final long mask32RemainingBitsPerLong = MASKS32[remainingBitsPerLong];
     int tmpIdx = 0;
     int remainingBits = remainingBitsPerLong;
     for (; longsIdx < BLOCK_SIZE / 2; ++longsIdx) {
       int b = bitsPerValue - remainingBits;
-      long l = (tmp[tmpIdx++] & mask32(remainingBits)) << b;
+      long l = (tmp[tmpIdx++] & MASKS32[remainingBits]) << b;
       while (b >= remainingBitsPerLong) {
         b -= remainingBitsPerLong;
         l |= (tmp[tmpIdx++] & mask32RemainingBitsPerLong) << b;
       }
       if (b > 0) {
-        l |= (tmp[tmpIdx] >>> (remainingBitsPerLong-b)) & mask32(b);
+        l |= (tmp[tmpIdx] >>> (remainingBitsPerLong-b)) & MASKS32[b];
         remainingBits = remainingBitsPerLong - b;
       } else {
         remainingBits = remainingBitsPerLong;
@@ -374,6 +375,7 @@ def writeRemainderWithSIMDOptimize(bpv, next_primitive, remaining_bits_per_long,
     num_values /= 2
     iteration *= 2
 
+
   f.write('    shiftLongs(tmp, %d, tmp, 0, 0, MASK%d_%d);\n' % (iteration * num_longs, next_primitive, remaining_bits_per_long))
   f.write('    for (int iter = 0, tmpIdx = 0, longsIdx = %d; iter < %d; ++iter, tmpIdx += %d, longsIdx += %d) {\n' %(o, iteration, num_longs, num_values))
   tmp_idx = 0
@@ -450,11 +452,21 @@ def writeDecode(bpv, f):
 if __name__ == '__main__':
   f = open(OUTPUT_FILE, 'w')
   f.write(HEADER)
-  for primitive_size in [8, 16, 32]:
+  for primitive_size in PRIMITIVE_SIZE:
+    f.write('  private static final long[] MASKS%d = new long[%d];\n' %(primitive_size, primitive_size))
+  f.write('  static {\n')
+  for primitive_size in PRIMITIVE_SIZE:
+    f.write('    for (int i = 0; i < %d; ++i) {\n' %primitive_size)
+    f.write('      MASKS%d[i] = mask%d(i);\n' %(primitive_size, primitive_size))
+    f.write('    }\n')
+  f.write('  }\n')
+  f.write('  //mark values in array as final longs to avoid the cost of reading array, arrays should only be used when the idx is a variable\n')
+  for primitive_size in PRIMITIVE_SIZE:
     for bpv in range(1, min(MAX_SPECIALIZED_BITS_PER_VALUE + 1, primitive_size)):
       if bpv * 2 != primitive_size or primitive_size == 8:
-        f.write('  private static final long MASK%d_%d = mask%d(%d);\n' %(primitive_size, bpv, primitive_size, bpv))
+        f.write('  private static final long MASK%d_%d = MASKS%d[%d];\n' %(primitive_size, bpv, primitive_size, bpv))
   f.write('\n')
+
   f.write("""
   /**
    * Decode 128 integers into {@code longs}.