You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by cu...@apache.org on 2019/07/25 06:08:31 UTC

[arrow] branch master updated: ARROW-1184: [Java] Dictionary.equals is not working correctly

This is an automated email from the ASF dual-hosted git repository.

cutlerb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new f4fcf56  ARROW-1184: [Java] Dictionary.equals is not working correctly
f4fcf56 is described below

commit f4fcf568c4386e2249a5f7f991d2ec5cf570108f
Author: tianchen <ni...@alibaba-inc.com>
AuthorDate: Wed Jul 24 23:08:11 2019 -0700

    ARROW-1184: [Java] Dictionary.equals is not working correctly
    
    Related to [ARROW-1184](https://issues.apache.org/jira/browse/ARROW-1184).
    The Dictionary.equals method does not return True when the dictionaries are equal. This is because equals is not implemented for FieldVector and so that comparison defaults to comparing the two objects only and not the vector data.
    
    Closes #4843 from tianchen92/ARROW-1184 and squashes the following commits:
    
    3511857 <tianchen> revert
    527a9d7 <tianchen> add TODO
    0c69588 <tianchen> fix equals
    1f41366 <tianchen> fix build
    87b7f66 <tianchen> fix
    3fe9389 <tianchen> use Validator logic in Dictionary#equals and add UT
    be74915 <tianchen> move UT
    314d697 <tianchen> ARROW-1184:  Dictionary.equals is not working correctly
    
    Authored-by: tianchen <ni...@alibaba-inc.com>
    Signed-off-by: Bryan Cutler <cu...@gmail.com>
---
 .../src/main/codegen/templates/UnionVector.java    |   6 +-
 .../apache/arrow/vector/dictionary/Dictionary.java |  13 +-
 .../apache/arrow/vector/TestDictionaryVector.java  | 228 +++++++++++++++++++--
 3 files changed, 223 insertions(+), 24 deletions(-)

diff --git a/java/vector/src/main/codegen/templates/UnionVector.java b/java/vector/src/main/codegen/templates/UnionVector.java
index b05005d..c79dfd0 100644
--- a/java/vector/src/main/codegen/templates/UnionVector.java
+++ b/java/vector/src/main/codegen/templates/UnionVector.java
@@ -518,7 +518,11 @@ public class UnionVector implements FieldVector {
     }
 
     public Object getObject(int index) {
-      return getVector(index).getObject(index);
+      ValueVector vector = getVector(index);
+      if (vector != null) {
+        return vector.getObject(index);
+      }
+      return null;
     }
 
     public byte[] get(int index) {
diff --git a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/Dictionary.java b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/Dictionary.java
index ccbdc9c..082d2ba 100644
--- a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/Dictionary.java
+++ b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/Dictionary.java
@@ -22,6 +22,7 @@ import java.util.Objects;
 import org.apache.arrow.vector.FieldVector;
 import org.apache.arrow.vector.types.pojo.ArrowType;
 import org.apache.arrow.vector.types.pojo.DictionaryEncoding;
+import org.apache.arrow.vector.util.Validator;
 
 /**
  * A dictionary (integer to Value mapping) that is used to facilitate
@@ -63,11 +64,21 @@ public class Dictionary {
       return false;
     }
     Dictionary that = (Dictionary) o;
-    return Objects.equals(encoding, that.encoding) && Objects.equals(dictionary, that.dictionary);
+    return Objects.equals(encoding, that.encoding) && compareFieldVector(dictionary, that.dictionary);
   }
 
   @Override
   public int hashCode() {
     return Objects.hash(encoding, dictionary);
   }
+
+  //TODO after vector api support compare two vectors, this should be cleaned up
+  private boolean compareFieldVector(FieldVector vector1, FieldVector vector2) {
+    try {
+      Validator.compareFieldVectors(vector1, vector2);
+    } catch (IllegalArgumentException e) {
+      return false;
+    }
+    return true;
+  }
 }
diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestDictionaryVector.java b/java/vector/src/test/java/org/apache/arrow/vector/TestDictionaryVector.java
index e0bd218..2d6391b 100644
--- a/java/vector/src/test/java/org/apache/arrow/vector/TestDictionaryVector.java
+++ b/java/vector/src/test/java/org/apache/arrow/vector/TestDictionaryVector.java
@@ -20,6 +20,7 @@ package org.apache.arrow.vector;
 import static org.apache.arrow.vector.TestUtils.newVarBinaryVector;
 import static org.apache.arrow.vector.TestUtils.newVarCharVector;
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertTrue;
 
 import java.nio.charset.StandardCharsets;
@@ -67,7 +68,7 @@ public class TestDictionaryVector {
   public void testEncodeStrings() {
     // Create a new value vector
     try (final VarCharVector vector = newVarCharVector("foo", allocator);
-         final VarCharVector dictionaryVector = newVarCharVector("dict", allocator);) {
+        final VarCharVector dictionaryVector = newVarCharVector("dict", allocator);) {
       vector.allocateNew(512, 5);
 
       // set some values
@@ -85,13 +86,14 @@ public class TestDictionaryVector {
       dictionaryVector.setSafe(2, two, 0, two.length);
       dictionaryVector.setValueCount(3);
 
-      Dictionary dictionary = new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, null));
+      Dictionary dictionary =
+          new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, null));
 
       try (final ValueVector encoded = (FieldVector) DictionaryEncoder.encode(vector, dictionary)) {
         // verify indices
         assertEquals(IntVector.class, encoded.getClass());
 
-        IntVector index = ((IntVector)encoded);
+        IntVector index = ((IntVector) encoded);
         assertEquals(5, index.getValueCount());
         assertEquals(0, index.get(0));
         assertEquals(1, index.get(1));
@@ -102,9 +104,9 @@ public class TestDictionaryVector {
         // now run through the decoder and verify we get the original back
         try (ValueVector decoded = DictionaryEncoder.decode(encoded, dictionary)) {
           assertEquals(vector.getClass(), decoded.getClass());
-          assertEquals(vector.getValueCount(), ((VarCharVector)decoded).getValueCount());
+          assertEquals(vector.getValueCount(), ((VarCharVector) decoded).getValueCount());
           for (int i = 0; i < 5; i++) {
-            assertEquals(vector.getObject(i), ((VarCharVector)decoded).getObject(i));
+            assertEquals(vector.getObject(i), ((VarCharVector) decoded).getObject(i));
           }
         }
       }
@@ -115,7 +117,7 @@ public class TestDictionaryVector {
   public void testEncodeLargeVector() {
     // Create a new value vector
     try (final VarCharVector vector = newVarCharVector("foo", allocator);
-         final VarCharVector dictionaryVector = newVarCharVector("dict", allocator);) {
+        final VarCharVector dictionaryVector = newVarCharVector("dict", allocator);) {
       vector.allocateNew();
 
       int count = 10000;
@@ -131,7 +133,8 @@ public class TestDictionaryVector {
       dictionaryVector.setSafe(2, two, 0, two.length);
       dictionaryVector.setValueCount(3);
 
-      Dictionary dictionary = new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, null));
+      Dictionary dictionary =
+          new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, null));
 
 
       try (final ValueVector encoded = (FieldVector) DictionaryEncoder.encode(vector, dictionary)) {
@@ -156,14 +159,6 @@ public class TestDictionaryVector {
     }
   }
 
-  private void writeListVector(UnionListWriter writer, int[] values) {
-    writer.startList();
-    for (int v: values) {
-      writer.integer().writeInt(v);
-    }
-    writer.endList();
-  }
-
   @Test
   public void testEncodeList() {
     // Create a new value vector
@@ -218,13 +213,6 @@ public class TestDictionaryVector {
     }
   }
 
-  private void writeStructVector(NullableStructWriter writer, int value1, long value2) {
-    writer.start();
-    writer.integer("f0").writeInt(value1);
-    writer.bigInt("f1").writeBigInt(value2);
-    writer.end();
-  }
-
   @Test
   public void testEncodeStruct() {
     // Create a new value vector
@@ -406,4 +394,200 @@ public class TestDictionaryVector {
       }
     }
   }
+
+  @Test
+  public void testIntEquals() {
+    //test Int
+    try (final IntVector vector1 = new IntVector("", allocator);
+        final IntVector vector2 = new IntVector("", allocator)) {
+
+      Dictionary dict1 = new Dictionary(vector1, new DictionaryEncoding(1L, false, null));
+      Dictionary dict2 = new Dictionary(vector2, new DictionaryEncoding(1L, false, null));
+
+      vector1.allocateNew(3);
+      vector1.setValueCount(3);
+      vector2.allocateNew(3);
+      vector2.setValueCount(3);
+
+      vector1.setSafe(0, 1);
+      vector1.setSafe(1, 2);
+      vector1.setSafe(2, 3);
+
+      vector2.setSafe(0, 1);
+      vector2.setSafe(1, 2);
+      vector2.setSafe(2, 0);
+
+      assertFalse(dict1.equals(dict2));
+
+      vector2.setSafe(2, 3);
+      assertTrue(dict1.equals(dict2));
+    }
+  }
+
+  @Test
+  public void testVarcharEquals() {
+    try (final VarCharVector vector1 = new VarCharVector("", allocator);
+        final VarCharVector vector2 = new VarCharVector("", allocator)) {
+
+      Dictionary dict1 = new Dictionary(vector1, new DictionaryEncoding(1L, false, null));
+      Dictionary dict2 = new Dictionary(vector2, new DictionaryEncoding(1L, false, null));
+
+      vector1.allocateNew();
+      vector1.setValueCount(3);
+      vector2.allocateNew();
+      vector2.setValueCount(3);
+
+      // set some values
+      vector1.setSafe(0, zero, 0, zero.length);
+      vector1.setSafe(1, one, 0, one.length);
+      vector1.setSafe(2, two, 0, two.length);
+
+      vector2.setSafe(0, zero, 0, zero.length);
+      vector2.setSafe(1, one, 0, one.length);
+      vector2.setSafe(2, one, 0, one.length);
+
+      assertFalse(dict1.equals(dict2));
+
+      vector2.setSafe(2, two, 0, two.length);
+      assertTrue(dict1.equals(dict2));
+    }
+  }
+
+  @Test
+  public void testVarBinaryEquals() {
+    try (final VarBinaryVector vector1 = new VarBinaryVector("", allocator);
+        final VarBinaryVector vector2 = new VarBinaryVector("", allocator)) {
+
+      Dictionary dict1 = new Dictionary(vector1, new DictionaryEncoding(1L, false, null));
+      Dictionary dict2 = new Dictionary(vector2, new DictionaryEncoding(1L, false, null));
+
+      vector1.allocateNew();
+      vector1.setValueCount(3);
+      vector2.allocateNew();
+      vector2.setValueCount(3);
+
+      // set some values
+      vector1.setSafe(0, zero, 0, zero.length);
+      vector1.setSafe(1, one, 0, one.length);
+      vector1.setSafe(2, two, 0, two.length);
+
+      vector2.setSafe(0, zero, 0, zero.length);
+      vector2.setSafe(1, one, 0, one.length);
+      vector2.setSafe(2, one, 0, one.length);
+
+      assertFalse(dict1.equals(dict2));
+
+      vector2.setSafe(2, two, 0, two.length);
+      assertTrue(dict1.equals(dict2));
+    }
+  }
+
+  @Test
+  public void testListEquals() {
+    try (final ListVector vector1 = ListVector.empty("", allocator);
+        final ListVector vector2 = ListVector.empty("", allocator);) {
+
+      Dictionary dict1 = new Dictionary(vector1, new DictionaryEncoding(1L, false, null));
+      Dictionary dict2 = new Dictionary(vector2, new DictionaryEncoding(1L, false, null));
+
+      UnionListWriter writer1 = vector1.getWriter();
+      writer1.allocate();
+
+      //set some values
+      writeListVector(writer1, new int[] {1, 2});
+      writeListVector(writer1, new int[] {3, 4});
+      writeListVector(writer1, new int[] {5, 6});
+      writer1.setValueCount(3);
+
+      UnionListWriter writer2 = vector2.getWriter();
+      writer2.allocate();
+
+      //set some values
+      writeListVector(writer2, new int[] {1, 2});
+      writeListVector(writer2, new int[] {3, 4});
+      writeListVector(writer2, new int[] {5, 6});
+      writer2.setValueCount(3);
+
+      assertTrue(dict1.equals(dict2));
+    }
+  }
+
+  @Test
+  public void testStructEquals() {
+    try (final StructVector vector1 = StructVector.empty("", allocator);
+        final StructVector vector2 = StructVector.empty("", allocator);) {
+      vector1.addOrGet("f0", FieldType.nullable(new ArrowType.Int(32, true)), IntVector.class);
+      vector1.addOrGet("f1", FieldType.nullable(new ArrowType.Int(64, true)), BigIntVector.class);
+      vector2.addOrGet("f0", FieldType.nullable(new ArrowType.Int(32, true)), IntVector.class);
+      vector2.addOrGet("f1", FieldType.nullable(new ArrowType.Int(64, true)), BigIntVector.class);
+
+      Dictionary dict1 = new Dictionary(vector1, new DictionaryEncoding(1L, false, null));
+      Dictionary dict2 = new Dictionary(vector2, new DictionaryEncoding(1L, false, null));
+
+      NullableStructWriter writer1 = vector1.getWriter();
+      writer1.allocate();
+
+      writeStructVector(writer1, 1, 10L);
+      writeStructVector(writer1, 2, 20L);
+      writer1.setValueCount(2);
+
+      NullableStructWriter writer2 = vector2.getWriter();
+      writer2.allocate();
+
+      writeStructVector(writer2, 1, 10L);
+      writeStructVector(writer2, 2, 20L);
+      writer2.setValueCount(2);
+
+      assertTrue(dict1.equals(dict2));
+    }
+  }
+
+  @Test
+  public void testUnionEquals() {
+    try (final UnionVector vector1 = new UnionVector("", allocator, null);
+        final UnionVector vector2 = new UnionVector("", allocator, null);) {
+
+      final NullableUInt4Holder uInt4Holder = new NullableUInt4Holder();
+      uInt4Holder.value = 10;
+      uInt4Holder.isSet = 1;
+
+      final NullableIntHolder intHolder = new NullableIntHolder();
+      uInt4Holder.value = 20;
+      uInt4Holder.isSet = 1;
+
+      vector1.setType(0, Types.MinorType.UINT4);
+      vector1.setSafe(0, uInt4Holder);
+
+      vector1.setType(2, Types.MinorType.INT);
+      vector1.setSafe(2, intHolder);
+      vector1.setValueCount(3);
+
+      vector2.setType(0, Types.MinorType.UINT4);
+      vector2.setSafe(0, uInt4Holder);
+
+      vector2.setType(2, Types.MinorType.INT);
+      vector2.setSafe(2, intHolder);
+      vector2.setValueCount(3);
+
+      Dictionary dict1 = new Dictionary(vector1, new DictionaryEncoding(1L, false, null));
+      Dictionary dict2 = new Dictionary(vector2, new DictionaryEncoding(1L, false, null));
+
+      assertTrue(dict1.equals(dict2));
+    }
+  }
+
+  private void writeStructVector(NullableStructWriter writer, int value1, long value2) {
+    writer.start();
+    writer.integer("f0").writeInt(value1);
+    writer.bigInt("f1").writeBigInt(value2);
+    writer.end();
+  }
+
+  private void writeListVector(UnionListWriter writer, int[] values) {
+    writer.startList();
+    for (int v: values) {
+      writer.integer().writeInt(v);
+    }
+    writer.endList();
+  }
 }