You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by jp...@apache.org on 2020/12/18 15:15:14 UTC

[lucene-solr] branch branch_8x updated (453555c -> 9a53155)

This is an automated email from the ASF dual-hosted git repository.

jpountz pushed a change to branch branch_8x
in repository https://gitbox.apache.org/repos/asf/lucene-solr.git.


    from 453555c  Clean up IDEA config generation (#2156)
     new 1e1cb51  LUCENE-9629: Use computed masks (#2113)
     new 9a53155  LUCENE-9635: BM25FQuery - Mask encoded norm long value in array lookup to avoid negative norms in long documents (#2138)

The 2 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.


Summary of changes:
 lucene/CHANGES.txt                                 |   3 +
 .../org/apache/lucene/codecs/lucene84/ForUtil.java | 131 ++++++++++++---------
 .../apache/lucene/codecs/lucene84/gen_ForUtil.py   |  44 ++++---
 .../lucene/search/MultiNormsLeafSimScorer.java     |   2 +-
 .../org/apache/lucene/search/TestBM25FQuery.java   |  79 +++++++++++++
 5 files changed, 184 insertions(+), 75 deletions(-)


[lucene-solr] 02/02: LUCENE-9635: BM25FQuery - Mask encoded norm long value in array lookup to avoid negative norms in long documents (#2138)

Posted by jp...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

jpountz pushed a commit to branch branch_8x
in repository https://gitbox.apache.org/repos/asf/lucene-solr.git

commit 9a5315532b1ba1bcb479dcd7f55a16c444bc7308
Author: yiluncui <75...@users.noreply.github.com>
AuthorDate: Fri Dec 18 06:56:31 2020 -0800

    LUCENE-9635: BM25FQuery - Mask encoded norm long value in array lookup to avoid negative norms in long documents (#2138)
---
 lucene/CHANGES.txt                                 |  3 +
 .../lucene/search/MultiNormsLeafSimScorer.java     |  2 +-
 .../org/apache/lucene/search/TestBM25FQuery.java   | 79 ++++++++++++++++++++++
 3 files changed, 83 insertions(+), 1 deletion(-)

diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt
index 8faeab5..6e39b2e 100644
--- a/lucene/CHANGES.txt
+++ b/lucene/CHANGES.txt
@@ -59,6 +59,9 @@ Bug Fixes
   (Ignacio Vera)   
 
 * LUCENE-9606: Wrap boolean queries generated by shape fields with a Constant score query. (Ignacio Vera)  
+
+* LUCENE-9635: BM25FQuery - Mask encoded norm long value in array lookup.
+  (Yilun Cui)
   
 Other
 ---------------------
diff --git a/lucene/sandbox/src/java/org/apache/lucene/search/MultiNormsLeafSimScorer.java b/lucene/sandbox/src/java/org/apache/lucene/search/MultiNormsLeafSimScorer.java
index 75c9801..9a37d3d 100644
--- a/lucene/sandbox/src/java/org/apache/lucene/search/MultiNormsLeafSimScorer.java
+++ b/lucene/sandbox/src/java/org/apache/lucene/search/MultiNormsLeafSimScorer.java
@@ -126,7 +126,7 @@ final class MultiNormsLeafSimScorer {
       for (int i = 0; i < normsArr.length; i++) {
         boolean found = normsArr[i].advanceExact(target);
         assert found;
-        normValue += weightArr[i] * LENGTH_TABLE[(byte) normsArr[i].longValue()];
+        normValue += weightArr[i] * LENGTH_TABLE[Byte.toUnsignedInt((byte) normsArr[i].longValue())];
       }
       current = SmallFloat.intToByte4(Math.round(normValue));
       return true;
diff --git a/lucene/sandbox/src/test/org/apache/lucene/search/TestBM25FQuery.java b/lucene/sandbox/src/test/org/apache/lucene/search/TestBM25FQuery.java
index 36c0054..5900bbb 100644
--- a/lucene/sandbox/src/test/org/apache/lucene/search/TestBM25FQuery.java
+++ b/lucene/sandbox/src/test/org/apache/lucene/search/TestBM25FQuery.java
@@ -23,11 +23,14 @@ import org.apache.lucene.document.Document;
 import org.apache.lucene.document.Field.Store;
 import org.apache.lucene.document.StringField;
 import org.apache.lucene.document.TextField;
+import org.apache.lucene.index.FieldInvertState;
 import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.IndexWriterConfig;
 import org.apache.lucene.index.MultiReader;
 import org.apache.lucene.index.RandomIndexWriter;
 import org.apache.lucene.index.Term;
 import org.apache.lucene.search.similarities.BM25Similarity;
+import org.apache.lucene.search.similarities.Similarity;
 import org.apache.lucene.store.Directory;
 import org.apache.lucene.util.BytesRef;
 import org.apache.lucene.util.LuceneTestCase;
@@ -166,4 +169,80 @@ public class TestBM25FQuery extends LuceneTestCase {
     w.close();
     dir.close();
   }
+
+  public void testDocWithNegativeNorms() throws IOException {
+    Directory dir = newDirectory();
+    IndexWriterConfig iwc = new IndexWriterConfig();
+    iwc.setSimilarity(new NegativeNormSimilarity());
+    RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc);
+
+    String queryString = "foo";
+
+    Document doc = new Document();
+    //both fields must contain tokens that match the query string "foo"
+    doc.add(new TextField("f", "foo", Store.NO));
+    doc.add(new TextField("g", "foo baz", Store.NO));
+    w.addDocument(doc);
+
+    IndexReader reader = w.getReader();
+    IndexSearcher searcher = newSearcher(reader);
+    BM25FQuery query = new BM25FQuery.Builder()
+            .addField("f")
+            .addField("g")
+            .addTerm(new BytesRef(queryString))
+            .build();
+    TopDocs topDocs = searcher.search(query, 10);
+    CheckHits.checkDocIds("queried docs do not match", new int[]{0}, topDocs.scoreDocs);
+
+    reader.close();
+    w.close();
+    dir.close();
+  }
+
+  public void testMultipleDocsNegativeNorms() throws IOException {
+    Directory dir = newDirectory();
+    IndexWriterConfig iwc = new IndexWriterConfig();
+    iwc.setSimilarity(new NegativeNormSimilarity());
+    RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc);
+
+    String queryString = "foo";
+
+    Document doc0 = new Document();
+    doc0.add(new TextField("f", "foo", Store.NO));
+    doc0.add(new TextField("g", "foo baz", Store.NO));
+    w.addDocument(doc0);
+
+    Document doc1 = new Document();
+    // add another match on the query string to the second doc
+    doc1.add(new TextField("f", "foo is foo", Store.NO));
+    doc1.add(new TextField("g", "foo baz", Store.NO));
+    w.addDocument(doc1);
+
+    IndexReader reader = w.getReader();
+    IndexSearcher searcher = newSearcher(reader);
+    BM25FQuery query = new BM25FQuery.Builder()
+            .addField("f")
+            .addField("g")
+            .addTerm(new BytesRef(queryString))
+            .build();
+    TopDocs topDocs = searcher.search(query, 10);
+    //Return doc1 ahead of doc0 since its tf is higher
+    CheckHits.checkDocIds("queried docs do not match", new int[]{1,0}, topDocs.scoreDocs);
+
+    reader.close();
+    w.close();
+    dir.close();
+  }
+
+  private static final class NegativeNormSimilarity extends Similarity {
+    @Override
+    public long computeNorm(FieldInvertState state) {
+      return -128;
+    }
+
+    @Override
+    public SimScorer scorer(float boost, CollectionStatistics collectionStats, TermStatistics... termStats) {
+      return new BM25Similarity().scorer(boost, collectionStats, termStats);
+    }
+  }
 }


[lucene-solr] 01/02: LUCENE-9629: Use computed masks (#2113)

Posted by jp...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

jpountz pushed a commit to branch branch_8x
in repository https://gitbox.apache.org/repos/asf/lucene-solr.git

commit 1e1cb518251f62b79b5fa84d45366458b34f3ce1
Author: gf2121 <52...@users.noreply.github.com>
AuthorDate: Fri Dec 18 08:59:40 2020 -0600

    LUCENE-9629: Use computed masks (#2113)
    
    Co-authored-by: 郭峰 <gu...@bytedance.com>
---
 .../org/apache/lucene/codecs/lucene84/ForUtil.java | 131 ++++++++++++---------
 .../apache/lucene/codecs/lucene84/gen_ForUtil.py   |  44 ++++---
 2 files changed, 101 insertions(+), 74 deletions(-)

diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene84/ForUtil.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene84/ForUtil.java
index eb07ec1..266624f 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene84/ForUtil.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene84/ForUtil.java
@@ -250,17 +250,17 @@ final class ForUtil {
     final int remainingBitsPerLong = shift + bitsPerValue;
     final long maskRemainingBitsPerLong;
     if (nextPrimitive == 8) {
-      maskRemainingBitsPerLong = mask8(remainingBitsPerLong);
+      maskRemainingBitsPerLong = MASKS8[remainingBitsPerLong];
     } else if (nextPrimitive == 16) {
-      maskRemainingBitsPerLong = mask16(remainingBitsPerLong);
+      maskRemainingBitsPerLong = MASKS16[remainingBitsPerLong];
     } else {
-      maskRemainingBitsPerLong = mask32(remainingBitsPerLong);
+      maskRemainingBitsPerLong = MASKS32[remainingBitsPerLong];
     }
 
     int tmpIdx = 0;
     int remainingBitsPerValue = bitsPerValue;
     while (idx < numLongs) {
-      if (remainingBitsPerValue > remainingBitsPerLong) {
+      if (remainingBitsPerValue >= remainingBitsPerLong) {
         remainingBitsPerValue -= remainingBitsPerLong;
         tmp[tmpIdx++] |= (longs[idx] >>> remainingBitsPerValue) & maskRemainingBitsPerLong;
         if (remainingBitsPerValue == 0) {
@@ -270,14 +270,14 @@ final class ForUtil {
       } else {
         final long mask1, mask2;
         if (nextPrimitive == 8) {
-          mask1 = mask8(remainingBitsPerValue);
-          mask2 = mask8(remainingBitsPerLong - remainingBitsPerValue);
+          mask1 = MASKS8[remainingBitsPerValue];
+          mask2 = MASKS8[remainingBitsPerLong - remainingBitsPerValue];
         } else if (nextPrimitive == 16) {
-          mask1 = mask16(remainingBitsPerValue);
-          mask2 = mask16(remainingBitsPerLong - remainingBitsPerValue);
+          mask1 = MASKS16[remainingBitsPerValue];
+          mask2 = MASKS16[remainingBitsPerLong - remainingBitsPerValue];
         } else {
-          mask1 = mask32(remainingBitsPerValue);
-          mask2 = mask32(remainingBitsPerLong - remainingBitsPerValue);
+          mask1 = MASKS32[remainingBitsPerValue];
+          mask2 = MASKS32[remainingBitsPerLong - remainingBitsPerValue];
         }
         tmp[tmpIdx] |= (longs[idx++] & mask1) << (remainingBitsPerLong - remainingBitsPerValue);
         remainingBitsPerValue = bitsPerValue - remainingBitsPerLong + remainingBitsPerValue;
@@ -302,7 +302,7 @@ final class ForUtil {
   private static void decodeSlow(int bitsPerValue, DataInput in, long[] tmp, long[] longs) throws IOException {
     final int numLongs = bitsPerValue << 1;
     in.readLELongs(tmp, 0, numLongs);
-    final long mask = mask32(bitsPerValue);
+    final long mask = MASKS32[bitsPerValue];
     int longsIdx = 0;
     int shift = 32 - bitsPerValue;
     for (; shift >= 0; shift -= bitsPerValue) {
@@ -310,18 +310,18 @@ final class ForUtil {
       longsIdx += numLongs;
     }
     final int remainingBitsPerLong = shift + bitsPerValue;
-    final long mask32RemainingBitsPerLong = mask32(remainingBitsPerLong);
+    final long mask32RemainingBitsPerLong = MASKS32[remainingBitsPerLong];
     int tmpIdx = 0;
     int remainingBits = remainingBitsPerLong;
     for (; longsIdx < BLOCK_SIZE / 2; ++longsIdx) {
       int b = bitsPerValue - remainingBits;
-      long l = (tmp[tmpIdx++] & mask32(remainingBits)) << b;
+      long l = (tmp[tmpIdx++] & MASKS32[remainingBits]) << b;
       while (b >= remainingBitsPerLong) {
         b -= remainingBitsPerLong;
         l |= (tmp[tmpIdx++] & mask32RemainingBitsPerLong) << b;
       }
       if (b > 0) {
-        l |= (tmp[tmpIdx] >>> (remainingBitsPerLong-b)) & mask32(b);
+        l |= (tmp[tmpIdx] >>> (remainingBitsPerLong-b)) & MASKS32[b];
         remainingBits = remainingBitsPerLong - b;
       } else {
         remainingBits = remainingBitsPerLong;
@@ -341,50 +341,65 @@ final class ForUtil {
     }
   }
 
-  private static final long MASK8_1 = mask8(1);
-  private static final long MASK8_2 = mask8(2);
-  private static final long MASK8_3 = mask8(3);
-  private static final long MASK8_4 = mask8(4);
-  private static final long MASK8_5 = mask8(5);
-  private static final long MASK8_6 = mask8(6);
-  private static final long MASK8_7 = mask8(7);
-  private static final long MASK16_1 = mask16(1);
-  private static final long MASK16_2 = mask16(2);
-  private static final long MASK16_3 = mask16(3);
-  private static final long MASK16_4 = mask16(4);
-  private static final long MASK16_5 = mask16(5);
-  private static final long MASK16_6 = mask16(6);
-  private static final long MASK16_7 = mask16(7);
-  private static final long MASK16_9 = mask16(9);
-  private static final long MASK16_10 = mask16(10);
-  private static final long MASK16_11 = mask16(11);
-  private static final long MASK16_12 = mask16(12);
-  private static final long MASK16_13 = mask16(13);
-  private static final long MASK16_14 = mask16(14);
-  private static final long MASK16_15 = mask16(15);
-  private static final long MASK32_1 = mask32(1);
-  private static final long MASK32_2 = mask32(2);
-  private static final long MASK32_3 = mask32(3);
-  private static final long MASK32_4 = mask32(4);
-  private static final long MASK32_5 = mask32(5);
-  private static final long MASK32_6 = mask32(6);
-  private static final long MASK32_7 = mask32(7);
-  private static final long MASK32_8 = mask32(8);
-  private static final long MASK32_9 = mask32(9);
-  private static final long MASK32_10 = mask32(10);
-  private static final long MASK32_11 = mask32(11);
-  private static final long MASK32_12 = mask32(12);
-  private static final long MASK32_13 = mask32(13);
-  private static final long MASK32_14 = mask32(14);
-  private static final long MASK32_15 = mask32(15);
-  private static final long MASK32_17 = mask32(17);
-  private static final long MASK32_18 = mask32(18);
-  private static final long MASK32_19 = mask32(19);
-  private static final long MASK32_20 = mask32(20);
-  private static final long MASK32_21 = mask32(21);
-  private static final long MASK32_22 = mask32(22);
-  private static final long MASK32_23 = mask32(23);
-  private static final long MASK32_24 = mask32(24);
+  private static final long[] MASKS8 = new long[8];
+  private static final long[] MASKS16 = new long[16];
+  private static final long[] MASKS32 = new long[32];
+  static {
+    for (int i = 0; i < 8; ++i) {
+      MASKS8[i] = mask8(i);
+    }
+    for (int i = 0; i < 16; ++i) {
+      MASKS16[i] = mask16(i);
+    }
+    for (int i = 0; i < 32; ++i) {
+      MASKS32[i] = mask32(i);
+    }
+  }
+  //mark values in array as final longs to avoid the cost of reading array, arrays should only be used when the idx is a variable
+  private static final long MASK8_1 = MASKS8[1];
+  private static final long MASK8_2 = MASKS8[2];
+  private static final long MASK8_3 = MASKS8[3];
+  private static final long MASK8_4 = MASKS8[4];
+  private static final long MASK8_5 = MASKS8[5];
+  private static final long MASK8_6 = MASKS8[6];
+  private static final long MASK8_7 = MASKS8[7];
+  private static final long MASK16_1 = MASKS16[1];
+  private static final long MASK16_2 = MASKS16[2];
+  private static final long MASK16_3 = MASKS16[3];
+  private static final long MASK16_4 = MASKS16[4];
+  private static final long MASK16_5 = MASKS16[5];
+  private static final long MASK16_6 = MASKS16[6];
+  private static final long MASK16_7 = MASKS16[7];
+  private static final long MASK16_9 = MASKS16[9];
+  private static final long MASK16_10 = MASKS16[10];
+  private static final long MASK16_11 = MASKS16[11];
+  private static final long MASK16_12 = MASKS16[12];
+  private static final long MASK16_13 = MASKS16[13];
+  private static final long MASK16_14 = MASKS16[14];
+  private static final long MASK16_15 = MASKS16[15];
+  private static final long MASK32_1 = MASKS32[1];
+  private static final long MASK32_2 = MASKS32[2];
+  private static final long MASK32_3 = MASKS32[3];
+  private static final long MASK32_4 = MASKS32[4];
+  private static final long MASK32_5 = MASKS32[5];
+  private static final long MASK32_6 = MASKS32[6];
+  private static final long MASK32_7 = MASKS32[7];
+  private static final long MASK32_8 = MASKS32[8];
+  private static final long MASK32_9 = MASKS32[9];
+  private static final long MASK32_10 = MASKS32[10];
+  private static final long MASK32_11 = MASKS32[11];
+  private static final long MASK32_12 = MASKS32[12];
+  private static final long MASK32_13 = MASKS32[13];
+  private static final long MASK32_14 = MASKS32[14];
+  private static final long MASK32_15 = MASKS32[15];
+  private static final long MASK32_17 = MASKS32[17];
+  private static final long MASK32_18 = MASKS32[18];
+  private static final long MASK32_19 = MASKS32[19];
+  private static final long MASK32_20 = MASKS32[20];
+  private static final long MASK32_21 = MASKS32[21];
+  private static final long MASK32_22 = MASKS32[22];
+  private static final long MASK32_23 = MASKS32[23];
+  private static final long MASK32_24 = MASKS32[24];
 
 
   /**
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene84/gen_ForUtil.py b/lucene/core/src/java/org/apache/lucene/codecs/lucene84/gen_ForUtil.py
index 94f31e2..3025618 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene84/gen_ForUtil.py
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene84/gen_ForUtil.py
@@ -21,6 +21,7 @@ from fractions import gcd
 
 MAX_SPECIALIZED_BITS_PER_VALUE = 24
 OUTPUT_FILE = "ForUtil.java"
+PRIMITIVE_SIZE = [8, 16, 32]
 HEADER = """// This file has been automatically generated, DO NOT EDIT
 
 /*
@@ -273,17 +274,17 @@ final class ForUtil {
     final int remainingBitsPerLong = shift + bitsPerValue;
     final long maskRemainingBitsPerLong;
     if (nextPrimitive == 8) {
-      maskRemainingBitsPerLong = mask8(remainingBitsPerLong);
+      maskRemainingBitsPerLong = MASKS8[remainingBitsPerLong];
     } else if (nextPrimitive == 16) {
-      maskRemainingBitsPerLong = mask16(remainingBitsPerLong);
+      maskRemainingBitsPerLong = MASKS16[remainingBitsPerLong];
     } else {
-      maskRemainingBitsPerLong = mask32(remainingBitsPerLong);
+      maskRemainingBitsPerLong = MASKS32[remainingBitsPerLong];
     }
 
     int tmpIdx = 0;
     int remainingBitsPerValue = bitsPerValue;
     while (idx < numLongs) {
-      if (remainingBitsPerValue > remainingBitsPerLong) {
+      if (remainingBitsPerValue >= remainingBitsPerLong) {
         remainingBitsPerValue -= remainingBitsPerLong;
         tmp[tmpIdx++] |= (longs[idx] >>> remainingBitsPerValue) & maskRemainingBitsPerLong;
         if (remainingBitsPerValue == 0) {
@@ -293,14 +294,14 @@ final class ForUtil {
       } else {
         final long mask1, mask2;
         if (nextPrimitive == 8) {
-          mask1 = mask8(remainingBitsPerValue);
-          mask2 = mask8(remainingBitsPerLong - remainingBitsPerValue);
+          mask1 = MASKS8[remainingBitsPerValue];
+          mask2 = MASKS8[remainingBitsPerLong - remainingBitsPerValue];
         } else if (nextPrimitive == 16) {
-          mask1 = mask16(remainingBitsPerValue);
-          mask2 = mask16(remainingBitsPerLong - remainingBitsPerValue);
+          mask1 = MASKS16[remainingBitsPerValue];
+          mask2 = MASKS16[remainingBitsPerLong - remainingBitsPerValue];
         } else {
-          mask1 = mask32(remainingBitsPerValue);
-          mask2 = mask32(remainingBitsPerLong - remainingBitsPerValue);
+          mask1 = MASKS32[remainingBitsPerValue];
+          mask2 = MASKS32[remainingBitsPerLong - remainingBitsPerValue];
         }
         tmp[tmpIdx] |= (longs[idx++] & mask1) << (remainingBitsPerLong - remainingBitsPerValue);
         remainingBitsPerValue = bitsPerValue - remainingBitsPerLong + remainingBitsPerValue;
@@ -325,7 +326,7 @@ final class ForUtil {
   private static void decodeSlow(int bitsPerValue, DataInput in, long[] tmp, long[] longs) throws IOException {
     final int numLongs = bitsPerValue << 1;
     in.readLELongs(tmp, 0, numLongs);
-    final long mask = mask32(bitsPerValue);
+    final long mask = MASKS32[bitsPerValue];
     int longsIdx = 0;
     int shift = 32 - bitsPerValue;
     for (; shift >= 0; shift -= bitsPerValue) {
@@ -333,18 +334,18 @@ final class ForUtil {
       longsIdx += numLongs;
     }
     final int remainingBitsPerLong = shift + bitsPerValue;
-    final long mask32RemainingBitsPerLong = mask32(remainingBitsPerLong);
+    final long mask32RemainingBitsPerLong = MASKS32[remainingBitsPerLong];
     int tmpIdx = 0;
     int remainingBits = remainingBitsPerLong;
     for (; longsIdx < BLOCK_SIZE / 2; ++longsIdx) {
       int b = bitsPerValue - remainingBits;
-      long l = (tmp[tmpIdx++] & mask32(remainingBits)) << b;
+      long l = (tmp[tmpIdx++] & MASKS32[remainingBits]) << b;
       while (b >= remainingBitsPerLong) {
         b -= remainingBitsPerLong;
         l |= (tmp[tmpIdx++] & mask32RemainingBitsPerLong) << b;
       }
       if (b > 0) {
-        l |= (tmp[tmpIdx] >>> (remainingBitsPerLong-b)) & mask32(b);
+        l |= (tmp[tmpIdx] >>> (remainingBitsPerLong-b)) & MASKS32[b];
         remainingBits = remainingBitsPerLong - b;
       } else {
         remainingBits = remainingBitsPerLong;
@@ -374,6 +375,7 @@ def writeRemainderWithSIMDOptimize(bpv, next_primitive, remaining_bits_per_long,
     num_values /= 2
     iteration *= 2
 
+
   f.write('    shiftLongs(tmp, %d, tmp, 0, 0, MASK%d_%d);\n' % (iteration * num_longs, next_primitive, remaining_bits_per_long))
   f.write('    for (int iter = 0, tmpIdx = 0, longsIdx = %d; iter < %d; ++iter, tmpIdx += %d, longsIdx += %d) {\n' %(o, iteration, num_longs, num_values))
   tmp_idx = 0
@@ -450,11 +452,21 @@ def writeDecode(bpv, f):
 if __name__ == '__main__':
   f = open(OUTPUT_FILE, 'w')
   f.write(HEADER)
-  for primitive_size in [8, 16, 32]:
+  for primitive_size in PRIMITIVE_SIZE:
+    f.write('  private static final long[] MASKS%d = new long[%d];\n' %(primitive_size, primitive_size))
+  f.write('  static {\n')
+  for primitive_size in PRIMITIVE_SIZE:
+    f.write('    for (int i = 0; i < %d; ++i) {\n' %primitive_size)
+    f.write('      MASKS%d[i] = mask%d(i);\n' %(primitive_size, primitive_size))
+    f.write('    }\n')
+  f.write('  }\n')
+  f.write('  //mark values in array as final longs to avoid the cost of reading array, arrays should only be used when the idx is a variable\n')
+  for primitive_size in PRIMITIVE_SIZE:
     for bpv in range(1, min(MAX_SPECIALIZED_BITS_PER_VALUE + 1, primitive_size)):
       if bpv * 2 != primitive_size or primitive_size == 8:
-        f.write('  private static final long MASK%d_%d = mask%d(%d);\n' %(primitive_size, bpv, primitive_size, bpv))
+        f.write('  private static final long MASK%d_%d = MASKS%d[%d];\n' %(primitive_size, bpv, primitive_size, bpv))
   f.write('\n')
+
   f.write("""
   /**
    * Decode 128 integers into {@code longs}.