You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by we...@apache.org on 2020/06/13 13:36:19 UTC

[arrow] branch master updated: ARROW-8312: [Java][Gandiva] support TreeNode in IN expression

This is an automated email from the ASF dual-hosted git repository.

wesm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new a2b2f1d  ARROW-8312: [Java][Gandiva] support TreeNode in IN expression
a2b2f1d is described below

commit a2b2f1dcefd13b268b50d2776b4f43a6c352b2df
Author: Yuan Zhou <yu...@intel.com>
AuthorDate: Sat Jun 13 08:35:48 2020 -0500

    ARROW-8312: [Java][Gandiva] support TreeNode in IN expression
    
    Signed-off-by: Yuan Zhou <yu...@intel.com>
    
    Closes #6806 from zhouyuan/wip_gandiva_in_nodeptr
    
    Authored-by: Yuan Zhou <yu...@intel.com>
    Signed-off-by: Wes McKinney <we...@apache.org>
---
 cpp/src/gandiva/jni/jni_common.cc                  |   2 +-
 cpp/src/gandiva/proto/Types.proto                  |   2 +-
 cpp/src/gandiva/tests/in_expr_test.cc              |   1 +
 .../apache/arrow/gandiva/expression/InNode.java    |  27 +++--
 .../arrow/gandiva/expression/TreeBuilder.java      |  16 +--
 .../apache/arrow/gandiva/evaluator/FilterTest.java | 120 +++++++++++++++++++++
 .../arrow/gandiva/evaluator/ProjectorTest.java     |  13 ++-
 7 files changed, 152 insertions(+), 29 deletions(-)

diff --git a/cpp/src/gandiva/jni/jni_common.cc b/cpp/src/gandiva/jni/jni_common.cc
index 453c0a4..e09daf6 100644
--- a/cpp/src/gandiva/jni/jni_common.cc
+++ b/cpp/src/gandiva/jni/jni_common.cc
@@ -350,7 +350,7 @@ NodePtr ProtoTypeToOrNode(const types::OrNode& node) {
 }
 
 NodePtr ProtoTypeToInNode(const types::InNode& node) {
-  NodePtr field = ProtoTypeToFieldNode(node.field());
+  NodePtr field = ProtoTypeToNode(node.node());
 
   if (node.has_intvalues()) {
     std::unordered_set<int32_t> int_values;
diff --git a/cpp/src/gandiva/proto/Types.proto b/cpp/src/gandiva/proto/Types.proto
index 02ba214..9020ccd 100644
--- a/cpp/src/gandiva/proto/Types.proto
+++ b/cpp/src/gandiva/proto/Types.proto
@@ -216,7 +216,7 @@ message FunctionSignature {
 }
 
 message InNode {
-  optional FieldNode field = 1;
+  optional TreeNode node = 1;
   optional IntConstants intValues = 2;
   optional LongConstants longValues = 3;
   optional StringConstants stringValues = 4;
diff --git a/cpp/src/gandiva/tests/in_expr_test.cc b/cpp/src/gandiva/tests/in_expr_test.cc
index 2103874..2ff91ae 100644
--- a/cpp/src/gandiva/tests/in_expr_test.cc
+++ b/cpp/src/gandiva/tests/in_expr_test.cc
@@ -85,6 +85,7 @@ TEST_F(TestIn, TestInString) {
   auto node_f0 = TreeExprBuilder::MakeField(field0);
   std::unordered_set<std::string> in_constants({"test", "me"});
   auto in_expr = TreeExprBuilder::MakeInExpressionString(node_f0, in_constants);
+
   auto condition = TreeExprBuilder::MakeCondition(in_expr);
 
   std::shared_ptr<Filter> filter;
diff --git a/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/InNode.java b/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/InNode.java
index 2907fd5..007139e 100644
--- a/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/InNode.java
+++ b/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/InNode.java
@@ -22,7 +22,6 @@ import java.util.Set;
 
 import org.apache.arrow.gandiva.exceptions.GandivaException;
 import org.apache.arrow.gandiva.ipc.GandivaTypes;
-import org.apache.arrow.vector.types.pojo.Field;
 
 import com.google.protobuf.ByteString;
 
@@ -36,40 +35,38 @@ public class InNode implements TreeNode {
   private final Set<Long> longValues;
   private final Set<String> stringValues;
   private final Set<byte[]> binaryValues;
-  private final Field field;
+  private final TreeNode input;
 
   private InNode(Set<Integer> values, Set<Long> longValues, Set<String> stringValues, Set<byte[]>
-          binaryValues, Field field) {
+          binaryValues, TreeNode node) {
     this.intValues = values;
     this.longValues = longValues;
     this.stringValues = stringValues;
     this.binaryValues = binaryValues;
-    this.field = field;
+    this.input = node;
   }
 
-  public static InNode makeIntInExpr(Field field, Set<Integer> intValues) {
-    return new InNode(intValues, null, null, null, field);
+  public static InNode makeIntInExpr(TreeNode node, Set<Integer> intValues) {
+    return new InNode(intValues, null, null, null, node);
   }
 
-  public static InNode makeLongInExpr(Field field, Set<Long> longValues) {
-    return new InNode(null, longValues, null, null, field);
+  public static InNode makeLongInExpr(TreeNode node, Set<Long> longValues) {
+    return new InNode(null, longValues, null, null, node);
   }
 
-  public static InNode makeStringInExpr(Field field, Set<String> stringValues) {
-    return new InNode(null, null, stringValues, null, field);
+  public static InNode makeStringInExpr(TreeNode node, Set<String> stringValues) {
+    return new InNode(null, null, stringValues, null, node);
   }
 
-  public static InNode makeBinaryInExpr(Field field, Set<byte[]> binaryValues) {
-    return new InNode(null, null, null, binaryValues, field);
+  public static InNode makeBinaryInExpr(TreeNode node, Set<byte[]> binaryValues) {
+    return new InNode(null, null, null, binaryValues, node);
   }
 
   @Override
   public GandivaTypes.TreeNode toProtobuf() throws GandivaException {
     GandivaTypes.InNode.Builder inNode = GandivaTypes.InNode.newBuilder();
 
-    GandivaTypes.FieldNode.Builder fieldNode = GandivaTypes.FieldNode.newBuilder();
-    fieldNode.setField(ArrowTypeHelper.arrowFieldToProtobuf(field));
-    inNode.setField(fieldNode);
+    inNode.setNode(input.toProtobuf());
 
     if (intValues != null) {
       GandivaTypes.IntConstants.Builder intConstants = GandivaTypes.IntConstants.newBuilder();
diff --git a/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/TreeBuilder.java b/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/TreeBuilder.java
index c20795f..3803db7 100644
--- a/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/TreeBuilder.java
+++ b/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/TreeBuilder.java
@@ -192,23 +192,23 @@ public class TreeBuilder {
     return makeCondition(root);
   }
 
-  public static TreeNode makeInExpressionInt32(Field resultField,
+  public static TreeNode makeInExpressionInt32(TreeNode resultNode,
                                                Set<Integer> intValues) {
-    return InNode.makeIntInExpr(resultField, intValues);
+    return InNode.makeIntInExpr(resultNode, intValues);
   }
 
-  public static TreeNode makeInExpressionBigInt(Field resultField,
+  public static TreeNode makeInExpressionBigInt(TreeNode resultNode,
                                                Set<Long> longValues) {
-    return InNode.makeLongInExpr(resultField, longValues);
+    return InNode.makeLongInExpr(resultNode, longValues);
   }
 
-  public static TreeNode makeInExpressionString(Field resultField,
+  public static TreeNode makeInExpressionString(TreeNode resultNode,
                                                 Set<String> stringValues) {
-    return InNode.makeStringInExpr(resultField, stringValues);
+    return InNode.makeStringInExpr(resultNode, stringValues);
   }
 
-  public static TreeNode makeInExpressionBinary(Field resultField,
+  public static TreeNode makeInExpressionBinary(TreeNode resultNode,
                                                 Set<byte[]> binaryValues) {
-    return InNode.makeBinaryInExpr(resultField, binaryValues);
+    return InNode.makeBinaryInExpr(resultNode, binaryValues);
   }
 }
diff --git a/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/FilterTest.java b/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/FilterTest.java
index 3666f23..e7e7e9f 100644
--- a/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/FilterTest.java
+++ b/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/FilterTest.java
@@ -17,21 +17,26 @@
 
 package org.apache.arrow.gandiva.evaluator;
 
+import java.nio.charset.Charset;
+import java.util.Arrays;
 import java.util.List;
 import java.util.stream.IntStream;
 
 import org.apache.arrow.gandiva.exceptions.GandivaException;
 import org.apache.arrow.gandiva.expression.Condition;
 import org.apache.arrow.gandiva.expression.TreeBuilder;
+import org.apache.arrow.gandiva.expression.TreeNode;
 import org.apache.arrow.memory.ArrowBuf;
 import org.apache.arrow.vector.ipc.message.ArrowFieldNode;
 import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
+import org.apache.arrow.vector.types.pojo.ArrowType;
 import org.apache.arrow.vector.types.pojo.Field;
 import org.apache.arrow.vector.types.pojo.Schema;
 import org.junit.Assert;
 import org.junit.Test;
 
 import com.google.common.collect.Lists;
+import com.google.common.collect.Sets;
 
 public class FilterTest extends BaseEvaluatorTest {
 
@@ -43,6 +48,121 @@ public class FilterTest extends BaseEvaluatorTest {
     return actual;
   }
 
+  private Charset utf8Charset = Charset.forName("UTF-8");
+  private Charset utf16Charset = Charset.forName("UTF-16");
+
+  List<ArrowBuf> varBufs(String[] strings, Charset charset) {
+    ArrowBuf offsetsBuffer = allocator.buffer((strings.length + 1) * 4);
+    ArrowBuf dataBuffer = allocator.buffer(strings.length * 8);
+
+    int startOffset = 0;
+    for (int i = 0; i < strings.length; i++) {
+      offsetsBuffer.writeInt(startOffset);
+
+      final byte[] bytes = strings[i].getBytes(charset);
+      dataBuffer = dataBuffer.reallocIfNeeded(dataBuffer.writerIndex() + bytes.length);
+      dataBuffer.setBytes(startOffset, bytes, 0, bytes.length);
+      startOffset += bytes.length;
+    }
+    offsetsBuffer.writeInt(startOffset); // offset for the last element
+
+    return Arrays.asList(offsetsBuffer, dataBuffer);
+  }
+
+  List<ArrowBuf> stringBufs(String[] strings) {
+    return varBufs(strings, utf8Charset);
+  }
+
+  @Test
+  public void testSimpleInString() throws GandivaException, Exception {
+    Field c1 = Field.nullable("c1", new ArrowType.Utf8());
+    TreeNode l1 = TreeBuilder.makeLiteral(1L);
+    TreeNode l2 = TreeBuilder.makeLiteral(3L);
+
+    List<Field> argsSchema = Lists.newArrayList(c1);
+    List<TreeNode> args = Lists.newArrayList(TreeBuilder.makeField(c1), l1, l2);
+    TreeNode substr = TreeBuilder.makeFunction("substr", args, new ArrowType.Utf8());
+    TreeNode inExpr =
+            TreeBuilder.makeInExpressionString(substr, Sets.newHashSet("one", "two", "thr", "fou"));
+
+    Condition condition = TreeBuilder.makeCondition(inExpr);
+
+    Schema schema = new Schema(argsSchema);
+    Filter filter = Filter.make(schema, condition);
+
+    int numRows = 16;
+    byte[] validity = new byte[] {(byte) 255, 0};
+    // second half is "undefined"
+    String[] c1Values = new String[]{"one", "two", "three", "four", "five", "six", "seven",
+      "eight", "nine", "ten", "eleven", "twelve", "thirteen", "fourteen", "fifteen",
+      "sixteen"};
+    int[] expected = {0, 1, 2, 3};
+    ArrowBuf c1Validity = buf(validity);
+    ArrowBuf c2Validity = buf(validity);
+    List<ArrowBuf> dataBufsX = stringBufs(c1Values);
+
+    ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0);
+    ArrowRecordBatch batch =
+            new ArrowRecordBatch(
+                    numRows,
+                    Lists.newArrayList(fieldNode),
+                    Lists.newArrayList(c1Validity, dataBufsX.get(0), dataBufsX.get(1), c2Validity));
+
+    ArrowBuf selectionBuffer = buf(numRows * 2);
+    SelectionVectorInt16 selectionVector = new SelectionVectorInt16(selectionBuffer);
+
+    filter.evaluate(batch, selectionVector);
+
+    int[] actual = selectionVectorToArray(selectionVector);
+    releaseRecordBatch(batch);
+    selectionBuffer.close();
+    filter.close();
+    Assert.assertArrayEquals(expected, actual);
+  }
+
+  @Test
+  public void testSimpleInInt() throws GandivaException, Exception {
+    Field c1 = Field.nullable("c1", int32);
+
+    List<Field> argsSchema = Lists.newArrayList(c1);
+    TreeNode inExpr =
+            TreeBuilder.makeInExpressionInt32(TreeBuilder.makeField(c1), Sets.newHashSet(1, 2, 3, 4));
+
+    Condition condition = TreeBuilder.makeCondition(inExpr);
+
+    Schema schema = new Schema(argsSchema);
+    Filter filter = Filter.make(schema, condition);
+
+    int numRows = 16;
+    byte[] validity = new byte[] {(byte) 255, 0};
+    // second half is "undefined"
+    int[] aValues = new int[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
+    int[] expected = {0, 1, 2, 3};
+
+    ArrowBuf validitya = buf(validity);
+    ArrowBuf validityb = buf(validity);
+    ArrowBuf valuesa = intBuf(aValues);
+
+    ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0);
+    ArrowRecordBatch batch =
+            new ArrowRecordBatch(
+                    numRows,
+                    Lists.newArrayList(fieldNode),
+                    Lists.newArrayList(validitya, valuesa, validityb));
+
+    ArrowBuf selectionBuffer = buf(numRows * 2);
+    SelectionVectorInt16 selectionVector = new SelectionVectorInt16(selectionBuffer);
+
+    filter.evaluate(batch, selectionVector);
+
+    // free buffers
+    int[] actual = selectionVectorToArray(selectionVector);
+    releaseRecordBatch(batch);
+    selectionBuffer.close();
+    filter.close();
+    Assert.assertArrayEquals(expected, actual);
+  }
+
   @Test
   public void testSimpleSV16() throws GandivaException, Exception {
     Field a = Field.nullable("a", int32);
diff --git a/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorTest.java b/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorTest.java
index 3e47647..753cdf6 100644
--- a/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorTest.java
+++ b/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorTest.java
@@ -1183,7 +1183,7 @@ public class ProjectorTest extends BaseEvaluatorTest {
     Field c1 = Field.nullable("c1", int32);
 
     TreeNode inExpr =
-            TreeBuilder.makeInExpressionInt32(c1, Sets.newHashSet(1, 2, 3, 4, 5, 15, 16));
+            TreeBuilder.makeInExpressionInt32(TreeBuilder.makeField(c1), Sets.newHashSet(1, 2, 3, 4, 5, 15, 16));
     ExpressionTree expr = TreeBuilder.makeExpression(inExpr, Field.nullable("result", boolType));
     Schema schema = new Schema(Lists.newArrayList(c1));
     Projector eval = Projector.make(schema, Lists.newArrayList(expr));
@@ -1210,7 +1210,7 @@ public class ProjectorTest extends BaseEvaluatorTest {
     output.add(bitVector);
     eval.evaluate(batch, output);
 
-    for (int i = 0; i < 5; i++) {
+    for (int i = 1; i < 5; i++) {
       assertTrue(bitVector.getObject(i).booleanValue());
     }
     for (int i = 5; i < 16; i++) {
@@ -1226,8 +1226,12 @@ public class ProjectorTest extends BaseEvaluatorTest {
   public void testInExprStrings() throws GandivaException, Exception {
     Field c1 = Field.nullable("c1", new ArrowType.Utf8());
 
+    TreeNode l1 = TreeBuilder.makeLiteral(1L);
+    TreeNode l2 = TreeBuilder.makeLiteral(3L);
+    List<TreeNode> args = Lists.newArrayList(TreeBuilder.makeField(c1), l1, l2);
+    TreeNode substr = TreeBuilder.makeFunction("substr", args, new ArrowType.Utf8());
     TreeNode inExpr =
-            TreeBuilder.makeInExpressionString(c1, Sets.newHashSet("one", "two", "three", "four"));
+            TreeBuilder.makeInExpressionString(substr, Sets.newHashSet("one", "two", "thr", "fou"));
     ExpressionTree expr = TreeBuilder.makeExpression(inExpr, Field.nullable("result", boolType));
     Schema schema = new Schema(Lists.newArrayList(c1));
     Projector eval = Projector.make(schema, Lists.newArrayList(expr));
@@ -1291,12 +1295,13 @@ public class ProjectorTest extends BaseEvaluatorTest {
     ArrowBuf aValidity = buf(validity);
     ArrowBuf aData = intBuf(aValues);
     ArrowBuf bValidity = buf(validity);
+    ArrowBuf b2Validity = buf(validity);
     ArrowBuf bData = intBuf(bValues);
     ArrowRecordBatch batch =
         new ArrowRecordBatch(
             numRows,
             Lists.newArrayList(new ArrowFieldNode(numRows, 8), new ArrowFieldNode(numRows, 8)),
-            Lists.newArrayList(aValidity, aData, bValidity, bData));
+            Lists.newArrayList(aValidity, aData, bValidity, bData, b2Validity));
 
     IntVector intVector = new IntVector(EMPTY_SCHEMA_PATH, allocator);