You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by we...@apache.org on 2019/08/09 01:28:15 UTC

[arrow] branch master updated: ARROW-6160: [Java] AbstractStructVector#getPrimitiveVectors fails to work with complex child vectors

This is an automated email from the ASF dual-hosted git repository.

wesm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 062cc70  ARROW-6160: [Java] AbstractStructVector#getPrimitiveVectors fails to work with complex child vectors
062cc70 is described below

commit 062cc7013947cf255ad3860e5ff51d64c9d43584
Author: tianchen <ni...@alibaba-inc.com>
AuthorDate: Thu Aug 8 20:28:03 2019 -0500

    ARROW-6160: [Java] AbstractStructVector#getPrimitiveVectors fails to work with complex child vectors
    
    Related to [ARROW-6160](https://issues.apache.org/jira/browse/ARROW-6160).
    
    Currently in AbstractStructVector#getPrimitiveVectors, only struct type child vectors will recursively get primitive vectors, other complex type like ListVector, UnionVector was treated as primitive type and return directly.
    
    For example, Struct(List(Int), Struct(Int, Varchar)) getPrimitiveVectors should return [IntVector, IntVector, VarCharVector] instead of [ListVector, IntVector, VarCharVector]
    
    Closes #5031 from tianchen92/ARROW-6160 and squashes the following commits:
    
    182630ee2 <tianchen> fix
    aef8c4a31 <tianchen> ARROW-XXXX:  AbstractStructVector#getPrimitiveVectors fails to work with complex child vectors
    
    Authored-by: tianchen <ni...@alibaba-inc.com>
    Signed-off-by: Wes McKinney <we...@apache.org>
---
 .../arrow/vector/complex/AbstractStructVector.java | 31 ++++++++++++++-----
 .../org/apache/arrow/vector/TestStructVector.java  | 36 +++++++++++++++++++---
 2 files changed, 56 insertions(+), 11 deletions(-)

diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/AbstractStructVector.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/AbstractStructVector.java
index ba837a2..dc9b1a1 100644
--- a/java/vector/src/main/java/org/apache/arrow/vector/complex/AbstractStructVector.java
+++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/AbstractStructVector.java
@@ -252,17 +252,34 @@ public abstract class AbstractStructVector extends AbstractContainerVector {
    */
   public List<ValueVector> getPrimitiveVectors() {
     final List<ValueVector> primitiveVectors = new ArrayList<>();
-    for (final ValueVector v : vectors.values()) {
-      if (v instanceof AbstractStructVector) {
-        AbstractStructVector structVector = (AbstractStructVector) v;
-        primitiveVectors.addAll(structVector.getPrimitiveVectors());
-      } else {
-        primitiveVectors.add(v);
-      }
+    for (final FieldVector v : vectors.values()) {
+      primitiveVectors.addAll(getPrimitiveVectors(v));
     }
     return primitiveVectors;
   }
 
+  private List<ValueVector> getPrimitiveVectors(FieldVector v) {
+    final List<ValueVector> primitives = new ArrayList<>();
+    if (v instanceof AbstractStructVector) {
+      AbstractStructVector structVector = (AbstractStructVector) v;
+      primitives.addAll(structVector.getPrimitiveVectors());
+    } else if (v instanceof ListVector) {
+      ListVector listVector = (ListVector) v;
+      primitives.addAll(getPrimitiveVectors(listVector.getDataVector()));
+    } else if (v instanceof FixedSizeListVector) {
+      ListVector listVector = (ListVector) v;
+      primitives.addAll(getPrimitiveVectors(listVector.getDataVector()));
+    } else if (v instanceof UnionVector) {
+      UnionVector unionVector = (UnionVector) v;
+      for (final FieldVector vector : unionVector.getChildrenFromFields()) {
+        primitives.addAll(getPrimitiveVectors(vector));
+      }
+    } else {
+      primitives.add(v);
+    }
+    return primitives;
+  }
+
   /**
    * Get a child vector by name.
    * @param name the name of the child to return
diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestStructVector.java b/java/vector/src/test/java/org/apache/arrow/vector/TestStructVector.java
index 9d156eb..272b8ba 100644
--- a/java/vector/src/test/java/org/apache/arrow/vector/TestStructVector.java
+++ b/java/vector/src/test/java/org/apache/arrow/vector/TestStructVector.java
@@ -17,16 +17,16 @@
 
 package org.apache.arrow.vector;
 
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertNotEquals;
-import static org.junit.Assert.assertNotNull;
-import static org.junit.Assert.assertNull;
+import static org.junit.Assert.*;
 
 import java.util.HashMap;
+import java.util.List;
 import java.util.Map;
 
 import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.vector.complex.ListVector;
 import org.apache.arrow.vector.complex.StructVector;
+import org.apache.arrow.vector.complex.UnionVector;
 import org.apache.arrow.vector.holders.ComplexHolder;
 import org.apache.arrow.vector.types.Types.MinorType;
 import org.apache.arrow.vector.types.pojo.ArrowType.Struct;
@@ -132,4 +132,32 @@ public class TestStructVector {
       assertNull(holder.reader);
     }
   }
+
+  @Test
+  public void testGetPrimitiveVectors() {
+    FieldType type = new FieldType(true, Struct.INSTANCE, null, null);
+    try (StructVector vector = new StructVector("struct", allocator, type, null)) {
+
+      // add list vector
+      vector.addOrGet("list", FieldType.nullable(MinorType.LIST.getType()), ListVector.class);
+      ListVector listVector = vector.addOrGetList("list");
+      listVector.addOrGetVector(FieldType.nullable(MinorType.INT.getType()));
+
+      // add union vector
+      vector.addOrGet("union", FieldType.nullable(MinorType.UNION.getType()), UnionVector.class);
+      UnionVector unionVector = vector.addOrGetUnion("union");
+      unionVector.addVector(new BigIntVector("bigInt", allocator));
+      unionVector.addVector(new SmallIntVector("smallInt", allocator));
+
+      // add varchar vector
+      vector.addOrGet("varchar", FieldType.nullable(MinorType.VARCHAR.getType()), VarCharVector.class);
+
+      List<ValueVector> primitiveVectors = vector.getPrimitiveVectors();
+      assertEquals(4, primitiveVectors.size());
+      assertEquals(MinorType.INT, primitiveVectors.get(0).getMinorType());
+      assertEquals(MinorType.BIGINT, primitiveVectors.get(1).getMinorType());
+      assertEquals(MinorType.SMALLINT, primitiveVectors.get(2).getMinorType());
+      assertEquals(MinorType.VARCHAR, primitiveVectors.get(3).getMinorType());
+    }
+  }
 }