You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ks...@apache.org on 2018/11/09 19:35:23 UTC
[arrow] branch master updated: ARROW-3721: [Gandiva] [Python]
Support all Gandiva literals
This is an automated email from the ASF dual-hosted git repository.
kszucs pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 0a2ce9d ARROW-3721: [Gandiva] [Python] Support all Gandiva literals
0a2ce9d is described below
commit 0a2ce9d67b8dccc25dc41aeb190deadfbd5d3167
Author: Philipp Moritz <pc...@gmail.com>
AuthorDate: Fri Nov 9 20:35:12 2018 +0100
ARROW-3721: [Gandiva] [Python] Support all Gandiva literals
Author: Philipp Moritz <pc...@gmail.com>
Author: Krisztián Szűcs <sz...@gmail.com>
Closes #2920 from pcmoritz/gandiva-python-literals and squashes the following commits:
b5a5c7c3 <Krisztián Szűcs> fix omitted assertion
6789cd73 <Krisztián Szűcs> don't expose _as_type in lib.pxd
f674c988 <Philipp Moritz> add tests
c951efc1 <Philipp Moritz> fix string handling
e35eb4e5 <Philipp Moritz> lint
681e201a <Philipp Moritz> support more literals in gandiva cython wrapper
---
python/pyarrow/gandiva.pyx | 54 ++++++++++++++++++++++++----
python/pyarrow/includes/libgandiva.pxd | 38 +++++++++++++++++++-
python/pyarrow/tests/test_gandiva.py | 66 +++++++++++++++++++++++++++++++++-
python/pyarrow/types.pxi | 21 ++++++-----
4 files changed, 160 insertions(+), 19 deletions(-)
diff --git a/python/pyarrow/gandiva.pyx b/python/pyarrow/gandiva.pyx
index 7bc462f..7a6c09e 100644
--- a/python/pyarrow/gandiva.pyx
+++ b/python/pyarrow/gandiva.pyx
@@ -27,14 +27,28 @@ from libc.stdint cimport int64_t, uint8_t, uintptr_t
from pyarrow.includes.libarrow cimport *
from pyarrow.compat import frombytes
-from pyarrow.lib cimport check_status, pyarrow_wrap_array
+from pyarrow.types import _as_type
+from pyarrow.lib cimport (Array, DataType, Field, MemoryPool, RecordBatch,
+ Schema, check_status, pyarrow_wrap_array)
from pyarrow.includes.libgandiva cimport (CCondition, CExpression,
CNode, CProjector, CFilter,
CSelectionVector,
TreeExprBuilder_MakeExpression,
TreeExprBuilder_MakeFunction,
- TreeExprBuilder_MakeLiteral,
+ TreeExprBuilder_MakeBoolLiteral,
+ TreeExprBuilder_MakeUInt8Literal,
+ TreeExprBuilder_MakeUInt16Literal,
+ TreeExprBuilder_MakeUInt32Literal,
+ TreeExprBuilder_MakeUInt64Literal,
+ TreeExprBuilder_MakeInt8Literal,
+ TreeExprBuilder_MakeInt16Literal,
+ TreeExprBuilder_MakeInt32Literal,
+ TreeExprBuilder_MakeInt64Literal,
+ TreeExprBuilder_MakeFloatLiteral,
+ TreeExprBuilder_MakeDoubleLiteral,
+ TreeExprBuilder_MakeStringLiteral,
+ TreeExprBuilder_MakeBinaryLiteral,
TreeExprBuilder_MakeField,
TreeExprBuilder_MakeIf,
TreeExprBuilder_MakeCondition,
@@ -42,8 +56,6 @@ from pyarrow.includes.libgandiva cimport (CCondition, CExpression,
Projector_Make,
Filter_Make)
-from pyarrow.lib cimport (Array, DataType, Field, MemoryPool,
- RecordBatch, Schema)
cdef class Node:
cdef:
@@ -150,10 +162,40 @@ cdef class Filter:
batch.sp_batch.get()[0], selection))
return SelectionVector.create(selection)
+
cdef class TreeExprBuilder:
- def make_literal(self, value):
- cdef shared_ptr[CNode] r = TreeExprBuilder_MakeLiteral(value)
+ def make_literal(self, value, dtype):
+ cdef shared_ptr[CNode] r
+ cdef DataType type = _as_type(dtype)
+ if type.id == _Type_BOOL:
+ r = TreeExprBuilder_MakeBoolLiteral(value)
+ elif type.id == _Type_UINT8:
+ r = TreeExprBuilder_MakeUInt8Literal(value)
+ elif type.id == _Type_UINT16:
+ r = TreeExprBuilder_MakeUInt16Literal(value)
+ elif type.id == _Type_UINT32:
+ r = TreeExprBuilder_MakeUInt32Literal(value)
+ elif type.id == _Type_UINT64:
+ r = TreeExprBuilder_MakeUInt64Literal(value)
+ elif type.id == _Type_INT8:
+ r = TreeExprBuilder_MakeInt8Literal(value)
+ elif type.id == _Type_INT16:
+ r = TreeExprBuilder_MakeInt16Literal(value)
+ elif type.id == _Type_INT32:
+ r = TreeExprBuilder_MakeInt32Literal(value)
+ elif type.id == _Type_INT64:
+ r = TreeExprBuilder_MakeInt64Literal(value)
+ elif type.id == _Type_FLOAT:
+ r = TreeExprBuilder_MakeFloatLiteral(value)
+ elif type.id == _Type_DOUBLE:
+ r = TreeExprBuilder_MakeDoubleLiteral(value)
+ elif type.id == _Type_STRING:
+ r = TreeExprBuilder_MakeStringLiteral(value.encode('UTF-8'))
+ elif type.id == _Type_BINARY:
+ r = TreeExprBuilder_MakeBinaryLiteral(value)
+ else:
+ raise TypeError("Didn't recognize dtype " + str(dtype))
return Node.create(r)
def make_expression(self, Node root_node, Field return_field):
diff --git a/python/pyarrow/includes/libgandiva.pxd b/python/pyarrow/includes/libgandiva.pxd
index b1e45af..f8106bc 100644
--- a/python/pyarrow/includes/libgandiva.pxd
+++ b/python/pyarrow/includes/libgandiva.pxd
@@ -56,9 +56,45 @@ cdef extern from "gandiva/arrow.h" namespace "gandiva" nogil:
cdef extern from "gandiva/tree_expr_builder.h" namespace "gandiva" nogil:
- cdef shared_ptr[CNode] TreeExprBuilder_MakeLiteral \
+ cdef shared_ptr[CNode] TreeExprBuilder_MakeBoolLiteral \
+ "gandiva::TreeExprBuilder::MakeLiteral"(c_bool value)
+
+ cdef shared_ptr[CNode] TreeExprBuilder_MakeUInt8Literal \
+ "gandiva::TreeExprBuilder::MakeLiteral"(uint8_t value)
+
+ cdef shared_ptr[CNode] TreeExprBuilder_MakeUInt16Literal \
+ "gandiva::TreeExprBuilder::MakeLiteral"(uint16_t value)
+
+ cdef shared_ptr[CNode] TreeExprBuilder_MakeUInt32Literal \
+ "gandiva::TreeExprBuilder::MakeLiteral"(uint32_t value)
+
+ cdef shared_ptr[CNode] TreeExprBuilder_MakeUInt64Literal \
+ "gandiva::TreeExprBuilder::MakeLiteral"(uint64_t value)
+
+ cdef shared_ptr[CNode] TreeExprBuilder_MakeInt8Literal \
+ "gandiva::TreeExprBuilder::MakeLiteral"(int8_t value)
+
+ cdef shared_ptr[CNode] TreeExprBuilder_MakeInt16Literal \
+ "gandiva::TreeExprBuilder::MakeLiteral"(int16_t value)
+
+ cdef shared_ptr[CNode] TreeExprBuilder_MakeInt32Literal \
+ "gandiva::TreeExprBuilder::MakeLiteral"(int32_t value)
+
+ cdef shared_ptr[CNode] TreeExprBuilder_MakeInt64Literal \
+ "gandiva::TreeExprBuilder::MakeLiteral"(int64_t value)
+
+ cdef shared_ptr[CNode] TreeExprBuilder_MakeFloatLiteral \
+ "gandiva::TreeExprBuilder::MakeLiteral"(float value)
+
+ cdef shared_ptr[CNode] TreeExprBuilder_MakeDoubleLiteral \
"gandiva::TreeExprBuilder::MakeLiteral"(double value)
+ cdef shared_ptr[CNode] TreeExprBuilder_MakeStringLiteral \
+ "gandiva::TreeExprBuilder::MakeStringLiteral"(const c_string& value)
+
+ cdef shared_ptr[CNode] TreeExprBuilder_MakeBinaryLiteral \
+ "gandiva::TreeExprBuilder::MakeBinaryLiteral"(const c_string& value)
+
cdef shared_ptr[CExpression] TreeExprBuilder_MakeExpression\
"gandiva::TreeExprBuilder::MakeExpression"(
shared_ptr[CNode] root_node, shared_ptr[CField] result_field)
diff --git a/python/pyarrow/tests/test_gandiva.py b/python/pyarrow/tests/test_gandiva.py
index f5874e4..579f88d 100644
--- a/python/pyarrow/tests/test_gandiva.py
+++ b/python/pyarrow/tests/test_gandiva.py
@@ -91,10 +91,74 @@ def test_filter():
builder = gandiva.TreeExprBuilder()
node_a = builder.make_field(table.schema.field_by_name("a"))
- thousand = builder.make_literal(1000.0)
+ thousand = builder.make_literal(1000.0, pa.float64())
cond = builder.make_function("less_than", [node_a, thousand], pa.bool_())
condition = builder.make_condition(cond)
filter = gandiva.make_filter(table.schema, condition)
result = filter.evaluate(table.to_batches()[0], pa.default_memory_pool())
assert result.to_array().equals(pa.array(range(1000), type=pa.uint32()))
+
+
+@pytest.mark.gandiva
+def test_literals():
+ import pyarrow.gandiva as gandiva
+
+ builder = gandiva.TreeExprBuilder()
+
+ builder.make_literal(True, pa.bool_())
+ builder.make_literal(0, pa.uint8())
+ builder.make_literal(1, pa.uint16())
+ builder.make_literal(2, pa.uint32())
+ builder.make_literal(3, pa.uint64())
+ builder.make_literal(4, pa.int8())
+ builder.make_literal(5, pa.int16())
+ builder.make_literal(6, pa.int32())
+ builder.make_literal(7, pa.int64())
+ builder.make_literal(8.0, pa.float32())
+ builder.make_literal(9.0, pa.float64())
+ builder.make_literal("hello", pa.string())
+ builder.make_literal(b"world", pa.binary())
+
+ builder.make_literal(True, "bool")
+ builder.make_literal(0, "uint8")
+ builder.make_literal(1, "uint16")
+ builder.make_literal(2, "uint32")
+ builder.make_literal(3, "uint64")
+ builder.make_literal(4, "int8")
+ builder.make_literal(5, "int16")
+ builder.make_literal(6, "int32")
+ builder.make_literal(7, "int64")
+ builder.make_literal(8.0, "float32")
+ builder.make_literal(9.0, "float64")
+ builder.make_literal("hello", "string")
+ builder.make_literal(b"world", "binary")
+
+ with pytest.raises(TypeError):
+ builder.make_literal("hello", pa.int64())
+ with pytest.raises(TypeError):
+ builder.make_literal(True, None)
+
+
+@pytest.mark.gandiva
+def test_regex():
+ import pyarrow.gandiva as gandiva
+
+ elements = ["park", "sparkle", "bright spark and fire", "spark"]
+ data = pa.array(elements, type=pa.string())
+ table = pa.Table.from_arrays([data], names=['a'])
+
+ builder = gandiva.TreeExprBuilder()
+ node_a = builder.make_field(table.schema.field_by_name("a"))
+ regex = builder.make_literal("%spark%", pa.string())
+ like = builder.make_function("like", [node_a, regex], pa.bool_())
+
+ field_result = pa.field("b", pa.bool_())
+ expr = builder.make_expression(like, field_result)
+
+ projector = gandiva.make_projector(
+ table.schema, [expr], pa.default_memory_pool())
+
+ r, = projector.evaluate(table.to_batches()[0])
+ b = pa.array([False, True, True, True], type=pa.bool_())
+ assert r.equals(b)
diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi
index 92ef0f3..51a5659 100644
--- a/python/pyarrow/types.pxi
+++ b/python/pyarrow/types.pxi
@@ -830,13 +830,11 @@ def field(name, type, bint nullable=True, dict metadata=None):
cdef:
shared_ptr[CKeyValueMetadata] c_meta
Field result = Field.__new__(Field)
- DataType _type
+ DataType _type = _as_type(type)
if metadata is not None:
convert_metadata(metadata, &c_meta)
- _type = _as_type(type)
-
result.sp_field.reset(new CField(tobytes(name), _type.sp_type,
nullable == 1, c_meta))
result.field = result.sp_field.get()
@@ -844,14 +842,6 @@ def field(name, type, bint nullable=True, dict metadata=None):
return result
-cdef _as_type(type):
- if isinstance(type, DataType):
- return type
- if not isinstance(type, six.string_types):
- raise TypeError(type)
- return type_for_alias(type)
-
-
cdef set PRIMITIVE_TYPES = set([
_Type_NA, _Type_BOOL,
_Type_UINT8, _Type_INT8,
@@ -1431,6 +1421,15 @@ def type_for_alias(name):
return alias()
+def _as_type(type):
+ if isinstance(type, DataType):
+ return type
+ elif isinstance(type, six.string_types):
+ return type_for_alias(type)
+ else:
+ raise TypeError(type)
+
+
def schema(fields, dict metadata=None):
"""
Construct pyarrow.Schema from collection of fields