You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by we...@apache.org on 2020/06/13 13:36:19 UTC
[arrow] branch master updated: ARROW-8312: [Java][Gandiva] support
TreeNode in IN expression
This is an automated email from the ASF dual-hosted git repository.
wesm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new a2b2f1d ARROW-8312: [Java][Gandiva] support TreeNode in IN expression
a2b2f1d is described below
commit a2b2f1dcefd13b268b50d2776b4f43a6c352b2df
Author: Yuan Zhou <yu...@intel.com>
AuthorDate: Sat Jun 13 08:35:48 2020 -0500
ARROW-8312: [Java][Gandiva] support TreeNode in IN expression
Signed-off-by: Yuan Zhou <yu...@intel.com>
Closes #6806 from zhouyuan/wip_gandiva_in_nodeptr
Authored-by: Yuan Zhou <yu...@intel.com>
Signed-off-by: Wes McKinney <we...@apache.org>
---
cpp/src/gandiva/jni/jni_common.cc | 2 +-
cpp/src/gandiva/proto/Types.proto | 2 +-
cpp/src/gandiva/tests/in_expr_test.cc | 1 +
.../apache/arrow/gandiva/expression/InNode.java | 27 +++--
.../arrow/gandiva/expression/TreeBuilder.java | 16 +--
.../apache/arrow/gandiva/evaluator/FilterTest.java | 120 +++++++++++++++++++++
.../arrow/gandiva/evaluator/ProjectorTest.java | 13 ++-
7 files changed, 152 insertions(+), 29 deletions(-)
diff --git a/cpp/src/gandiva/jni/jni_common.cc b/cpp/src/gandiva/jni/jni_common.cc
index 453c0a4..e09daf6 100644
--- a/cpp/src/gandiva/jni/jni_common.cc
+++ b/cpp/src/gandiva/jni/jni_common.cc
@@ -350,7 +350,7 @@ NodePtr ProtoTypeToOrNode(const types::OrNode& node) {
}
NodePtr ProtoTypeToInNode(const types::InNode& node) {
- NodePtr field = ProtoTypeToFieldNode(node.field());
+ NodePtr field = ProtoTypeToNode(node.node());
if (node.has_intvalues()) {
std::unordered_set<int32_t> int_values;
diff --git a/cpp/src/gandiva/proto/Types.proto b/cpp/src/gandiva/proto/Types.proto
index 02ba214..9020ccd 100644
--- a/cpp/src/gandiva/proto/Types.proto
+++ b/cpp/src/gandiva/proto/Types.proto
@@ -216,7 +216,7 @@ message FunctionSignature {
}
message InNode {
- optional FieldNode field = 1;
+ optional TreeNode node = 1;
optional IntConstants intValues = 2;
optional LongConstants longValues = 3;
optional StringConstants stringValues = 4;
diff --git a/cpp/src/gandiva/tests/in_expr_test.cc b/cpp/src/gandiva/tests/in_expr_test.cc
index 2103874..2ff91ae 100644
--- a/cpp/src/gandiva/tests/in_expr_test.cc
+++ b/cpp/src/gandiva/tests/in_expr_test.cc
@@ -85,6 +85,7 @@ TEST_F(TestIn, TestInString) {
auto node_f0 = TreeExprBuilder::MakeField(field0);
std::unordered_set<std::string> in_constants({"test", "me"});
auto in_expr = TreeExprBuilder::MakeInExpressionString(node_f0, in_constants);
+
auto condition = TreeExprBuilder::MakeCondition(in_expr);
std::shared_ptr<Filter> filter;
diff --git a/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/InNode.java b/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/InNode.java
index 2907fd5..007139e 100644
--- a/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/InNode.java
+++ b/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/InNode.java
@@ -22,7 +22,6 @@ import java.util.Set;
import org.apache.arrow.gandiva.exceptions.GandivaException;
import org.apache.arrow.gandiva.ipc.GandivaTypes;
-import org.apache.arrow.vector.types.pojo.Field;
import com.google.protobuf.ByteString;
@@ -36,40 +35,38 @@ public class InNode implements TreeNode {
private final Set<Long> longValues;
private final Set<String> stringValues;
private final Set<byte[]> binaryValues;
- private final Field field;
+ private final TreeNode input;
private InNode(Set<Integer> values, Set<Long> longValues, Set<String> stringValues, Set<byte[]>
- binaryValues, Field field) {
+ binaryValues, TreeNode node) {
this.intValues = values;
this.longValues = longValues;
this.stringValues = stringValues;
this.binaryValues = binaryValues;
- this.field = field;
+ this.input = node;
}
- public static InNode makeIntInExpr(Field field, Set<Integer> intValues) {
- return new InNode(intValues, null, null, null, field);
+ public static InNode makeIntInExpr(TreeNode node, Set<Integer> intValues) {
+ return new InNode(intValues, null, null, null, node);
}
- public static InNode makeLongInExpr(Field field, Set<Long> longValues) {
- return new InNode(null, longValues, null, null, field);
+ public static InNode makeLongInExpr(TreeNode node, Set<Long> longValues) {
+ return new InNode(null, longValues, null, null, node);
}
- public static InNode makeStringInExpr(Field field, Set<String> stringValues) {
- return new InNode(null, null, stringValues, null, field);
+ public static InNode makeStringInExpr(TreeNode node, Set<String> stringValues) {
+ return new InNode(null, null, stringValues, null, node);
}
- public static InNode makeBinaryInExpr(Field field, Set<byte[]> binaryValues) {
- return new InNode(null, null, null, binaryValues, field);
+ public static InNode makeBinaryInExpr(TreeNode node, Set<byte[]> binaryValues) {
+ return new InNode(null, null, null, binaryValues, node);
}
@Override
public GandivaTypes.TreeNode toProtobuf() throws GandivaException {
GandivaTypes.InNode.Builder inNode = GandivaTypes.InNode.newBuilder();
- GandivaTypes.FieldNode.Builder fieldNode = GandivaTypes.FieldNode.newBuilder();
- fieldNode.setField(ArrowTypeHelper.arrowFieldToProtobuf(field));
- inNode.setField(fieldNode);
+ inNode.setNode(input.toProtobuf());
if (intValues != null) {
GandivaTypes.IntConstants.Builder intConstants = GandivaTypes.IntConstants.newBuilder();
diff --git a/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/TreeBuilder.java b/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/TreeBuilder.java
index c20795f..3803db7 100644
--- a/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/TreeBuilder.java
+++ b/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/TreeBuilder.java
@@ -192,23 +192,23 @@ public class TreeBuilder {
return makeCondition(root);
}
- public static TreeNode makeInExpressionInt32(Field resultField,
+ public static TreeNode makeInExpressionInt32(TreeNode resultNode,
Set<Integer> intValues) {
- return InNode.makeIntInExpr(resultField, intValues);
+ return InNode.makeIntInExpr(resultNode, intValues);
}
- public static TreeNode makeInExpressionBigInt(Field resultField,
+ public static TreeNode makeInExpressionBigInt(TreeNode resultNode,
Set<Long> longValues) {
- return InNode.makeLongInExpr(resultField, longValues);
+ return InNode.makeLongInExpr(resultNode, longValues);
}
- public static TreeNode makeInExpressionString(Field resultField,
+ public static TreeNode makeInExpressionString(TreeNode resultNode,
Set<String> stringValues) {
- return InNode.makeStringInExpr(resultField, stringValues);
+ return InNode.makeStringInExpr(resultNode, stringValues);
}
- public static TreeNode makeInExpressionBinary(Field resultField,
+ public static TreeNode makeInExpressionBinary(TreeNode resultNode,
Set<byte[]> binaryValues) {
- return InNode.makeBinaryInExpr(resultField, binaryValues);
+ return InNode.makeBinaryInExpr(resultNode, binaryValues);
}
}
diff --git a/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/FilterTest.java b/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/FilterTest.java
index 3666f23..e7e7e9f 100644
--- a/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/FilterTest.java
+++ b/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/FilterTest.java
@@ -17,21 +17,26 @@
package org.apache.arrow.gandiva.evaluator;
+import java.nio.charset.Charset;
+import java.util.Arrays;
import java.util.List;
import java.util.stream.IntStream;
import org.apache.arrow.gandiva.exceptions.GandivaException;
import org.apache.arrow.gandiva.expression.Condition;
import org.apache.arrow.gandiva.expression.TreeBuilder;
+import org.apache.arrow.gandiva.expression.TreeNode;
import org.apache.arrow.memory.ArrowBuf;
import org.apache.arrow.vector.ipc.message.ArrowFieldNode;
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
+import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
import org.junit.Assert;
import org.junit.Test;
import com.google.common.collect.Lists;
+import com.google.common.collect.Sets;
public class FilterTest extends BaseEvaluatorTest {
@@ -43,6 +48,121 @@ public class FilterTest extends BaseEvaluatorTest {
return actual;
}
+ private Charset utf8Charset = Charset.forName("UTF-8");
+ private Charset utf16Charset = Charset.forName("UTF-16");
+
+ List<ArrowBuf> varBufs(String[] strings, Charset charset) {
+ ArrowBuf offsetsBuffer = allocator.buffer((strings.length + 1) * 4);
+ ArrowBuf dataBuffer = allocator.buffer(strings.length * 8);
+
+ int startOffset = 0;
+ for (int i = 0; i < strings.length; i++) {
+ offsetsBuffer.writeInt(startOffset);
+
+ final byte[] bytes = strings[i].getBytes(charset);
+ dataBuffer = dataBuffer.reallocIfNeeded(dataBuffer.writerIndex() + bytes.length);
+ dataBuffer.setBytes(startOffset, bytes, 0, bytes.length);
+ startOffset += bytes.length;
+ }
+ offsetsBuffer.writeInt(startOffset); // offset for the last element
+
+ return Arrays.asList(offsetsBuffer, dataBuffer);
+ }
+
+ List<ArrowBuf> stringBufs(String[] strings) {
+ return varBufs(strings, utf8Charset);
+ }
+
+ @Test
+ public void testSimpleInString() throws GandivaException, Exception {
+ Field c1 = Field.nullable("c1", new ArrowType.Utf8());
+ TreeNode l1 = TreeBuilder.makeLiteral(1L);
+ TreeNode l2 = TreeBuilder.makeLiteral(3L);
+
+ List<Field> argsSchema = Lists.newArrayList(c1);
+ List<TreeNode> args = Lists.newArrayList(TreeBuilder.makeField(c1), l1, l2);
+ TreeNode substr = TreeBuilder.makeFunction("substr", args, new ArrowType.Utf8());
+ TreeNode inExpr =
+ TreeBuilder.makeInExpressionString(substr, Sets.newHashSet("one", "two", "thr", "fou"));
+
+ Condition condition = TreeBuilder.makeCondition(inExpr);
+
+ Schema schema = new Schema(argsSchema);
+ Filter filter = Filter.make(schema, condition);
+
+ int numRows = 16;
+ byte[] validity = new byte[] {(byte) 255, 0};
+ // second half is "undefined"
+ String[] c1Values = new String[]{"one", "two", "three", "four", "five", "six", "seven",
+ "eight", "nine", "ten", "eleven", "twelve", "thirteen", "fourteen", "fifteen",
+ "sixteen"};
+ int[] expected = {0, 1, 2, 3};
+ ArrowBuf c1Validity = buf(validity);
+ ArrowBuf c2Validity = buf(validity);
+ List<ArrowBuf> dataBufsX = stringBufs(c1Values);
+
+ ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0);
+ ArrowRecordBatch batch =
+ new ArrowRecordBatch(
+ numRows,
+ Lists.newArrayList(fieldNode),
+ Lists.newArrayList(c1Validity, dataBufsX.get(0), dataBufsX.get(1), c2Validity));
+
+ ArrowBuf selectionBuffer = buf(numRows * 2);
+ SelectionVectorInt16 selectionVector = new SelectionVectorInt16(selectionBuffer);
+
+ filter.evaluate(batch, selectionVector);
+
+ int[] actual = selectionVectorToArray(selectionVector);
+ releaseRecordBatch(batch);
+ selectionBuffer.close();
+ filter.close();
+ Assert.assertArrayEquals(expected, actual);
+ }
+
+ @Test
+ public void testSimpleInInt() throws GandivaException, Exception {
+ Field c1 = Field.nullable("c1", int32);
+
+ List<Field> argsSchema = Lists.newArrayList(c1);
+ TreeNode inExpr =
+ TreeBuilder.makeInExpressionInt32(TreeBuilder.makeField(c1), Sets.newHashSet(1, 2, 3, 4));
+
+ Condition condition = TreeBuilder.makeCondition(inExpr);
+
+ Schema schema = new Schema(argsSchema);
+ Filter filter = Filter.make(schema, condition);
+
+ int numRows = 16;
+ byte[] validity = new byte[] {(byte) 255, 0};
+ // second half is "undefined"
+ int[] aValues = new int[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
+ int[] expected = {0, 1, 2, 3};
+
+ ArrowBuf validitya = buf(validity);
+ ArrowBuf validityb = buf(validity);
+ ArrowBuf valuesa = intBuf(aValues);
+
+ ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0);
+ ArrowRecordBatch batch =
+ new ArrowRecordBatch(
+ numRows,
+ Lists.newArrayList(fieldNode),
+ Lists.newArrayList(validitya, valuesa, validityb));
+
+ ArrowBuf selectionBuffer = buf(numRows * 2);
+ SelectionVectorInt16 selectionVector = new SelectionVectorInt16(selectionBuffer);
+
+ filter.evaluate(batch, selectionVector);
+
+ // free buffers
+ int[] actual = selectionVectorToArray(selectionVector);
+ releaseRecordBatch(batch);
+ selectionBuffer.close();
+ filter.close();
+ Assert.assertArrayEquals(expected, actual);
+ }
+
@Test
public void testSimpleSV16() throws GandivaException, Exception {
Field a = Field.nullable("a", int32);
diff --git a/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorTest.java b/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorTest.java
index 3e47647..753cdf6 100644
--- a/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorTest.java
+++ b/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorTest.java
@@ -1183,7 +1183,7 @@ public class ProjectorTest extends BaseEvaluatorTest {
Field c1 = Field.nullable("c1", int32);
TreeNode inExpr =
- TreeBuilder.makeInExpressionInt32(c1, Sets.newHashSet(1, 2, 3, 4, 5, 15, 16));
+ TreeBuilder.makeInExpressionInt32(TreeBuilder.makeField(c1), Sets.newHashSet(1, 2, 3, 4, 5, 15, 16));
ExpressionTree expr = TreeBuilder.makeExpression(inExpr, Field.nullable("result", boolType));
Schema schema = new Schema(Lists.newArrayList(c1));
Projector eval = Projector.make(schema, Lists.newArrayList(expr));
@@ -1210,7 +1210,7 @@ public class ProjectorTest extends BaseEvaluatorTest {
output.add(bitVector);
eval.evaluate(batch, output);
- for (int i = 0; i < 5; i++) {
+ for (int i = 1; i < 5; i++) {
assertTrue(bitVector.getObject(i).booleanValue());
}
for (int i = 5; i < 16; i++) {
@@ -1226,8 +1226,12 @@ public class ProjectorTest extends BaseEvaluatorTest {
public void testInExprStrings() throws GandivaException, Exception {
Field c1 = Field.nullable("c1", new ArrowType.Utf8());
+ TreeNode l1 = TreeBuilder.makeLiteral(1L);
+ TreeNode l2 = TreeBuilder.makeLiteral(3L);
+ List<TreeNode> args = Lists.newArrayList(TreeBuilder.makeField(c1), l1, l2);
+ TreeNode substr = TreeBuilder.makeFunction("substr", args, new ArrowType.Utf8());
TreeNode inExpr =
- TreeBuilder.makeInExpressionString(c1, Sets.newHashSet("one", "two", "three", "four"));
+ TreeBuilder.makeInExpressionString(substr, Sets.newHashSet("one", "two", "thr", "fou"));
ExpressionTree expr = TreeBuilder.makeExpression(inExpr, Field.nullable("result", boolType));
Schema schema = new Schema(Lists.newArrayList(c1));
Projector eval = Projector.make(schema, Lists.newArrayList(expr));
@@ -1291,12 +1295,13 @@ public class ProjectorTest extends BaseEvaluatorTest {
ArrowBuf aValidity = buf(validity);
ArrowBuf aData = intBuf(aValues);
ArrowBuf bValidity = buf(validity);
+ ArrowBuf b2Validity = buf(validity);
ArrowBuf bData = intBuf(bValues);
ArrowRecordBatch batch =
new ArrowRecordBatch(
numRows,
Lists.newArrayList(new ArrowFieldNode(numRows, 8), new ArrowFieldNode(numRows, 8)),
- Lists.newArrayList(aValidity, aData, bValidity, bData));
+ Lists.newArrayList(aValidity, aData, bValidity, bData, b2Validity));
IntVector intVector = new IntVector(EMPTY_SCHEMA_PATH, allocator);