You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iceberg.apache.org by bl...@apache.org on 2022/11/28 20:15:49 UTC
[iceberg] branch master updated: Python: Implement PyArrow row level filtering (#6258)
This is an automated email from the ASF dual-hosted git repository.
blue pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iceberg.git
The following commit(s) were added to refs/heads/master by this push:
new 56309c2402 Python: Implement PyArrow row level filtering (#6258)
56309c2402 is described below
commit 56309c24020eb47cb2e3f0db005afa4563d9a766
Author: Fokko Driesprong <fo...@apache.org>
AuthorDate: Mon Nov 28 21:15:43 2022 +0100
Python: Implement PyArrow row level filtering (#6258)
---
python/pyiceberg/expressions/visitors.py | 10 +-
python/pyiceberg/io/pyarrow.py | 65 +++++++++
python/pyiceberg/table/__init__.py | 23 +++-
python/tests/expressions/test_evaluator.py | 34 ++---
python/tests/expressions/test_visitors.py | 209 +++++++++++++++--------------
python/tests/io/test_pyarrow.py | 162 +++++++++++++++++++++-
6 files changed, 375 insertions(+), 128 deletions(-)
diff --git a/python/pyiceberg/expressions/visitors.py b/python/pyiceberg/expressions/visitors.py
index 9aa841097e..8888b722e0 100644
--- a/python/pyiceberg/expressions/visitors.py
+++ b/python/pyiceberg/expressions/visitors.py
@@ -214,7 +214,7 @@ class BindVisitor(BooleanExpressionVisitor[BooleanExpression]):
schema: Schema
case_sensitive: bool
- def __init__(self, schema: Schema, case_sensitive: bool = True) -> None:
+ def __init__(self, schema: Schema, case_sensitive: bool) -> None:
self.schema = schema
self.case_sensitive = case_sensitive
@@ -421,9 +421,7 @@ class _RewriteNotVisitor(BooleanExpressionVisitor[BooleanExpression]):
return predicate
-def expression_evaluator(
- schema: Schema, unbound: BooleanExpression, case_sensitive: bool = True
-) -> Callable[[StructProtocol], bool]:
+def expression_evaluator(schema: Schema, unbound: BooleanExpression, case_sensitive: bool) -> Callable[[StructProtocol], bool]:
return _ExpressionEvaluator(schema, unbound, case_sensitive).eval
@@ -431,7 +429,7 @@ class _ExpressionEvaluator(BoundBooleanExpressionVisitor[bool]):
bound: BooleanExpression
struct: StructProtocol
- def __init__(self, schema: Schema, unbound: BooleanExpression, case_sensitive: bool = True):
+ def __init__(self, schema: Schema, unbound: BooleanExpression, case_sensitive: bool):
self.bound = bind(schema, unbound, case_sensitive)
def eval(self, struct: StructProtocol) -> bool:
@@ -507,7 +505,7 @@ class _ManifestEvalVisitor(BoundBooleanExpressionVisitor[bool]):
partition_fields: List[PartitionFieldSummary]
partition_filter: BooleanExpression
- def __init__(self, partition_struct_schema: Schema, partition_filter: BooleanExpression, case_sensitive: bool = True):
+ def __init__(self, partition_struct_schema: Schema, partition_filter: BooleanExpression, case_sensitive):
self.partition_filter = bind(partition_struct_schema, rewrite_not(partition_filter), case_sensitive)
def eval(self, manifest: ManifestFile) -> bool:
diff --git a/python/pyiceberg/io/pyarrow.py b/python/pyiceberg/io/pyarrow.py
index d5e23e5c9d..751b2815e3 100644
--- a/python/pyiceberg/io/pyarrow.py
+++ b/python/pyiceberg/io/pyarrow.py
@@ -25,14 +25,17 @@ with the pyarrow library.
import os
from functools import lru_cache, singledispatch
from typing import (
+ Any,
Callable,
List,
+ Set,
Tuple,
Union,
)
from urllib.parse import urlparse
import pyarrow as pa
+import pyarrow.compute as pc
from pyarrow.fs import (
FileInfo,
FileSystem,
@@ -41,6 +44,9 @@ from pyarrow.fs import (
S3FileSystem,
)
+from pyiceberg.expressions import BooleanExpression, BoundTerm, Literal
+from pyiceberg.expressions.visitors import BoundBooleanExpressionVisitor
+from pyiceberg.expressions.visitors import visit as boolean_expression_visit
from pyiceberg.io import (
FileIO,
InputFile,
@@ -379,3 +385,62 @@ def _(_: StringType) -> pa.DataType:
def _(_: BinaryType) -> pa.DataType:
# Variable length by default
return pa.binary()
+
+
+class _ConvertToArrowExpression(BoundBooleanExpressionVisitor[pc.Expression]):
+ def visit_in(self, term: BoundTerm[pc.Expression], literals: Set[Any]) -> pc.Expression:
+ return pc.field(term.ref().field.name).isin(literals)
+
+ def visit_not_in(self, term: BoundTerm[pc.Expression], literals: Set[Any]) -> pc.Expression:
+ return ~pc.field(term.ref().field.name).isin(literals)
+
+ def visit_is_nan(self, term: BoundTerm[pc.Expression]) -> pc.Expression:
+ ref = pc.field(term.ref().field.name)
+ return ref.is_null(nan_is_null=True) & ref.is_valid()
+
+ def visit_not_nan(self, term: BoundTerm[pc.Expression]) -> pc.Expression:
+ ref = pc.field(term.ref().field.name)
+ return ~(ref.is_null(nan_is_null=True) & ref.is_valid())
+
+ def visit_is_null(self, term: BoundTerm[pc.Expression]) -> pc.Expression:
+ return pc.field(term.ref().field.name).is_null(nan_is_null=False)
+
+ def visit_not_null(self, term: BoundTerm[Any]) -> pc.Expression:
+ return pc.field(term.ref().field.name).is_valid()
+
+ def visit_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression:
+ return pc.field(term.ref().field.name) == literal.value
+
+ def visit_not_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression:
+ return pc.field(term.ref().field.name) != literal.value
+
+ def visit_greater_than_or_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression:
+ return pc.field(term.ref().field.name) >= literal.value
+
+ def visit_greater_than(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression:
+ return pc.field(term.ref().field.name) > literal.value
+
+ def visit_less_than(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression:
+ return pc.field(term.ref().field.name) < literal.value
+
+ def visit_less_than_or_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> pc.Expression:
+ return pc.field(term.ref().field.name) <= literal.value
+
+ def visit_true(self) -> pc.Expression:
+ return pc.scalar(True)
+
+ def visit_false(self) -> pc.Expression:
+ return pc.scalar(False)
+
+ def visit_not(self, child_result: pc.Expression) -> pc.Expression:
+ return ~child_result
+
+ def visit_and(self, left_result: pc.Expression, right_result: pc.Expression) -> pc.Expression:
+ return left_result & right_result
+
+ def visit_or(self, left_result: pc.Expression, right_result: pc.Expression) -> pc.Expression:
+ return left_result | right_result
+
+
+def expression_to_pyarrow(expr: BooleanExpression) -> pc.Expression:
+ return boolean_expression_visit(expr, _ConvertToArrowExpression())
diff --git a/python/pyiceberg/table/__init__.py b/python/pyiceberg/table/__init__.py
index dcc14c03c8..e989a3c911 100644
--- a/python/pyiceberg/table/__init__.py
+++ b/python/pyiceberg/table/__init__.py
@@ -39,8 +39,9 @@ from pyiceberg.expressions import (
BooleanExpression,
visitors,
)
-from pyiceberg.expressions.visitors import inclusive_projection
+from pyiceberg.expressions.visitors import bind, inclusive_projection
from pyiceberg.io import FileIO
+from pyiceberg.io.pyarrow import expression_to_pyarrow, schema_to_pyarrow
from pyiceberg.manifest import DataFile, ManifestFile, files
from pyiceberg.partitioning import PartitionSpec
from pyiceberg.schema import Schema
@@ -340,8 +341,6 @@ class DataScan(TableScan["DataScan"]):
scheme, path = PyArrowFileIO.parse_location(self.table.location())
fs = self.table.io.get_fs(scheme)
- import pyarrow.parquet as pq
-
locations = []
for task in self.plan_files():
if isinstance(task, FileScanTask):
@@ -354,7 +353,23 @@ class DataScan(TableScan["DataScan"]):
if "*" not in self.selected_fields:
columns = list(self.selected_fields)
- return pq.read_table(source=locations, filesystem=fs, columns=columns)
+ pyarrow_filter = None
+ if self.row_filter is not AlwaysTrue():
+ bound_row_filter = bind(self.table.schema(), self.row_filter, case_sensitive=self.case_sensitive)
+ pyarrow_filter = expression_to_pyarrow(bound_row_filter)
+
+ from pyarrow.dataset import dataset
+
+ ds = dataset(
+ source=locations,
+ filesystem=fs,
+ # Optionally provide the Schema for the Dataset,
+ # in which case it will not be inferred from the source.
+ # https://arrow.apache.org/docs/python/generated/pyarrow.dataset.dataset.html#pyarrow.dataset.dataset
+ schema=schema_to_pyarrow(self.table.schema()),
+ )
+
+ return ds.to_table(filter=pyarrow_filter, columns=columns)
def to_duckdb(self, table_name: str, connection=None):
import duckdb
diff --git a/python/tests/expressions/test_evaluator.py b/python/tests/expressions/test_evaluator.py
index 07a553947e..d2e473e38f 100644
--- a/python/tests/expressions/test_evaluator.py
+++ b/python/tests/expressions/test_evaluator.py
@@ -69,111 +69,111 @@ FLOAT_SCHEMA = Schema(
def test_true():
- evaluate = expression_evaluator(SIMPLE_SCHEMA, AlwaysTrue())
+ evaluate = expression_evaluator(SIMPLE_SCHEMA, AlwaysTrue(), case_sensitive=True)
assert evaluate(Record(1, "a"))
def test_false():
- evaluate = expression_evaluator(SIMPLE_SCHEMA, AlwaysFalse())
+ evaluate = expression_evaluator(SIMPLE_SCHEMA, AlwaysFalse(), case_sensitive=True)
assert not evaluate(Record(1, "a"))
def test_less_than():
- evaluate = expression_evaluator(SIMPLE_SCHEMA, LessThan("id", 3))
+ evaluate = expression_evaluator(SIMPLE_SCHEMA, LessThan("id", 3), case_sensitive=True)
assert evaluate(Record(2, "a"))
assert not evaluate(Record(3, "a"))
def test_less_than_or_equal():
- evaluate = expression_evaluator(SIMPLE_SCHEMA, LessThanOrEqual("id", 3))
+ evaluate = expression_evaluator(SIMPLE_SCHEMA, LessThanOrEqual("id", 3), case_sensitive=True)
assert evaluate(Record(1, "a"))
assert evaluate(Record(3, "a"))
assert not evaluate(Record(4, "a"))
def test_greater_than():
- evaluate = expression_evaluator(SIMPLE_SCHEMA, GreaterThan("id", 3))
+ evaluate = expression_evaluator(SIMPLE_SCHEMA, GreaterThan("id", 3), case_sensitive=True)
assert not evaluate(Record(1, "a"))
assert not evaluate(Record(3, "a"))
assert evaluate(Record(4, "a"))
def test_greater_than_or_equal():
- evaluate = expression_evaluator(SIMPLE_SCHEMA, GreaterThanOrEqual("id", 3))
+ evaluate = expression_evaluator(SIMPLE_SCHEMA, GreaterThanOrEqual("id", 3), case_sensitive=True)
assert not evaluate(Record(2, "a"))
assert evaluate(Record(3, "a"))
assert evaluate(Record(4, "a"))
def test_equal_to():
- evaluate = expression_evaluator(SIMPLE_SCHEMA, EqualTo("id", 3))
+ evaluate = expression_evaluator(SIMPLE_SCHEMA, EqualTo("id", 3), case_sensitive=True)
assert not evaluate(Record(2, "a"))
assert evaluate(Record(3, "a"))
assert not evaluate(Record(4, "a"))
def test_not_equal_to():
- evaluate = expression_evaluator(SIMPLE_SCHEMA, NotEqualTo("id", 3))
+ evaluate = expression_evaluator(SIMPLE_SCHEMA, NotEqualTo("id", 3), case_sensitive=True)
assert evaluate(Record(2, "a"))
assert not evaluate(Record(3, "a"))
assert evaluate(Record(4, "a"))
def test_in():
- evaluate = expression_evaluator(SIMPLE_SCHEMA, In("id", [1, 2, 3]))
+ evaluate = expression_evaluator(SIMPLE_SCHEMA, In("id", [1, 2, 3]), case_sensitive=True)
assert evaluate(Record(2, "a"))
assert evaluate(Record(3, "a"))
assert not evaluate(Record(4, "a"))
def test_not_in():
- evaluate = expression_evaluator(SIMPLE_SCHEMA, NotIn("id", [1, 2, 3]))
+ evaluate = expression_evaluator(SIMPLE_SCHEMA, NotIn("id", [1, 2, 3]), case_sensitive=True)
assert not evaluate(Record(2, "a"))
assert not evaluate(Record(3, "a"))
assert evaluate(Record(4, "a"))
def test_is_null():
- evaluate = expression_evaluator(SIMPLE_SCHEMA, IsNull("data"))
+ evaluate = expression_evaluator(SIMPLE_SCHEMA, IsNull("data"), case_sensitive=True)
assert not evaluate(Record(2, "a"))
assert evaluate(Record(3, None))
def test_not_null():
- evaluate = expression_evaluator(SIMPLE_SCHEMA, NotNull("data"))
+ evaluate = expression_evaluator(SIMPLE_SCHEMA, NotNull("data"), case_sensitive=True)
assert evaluate(Record(2, "a"))
assert not evaluate(Record(3, None))
def test_is_nan():
- evaluate = expression_evaluator(FLOAT_SCHEMA, IsNaN("f"))
+ evaluate = expression_evaluator(FLOAT_SCHEMA, IsNaN("f"), case_sensitive=True)
assert not evaluate(Record(2, 0.0))
assert not evaluate(Record(3, float("infinity")))
assert evaluate(Record(4, float("nan")))
def test_not_nan():
- evaluate = expression_evaluator(FLOAT_SCHEMA, NotNaN("f"))
+ evaluate = expression_evaluator(FLOAT_SCHEMA, NotNaN("f"), case_sensitive=True)
assert evaluate(Record(2, 0.0))
assert evaluate(Record(3, float("infinity")))
assert not evaluate(Record(4, float("nan")))
def test_not():
- evaluate = expression_evaluator(SIMPLE_SCHEMA, Not(LessThan("id", 3)))
+ evaluate = expression_evaluator(SIMPLE_SCHEMA, Not(LessThan("id", 3)), case_sensitive=True)
assert not evaluate(Record(2, "a"))
assert evaluate(Record(3, "a"))
def test_and():
- evaluate = expression_evaluator(SIMPLE_SCHEMA, And(LessThan("id", 3), GreaterThan("id", 1)))
+ evaluate = expression_evaluator(SIMPLE_SCHEMA, And(LessThan("id", 3), GreaterThan("id", 1)), case_sensitive=True)
assert not evaluate(Record(1, "a"))
assert evaluate(Record(2, "a"))
assert not evaluate(Record(3, "a"))
def test_or():
- evaluate = expression_evaluator(SIMPLE_SCHEMA, Or(LessThan("id", 2), GreaterThan("id", 2)))
+ evaluate = expression_evaluator(SIMPLE_SCHEMA, Or(LessThan("id", 2), GreaterThan("id", 2)), case_sensitive=True)
assert evaluate(Record(1, "a"))
assert not evaluate(Record(2, "a"))
assert evaluate(Record(3, "a"))
diff --git a/python/tests/expressions/test_visitors.py b/python/tests/expressions/test_visitors.py
index 0f07156be4..0552843413 100644
--- a/python/tests/expressions/test_visitors.py
+++ b/python/tests/expressions/test_visitors.py
@@ -289,7 +289,7 @@ def test_bind_visitor_already_bound(table_schema_simple: Schema):
literal=literal("hello"),
)
with pytest.raises(TypeError) as exc_info:
- visit(bound, visitor=BindVisitor(schema=table_schema_simple))
+ visit(bound, visitor=BindVisitor(schema=table_schema_simple, case_sensitive=True))
assert (
"Found already bound predicate: BoundEqualTo(term=BoundReference(field=NestedField(field_id=1, name='foo', field_type=StringType(), required=False), accessor=Accessor(position=0,inner=None)), literal=literal('hello'))"
== str(exc_info.value)
@@ -305,28 +305,28 @@ def test_visit_bound_visitor_unknown_predicate():
def test_always_true_expression_binding(table_schema_simple: Schema):
"""Test that visiting an always-true expression returns always-true"""
unbound_expression = AlwaysTrue()
- bound_expression = visit(unbound_expression, visitor=BindVisitor(schema=table_schema_simple))
+ bound_expression = visit(unbound_expression, visitor=BindVisitor(schema=table_schema_simple, case_sensitive=True))
assert bound_expression == AlwaysTrue()
def test_always_false_expression_binding(table_schema_simple: Schema):
"""Test that visiting an always-false expression returns always-false"""
unbound_expression = AlwaysFalse()
- bound_expression = visit(unbound_expression, visitor=BindVisitor(schema=table_schema_simple))
+ bound_expression = visit(unbound_expression, visitor=BindVisitor(schema=table_schema_simple, case_sensitive=True))
assert bound_expression == AlwaysFalse()
def test_always_false_and_always_true_expression_binding(table_schema_simple: Schema):
"""Test that visiting both an always-true AND always-false expression returns always-false"""
unbound_expression = And(AlwaysTrue(), AlwaysFalse())
- bound_expression = visit(unbound_expression, visitor=BindVisitor(schema=table_schema_simple))
+ bound_expression = visit(unbound_expression, visitor=BindVisitor(schema=table_schema_simple, case_sensitive=True))
assert bound_expression == AlwaysFalse()
def test_always_false_or_always_true_expression_binding(table_schema_simple: Schema):
"""Test that visiting always-true OR always-false expression returns always-true"""
unbound_expression = Or(AlwaysTrue(), AlwaysFalse())
- bound_expression = visit(unbound_expression, visitor=BindVisitor(schema=table_schema_simple))
+ bound_expression = visit(unbound_expression, visitor=BindVisitor(schema=table_schema_simple, case_sensitive=True))
assert bound_expression == AlwaysTrue()
@@ -397,7 +397,7 @@ def test_always_false_or_always_true_expression_binding(table_schema_simple: Sch
)
def test_and_expression_binding(unbound_and_expression, expected_bound_expression, table_schema_simple):
"""Test that visiting an unbound AND expression with a bind-visitor returns the expected bound expression"""
- bound_expression = visit(unbound_and_expression, visitor=BindVisitor(schema=table_schema_simple))
+ bound_expression = visit(unbound_and_expression, visitor=BindVisitor(schema=table_schema_simple, case_sensitive=True))
assert bound_expression == expected_bound_expression
@@ -489,7 +489,7 @@ def test_and_expression_binding(unbound_and_expression, expected_bound_expressio
)
def test_or_expression_binding(unbound_or_expression, expected_bound_expression, table_schema_simple):
"""Test that visiting an unbound OR expression with a bind-visitor returns the expected bound expression"""
- bound_expression = visit(unbound_or_expression, visitor=BindVisitor(schema=table_schema_simple))
+ bound_expression = visit(unbound_or_expression, visitor=BindVisitor(schema=table_schema_simple, case_sensitive=True))
assert bound_expression == expected_bound_expression
@@ -533,7 +533,7 @@ def test_or_expression_binding(unbound_or_expression, expected_bound_expression,
)
def test_in_expression_binding(unbound_in_expression, expected_bound_expression, table_schema_simple):
"""Test that visiting an unbound IN expression with a bind-visitor returns the expected bound expression"""
- bound_expression = visit(unbound_in_expression, visitor=BindVisitor(schema=table_schema_simple))
+ bound_expression = visit(unbound_in_expression, visitor=BindVisitor(schema=table_schema_simple, case_sensitive=True))
assert bound_expression == expected_bound_expression
@@ -582,7 +582,7 @@ def test_in_expression_binding(unbound_in_expression, expected_bound_expression,
)
def test_not_expression_binding(unbound_not_expression, expected_bound_expression, table_schema_simple):
"""Test that visiting an unbound NOT expression with a bind-visitor returns the expected bound expression"""
- bound_expression = visit(unbound_not_expression, visitor=BindVisitor(schema=table_schema_simple))
+ bound_expression = visit(unbound_not_expression, visitor=BindVisitor(schema=table_schema_simple, case_sensitive=True))
assert bound_expression == expected_bound_expression
@@ -950,93 +950,93 @@ def manifest() -> ManifestFile:
def test_all_nulls(schema: Schema, manifest: ManifestFile) -> None:
- assert not _ManifestEvalVisitor(schema, NotNull(Reference("all_nulls_missing_nan"))).eval(
+ assert not _ManifestEvalVisitor(schema, NotNull(Reference("all_nulls_missing_nan")), case_sensitive=True).eval(
manifest
), "Should skip: all nulls column with non-floating type contains all null"
- assert _ManifestEvalVisitor(schema, NotNull(Reference("all_nulls_missing_nan_float"))).eval(
+ assert _ManifestEvalVisitor(schema, NotNull(Reference("all_nulls_missing_nan_float")), case_sensitive=True).eval(
manifest
), "Should read: no NaN information may indicate presence of NaN value"
- assert _ManifestEvalVisitor(schema, NotNull(Reference("some_nulls"))).eval(
+ assert _ManifestEvalVisitor(schema, NotNull(Reference("some_nulls")), case_sensitive=True).eval(
manifest
), "Should read: column with some nulls contains a non-null value"
- assert _ManifestEvalVisitor(schema, NotNull(Reference("no_nulls"))).eval(
+ assert _ManifestEvalVisitor(schema, NotNull(Reference("no_nulls")), case_sensitive=True).eval(
manifest
), "Should read: non-null column contains a non-null value"
def test_no_nulls(schema: Schema, manifest: ManifestFile) -> None:
- assert _ManifestEvalVisitor(schema, IsNull(Reference("all_nulls_missing_nan"))).eval(
+ assert _ManifestEvalVisitor(schema, IsNull(Reference("all_nulls_missing_nan")), case_sensitive=True).eval(
manifest
), "Should read: at least one null value in all null column"
- assert _ManifestEvalVisitor(schema, IsNull(Reference("some_nulls"))).eval(
+ assert _ManifestEvalVisitor(schema, IsNull(Reference("some_nulls")), case_sensitive=True).eval(
manifest
), "Should read: column with some nulls contains a null value"
- assert not _ManifestEvalVisitor(schema, IsNull(Reference("no_nulls"))).eval(
+ assert not _ManifestEvalVisitor(schema, IsNull(Reference("no_nulls")), case_sensitive=True).eval(
manifest
), "Should skip: non-null column contains no null values"
- assert _ManifestEvalVisitor(schema, IsNull(Reference("both_nan_and_null"))).eval(
+ assert _ManifestEvalVisitor(schema, IsNull(Reference("both_nan_and_null")), case_sensitive=True).eval(
manifest
), "Should read: both_nan_and_null column contains no null values"
def test_is_nan(schema: Schema, manifest: ManifestFile) -> None:
- assert _ManifestEvalVisitor(schema, IsNaN(Reference("float"))).eval(
+ assert _ManifestEvalVisitor(schema, IsNaN(Reference("float")), case_sensitive=True).eval(
manifest
), "Should read: no information on if there are nan value in float column"
- assert _ManifestEvalVisitor(schema, IsNaN(Reference("all_nulls_double"))).eval(
+ assert _ManifestEvalVisitor(schema, IsNaN(Reference("all_nulls_double")), case_sensitive=True).eval(
manifest
), "Should read: no NaN information may indicate presence of NaN value"
- assert _ManifestEvalVisitor(schema, IsNaN(Reference("all_nulls_missing_nan_float"))).eval(
+ assert _ManifestEvalVisitor(schema, IsNaN(Reference("all_nulls_missing_nan_float")), case_sensitive=True).eval(
manifest
), "Should read: no NaN information may indicate presence of NaN value"
- assert not _ManifestEvalVisitor(schema, IsNaN(Reference("all_nulls_no_nans"))).eval(
+ assert not _ManifestEvalVisitor(schema, IsNaN(Reference("all_nulls_no_nans")), case_sensitive=True).eval(
manifest
), "Should skip: no nan column doesn't contain nan value"
- assert _ManifestEvalVisitor(schema, IsNaN(Reference("all_nans"))).eval(
+ assert _ManifestEvalVisitor(schema, IsNaN(Reference("all_nans")), case_sensitive=True).eval(
manifest
), "Should read: all_nans column contains nan value"
- assert _ManifestEvalVisitor(schema, IsNaN(Reference("both_nan_and_null"))).eval(
+ assert _ManifestEvalVisitor(schema, IsNaN(Reference("both_nan_and_null")), case_sensitive=True).eval(
manifest
), "Should read: both_nan_and_null column contains nan value"
- assert not _ManifestEvalVisitor(schema, IsNaN(Reference("no_nan_or_null"))).eval(
+ assert not _ManifestEvalVisitor(schema, IsNaN(Reference("no_nan_or_null")), case_sensitive=True).eval(
manifest
), "Should skip: no_nan_or_null column doesn't contain nan value"
def test_not_nan(schema: Schema, manifest: ManifestFile) -> None:
- assert _ManifestEvalVisitor(schema, NotNaN(Reference("float"))).eval(
+ assert _ManifestEvalVisitor(schema, NotNaN(Reference("float")), case_sensitive=True).eval(
manifest
), "Should read: no information on if there are nan value in float column"
- assert _ManifestEvalVisitor(schema, NotNaN(Reference("all_nulls_double"))).eval(
+ assert _ManifestEvalVisitor(schema, NotNaN(Reference("all_nulls_double")), case_sensitive=True).eval(
manifest
), "Should read: all null column contains non nan value"
- assert _ManifestEvalVisitor(schema, NotNaN(Reference("all_nulls_no_nans"))).eval(
+ assert _ManifestEvalVisitor(schema, NotNaN(Reference("all_nulls_no_nans")), case_sensitive=True).eval(
manifest
), "Should read: no_nans column contains non nan value"
- assert not _ManifestEvalVisitor(schema, NotNaN(Reference("all_nans"))).eval(
+ assert not _ManifestEvalVisitor(schema, NotNaN(Reference("all_nans")), case_sensitive=True).eval(
manifest
), "Should skip: all nans column doesn't contain non nan value"
- assert _ManifestEvalVisitor(schema, NotNaN(Reference("both_nan_and_null"))).eval(
+ assert _ManifestEvalVisitor(schema, NotNaN(Reference("both_nan_and_null")), case_sensitive=True).eval(
manifest
), "Should read: both_nan_and_null nans column contains non nan value"
- assert _ManifestEvalVisitor(schema, NotNaN(Reference("no_nan_or_null"))).eval(
+ assert _ManifestEvalVisitor(schema, NotNaN(Reference("no_nan_or_null")), case_sensitive=True).eval(
manifest
), "Should read: no_nan_or_null column contains non nan value"
@@ -1056,15 +1056,17 @@ def test_missing_stats(schema: Schema, manifest_no_stats: ManifestFile):
]
for expr in expressions:
- assert _ManifestEvalVisitor(schema, expr).eval(manifest_no_stats), f"Should read when missing stats for expr: {expr}"
+ assert _ManifestEvalVisitor(schema, expr, case_sensitive=True).eval(
+ manifest_no_stats
+ ), f"Should read when missing stats for expr: {expr}"
def test_not(schema: Schema, manifest: ManifestFile):
- assert _ManifestEvalVisitor(schema, Not(LessThan(Reference("id"), INT_MIN_VALUE - 25))).eval(
+ assert _ManifestEvalVisitor(schema, Not(LessThan(Reference("id"), INT_MIN_VALUE - 25)), case_sensitive=True).eval(
manifest
), "Should read: not(false)"
- assert not _ManifestEvalVisitor(schema, Not(GreaterThan(Reference("id"), INT_MIN_VALUE - 25))).eval(
+ assert not _ManifestEvalVisitor(schema, Not(GreaterThan(Reference("id"), INT_MIN_VALUE - 25)), case_sensitive=True).eval(
manifest
), "Should skip: not(true)"
@@ -1076,6 +1078,7 @@ def test_and(schema: Schema, manifest: ManifestFile):
LessThan(Reference("id"), INT_MIN_VALUE - 25),
GreaterThanOrEqual(Reference("id"), INT_MIN_VALUE - 30),
),
+ case_sensitive=True,
).eval(manifest), "Should skip: and(false, true)"
assert not _ManifestEvalVisitor(
@@ -1084,6 +1087,7 @@ def test_and(schema: Schema, manifest: ManifestFile):
LessThan(Reference("id"), INT_MIN_VALUE - 25),
GreaterThanOrEqual(Reference("id"), INT_MAX_VALUE + 1),
),
+ case_sensitive=True,
).eval(manifest), "Should skip: and(false, false)"
assert _ManifestEvalVisitor(
@@ -1092,6 +1096,7 @@ def test_and(schema: Schema, manifest: ManifestFile):
GreaterThan(Reference("id"), INT_MIN_VALUE - 25),
LessThanOrEqual(Reference("id"), INT_MIN_VALUE),
),
+ case_sensitive=True,
).eval(manifest), "Should read: and(true, true)"
@@ -1102,6 +1107,7 @@ def test_or(schema: Schema, manifest: ManifestFile):
LessThan(Reference("id"), INT_MIN_VALUE - 25),
GreaterThanOrEqual(Reference("id"), INT_MAX_VALUE + 1),
),
+ case_sensitive=True,
).eval(manifest), "Should skip: or(false, false)"
assert _ManifestEvalVisitor(
@@ -1110,165 +1116,168 @@ def test_or(schema: Schema, manifest: ManifestFile):
LessThan(Reference("id"), INT_MIN_VALUE - 25),
GreaterThanOrEqual(Reference("id"), INT_MAX_VALUE - 19),
),
+ case_sensitive=True,
).eval(manifest), "Should read: or(false, true)"
def test_integer_lt(schema: Schema, manifest: ManifestFile):
- assert not _ManifestEvalVisitor(schema, LessThan(Reference("id"), INT_MIN_VALUE - 25)).eval(
+ assert not _ManifestEvalVisitor(schema, LessThan(Reference("id"), INT_MIN_VALUE - 25), case_sensitive=True).eval(
manifest
), "Should not read: id range below lower bound (5 < 30)"
- assert not _ManifestEvalVisitor(schema, LessThan(Reference("id"), INT_MIN_VALUE)).eval(
+ assert not _ManifestEvalVisitor(schema, LessThan(Reference("id"), INT_MIN_VALUE), case_sensitive=True).eval(
manifest
), "Should not read: id range below lower bound (30 is not < 30)"
- assert _ManifestEvalVisitor(schema, LessThan(Reference("id"), INT_MIN_VALUE + 1)).eval(
+ assert _ManifestEvalVisitor(schema, LessThan(Reference("id"), INT_MIN_VALUE + 1), case_sensitive=True).eval(
manifest
), "Should read: one possible id"
- assert _ManifestEvalVisitor(schema, LessThan(Reference("id"), INT_MAX_VALUE)).eval(manifest), "Should read: may possible ids"
+ assert _ManifestEvalVisitor(schema, LessThan(Reference("id"), INT_MAX_VALUE), case_sensitive=True).eval(
+ manifest
+ ), "Should read: may possible ids"
def test_integer_lt_eq(schema: Schema, manifest: ManifestFile):
- assert not _ManifestEvalVisitor(schema, LessThanOrEqual(Reference("id"), INT_MIN_VALUE - 25)).eval(
+ assert not _ManifestEvalVisitor(schema, LessThanOrEqual(Reference("id"), INT_MIN_VALUE - 25), case_sensitive=True).eval(
manifest
), "Should not read: id range below lower bound (5 < 30)"
- assert not _ManifestEvalVisitor(schema, LessThanOrEqual(Reference("id"), INT_MIN_VALUE - 1)).eval(
+ assert not _ManifestEvalVisitor(schema, LessThanOrEqual(Reference("id"), INT_MIN_VALUE - 1), case_sensitive=True).eval(
manifest
), "Should not read: id range below lower bound (29 < 30)"
- assert _ManifestEvalVisitor(schema, LessThanOrEqual(Reference("id"), INT_MIN_VALUE)).eval(
+ assert _ManifestEvalVisitor(schema, LessThanOrEqual(Reference("id"), INT_MIN_VALUE), case_sensitive=True).eval(
manifest
), "Should read: one possible id"
- assert _ManifestEvalVisitor(schema, LessThanOrEqual(Reference("id"), INT_MAX_VALUE)).eval(
+ assert _ManifestEvalVisitor(schema, LessThanOrEqual(Reference("id"), INT_MAX_VALUE), case_sensitive=True).eval(
manifest
), "Should read: many possible ids"
def test_integer_gt(schema: Schema, manifest: ManifestFile):
- assert not _ManifestEvalVisitor(schema, GreaterThan(Reference("id"), INT_MAX_VALUE + 6)).eval(
+ assert not _ManifestEvalVisitor(schema, GreaterThan(Reference("id"), INT_MAX_VALUE + 6), case_sensitive=True).eval(
manifest
), "Should not read: id range above upper bound (85 < 79)"
- assert not _ManifestEvalVisitor(schema, GreaterThan(Reference("id"), INT_MAX_VALUE)).eval(
+ assert not _ManifestEvalVisitor(schema, GreaterThan(Reference("id"), INT_MAX_VALUE), case_sensitive=True).eval(
manifest
), "Should not read: id range above upper bound (79 is not > 79)"
- assert _ManifestEvalVisitor(schema, GreaterThan(Reference("id"), INT_MAX_VALUE - 1)).eval(
+ assert _ManifestEvalVisitor(schema, GreaterThan(Reference("id"), INT_MAX_VALUE - 1), case_sensitive=True).eval(
manifest
), "Should read: one possible id"
- assert _ManifestEvalVisitor(schema, GreaterThan(Reference("id"), INT_MAX_VALUE - 4)).eval(
+ assert _ManifestEvalVisitor(schema, GreaterThan(Reference("id"), INT_MAX_VALUE - 4), case_sensitive=True).eval(
manifest
), "Should read: may possible ids"
def test_integer_gt_eq(schema: Schema, manifest: ManifestFile):
- assert not _ManifestEvalVisitor(schema, GreaterThanOrEqual(Reference("id"), INT_MAX_VALUE + 6)).eval(
+ assert not _ManifestEvalVisitor(schema, GreaterThanOrEqual(Reference("id"), INT_MAX_VALUE + 6), case_sensitive=True).eval(
manifest
), "Should not read: id range above upper bound (85 < 79)"
- assert not _ManifestEvalVisitor(schema, GreaterThanOrEqual(Reference("id"), INT_MAX_VALUE + 1)).eval(
+ assert not _ManifestEvalVisitor(schema, GreaterThanOrEqual(Reference("id"), INT_MAX_VALUE + 1), case_sensitive=True).eval(
manifest
), "Should not read: id range above upper bound (80 > 79)"
- assert _ManifestEvalVisitor(schema, GreaterThanOrEqual(Reference("id"), INT_MAX_VALUE)).eval(
+ assert _ManifestEvalVisitor(schema, GreaterThanOrEqual(Reference("id"), INT_MAX_VALUE), case_sensitive=True).eval(
manifest
), "Should read: one possible id"
- assert _ManifestEvalVisitor(schema, GreaterThanOrEqual(Reference("id"), INT_MAX_VALUE)).eval(
+ assert _ManifestEvalVisitor(schema, GreaterThanOrEqual(Reference("id"), INT_MAX_VALUE), case_sensitive=True).eval(
manifest
), "Should read: may possible ids"
def test_integer_eq(schema: Schema, manifest: ManifestFile):
- assert not _ManifestEvalVisitor(schema, EqualTo(Reference("id"), INT_MIN_VALUE - 25)).eval(
+ assert not _ManifestEvalVisitor(schema, EqualTo(Reference("id"), INT_MIN_VALUE - 25), case_sensitive=True).eval(
manifest
), "Should not read: id below lower bound"
- assert not _ManifestEvalVisitor(schema, EqualTo(Reference("id"), INT_MIN_VALUE - 1)).eval(
+ assert not _ManifestEvalVisitor(schema, EqualTo(Reference("id"), INT_MIN_VALUE - 1), case_sensitive=True).eval(
manifest
), "Should not read: id below lower bound"
- assert _ManifestEvalVisitor(schema, EqualTo(Reference("id"), INT_MIN_VALUE)).eval(
+ assert _ManifestEvalVisitor(schema, EqualTo(Reference("id"), INT_MIN_VALUE), case_sensitive=True).eval(
manifest
), "Should read: id equal to lower bound"
- assert _ManifestEvalVisitor(schema, EqualTo(Reference("id"), INT_MAX_VALUE - 4)).eval(
+ assert _ManifestEvalVisitor(schema, EqualTo(Reference("id"), INT_MAX_VALUE - 4), case_sensitive=True).eval(
manifest
), "Should read: id between lower and upper bounds"
- assert _ManifestEvalVisitor(schema, EqualTo(Reference("id"), INT_MAX_VALUE)).eval(
+ assert _ManifestEvalVisitor(schema, EqualTo(Reference("id"), INT_MAX_VALUE), case_sensitive=True).eval(
manifest
), "Should read: id equal to upper bound"
- assert not _ManifestEvalVisitor(schema, EqualTo(Reference("id"), INT_MAX_VALUE + 1)).eval(
+ assert not _ManifestEvalVisitor(schema, EqualTo(Reference("id"), INT_MAX_VALUE + 1), case_sensitive=True).eval(
manifest
), "Should not read: id above upper bound"
- assert not _ManifestEvalVisitor(schema, EqualTo(Reference("id"), INT_MAX_VALUE + 6)).eval(
+ assert not _ManifestEvalVisitor(schema, EqualTo(Reference("id"), INT_MAX_VALUE + 6), case_sensitive=True).eval(
manifest
), "Should not read: id above upper bound"
def test_integer_not_eq(schema: Schema, manifest: ManifestFile):
- assert _ManifestEvalVisitor(schema, NotEqualTo(Reference("id"), INT_MIN_VALUE - 25)).eval(
+ assert _ManifestEvalVisitor(schema, NotEqualTo(Reference("id"), INT_MIN_VALUE - 25), case_sensitive=True).eval(
manifest
), "Should read: id below lower bound"
- assert _ManifestEvalVisitor(schema, NotEqualTo(Reference("id"), INT_MIN_VALUE - 1)).eval(
+ assert _ManifestEvalVisitor(schema, NotEqualTo(Reference("id"), INT_MIN_VALUE - 1), case_sensitive=True).eval(
manifest
), "Should read: id below lower bound"
- assert _ManifestEvalVisitor(schema, NotEqualTo(Reference("id"), INT_MIN_VALUE)).eval(
+ assert _ManifestEvalVisitor(schema, NotEqualTo(Reference("id"), INT_MIN_VALUE), case_sensitive=True).eval(
manifest
), "Should read: id equal to lower bound"
- assert _ManifestEvalVisitor(schema, NotEqualTo(Reference("id"), INT_MAX_VALUE - 4)).eval(
+ assert _ManifestEvalVisitor(schema, NotEqualTo(Reference("id"), INT_MAX_VALUE - 4), case_sensitive=True).eval(
manifest
), "Should read: id between lower and upper bounds"
- assert _ManifestEvalVisitor(schema, NotEqualTo(Reference("id"), INT_MAX_VALUE)).eval(
+ assert _ManifestEvalVisitor(schema, NotEqualTo(Reference("id"), INT_MAX_VALUE), case_sensitive=True).eval(
manifest
), "Should read: id equal to upper bound"
- assert _ManifestEvalVisitor(schema, NotEqualTo(Reference("id"), INT_MAX_VALUE + 1)).eval(
+ assert _ManifestEvalVisitor(schema, NotEqualTo(Reference("id"), INT_MAX_VALUE + 1), case_sensitive=True).eval(
manifest
), "Should read: id above upper bound"
- assert _ManifestEvalVisitor(schema, NotEqualTo(Reference("id"), INT_MAX_VALUE + 6)).eval(
+ assert _ManifestEvalVisitor(schema, NotEqualTo(Reference("id"), INT_MAX_VALUE + 6), case_sensitive=True).eval(
manifest
), "Should read: id above upper bound"
def test_integer_not_eq_rewritten(schema: Schema, manifest: ManifestFile):
- assert _ManifestEvalVisitor(schema, Not(EqualTo(Reference("id"), INT_MIN_VALUE - 25))).eval(
+ assert _ManifestEvalVisitor(schema, Not(EqualTo(Reference("id"), INT_MIN_VALUE - 25)), case_sensitive=True).eval(
manifest
), "Should read: id below lower bound"
- assert _ManifestEvalVisitor(schema, Not(EqualTo(Reference("id"), INT_MIN_VALUE - 1))).eval(
+ assert _ManifestEvalVisitor(schema, Not(EqualTo(Reference("id"), INT_MIN_VALUE - 1)), case_sensitive=True).eval(
manifest
), "Should read: id below lower bound"
- assert _ManifestEvalVisitor(schema, Not(EqualTo(Reference("id"), INT_MIN_VALUE))).eval(
+ assert _ManifestEvalVisitor(schema, Not(EqualTo(Reference("id"), INT_MIN_VALUE)), case_sensitive=True).eval(
manifest
), "Should read: id equal to lower bound"
- assert _ManifestEvalVisitor(schema, Not(EqualTo(Reference("id"), INT_MAX_VALUE - 4))).eval(
+ assert _ManifestEvalVisitor(schema, Not(EqualTo(Reference("id"), INT_MAX_VALUE - 4)), case_sensitive=True).eval(
manifest
), "Should read: id between lower and upper bounds"
- assert _ManifestEvalVisitor(schema, Not(EqualTo(Reference("id"), INT_MAX_VALUE))).eval(
+ assert _ManifestEvalVisitor(schema, Not(EqualTo(Reference("id"), INT_MAX_VALUE)), case_sensitive=True).eval(
manifest
), "Should read: id equal to upper bound"
- assert _ManifestEvalVisitor(schema, Not(EqualTo(Reference("id"), INT_MAX_VALUE + 1))).eval(
+ assert _ManifestEvalVisitor(schema, Not(EqualTo(Reference("id"), INT_MAX_VALUE + 1)), case_sensitive=True).eval(
manifest
), "Should read: id above upper bound"
- assert _ManifestEvalVisitor(schema, Not(EqualTo(Reference("id"), INT_MAX_VALUE + 6))).eval(
+ assert _ManifestEvalVisitor(schema, Not(EqualTo(Reference("id"), INT_MAX_VALUE + 6)), case_sensitive=True).eval(
manifest
), "Should read: id above upper bound"
@@ -1304,85 +1313,85 @@ def test_integer_not_eq_rewritten_case_insensitive(schema: Schema, manifest: Man
def test_integer_in(schema: Schema, manifest: ManifestFile):
- assert not _ManifestEvalVisitor(schema, In(Reference("id"), (INT_MIN_VALUE - 25, INT_MIN_VALUE - 24))).eval(
- manifest
- ), "Should not read: id below lower bound (5 < 30, 6 < 30)"
+ assert not _ManifestEvalVisitor(
+ schema, In(Reference("id"), (INT_MIN_VALUE - 25, INT_MIN_VALUE - 24)), case_sensitive=True
+ ).eval(manifest), "Should not read: id below lower bound (5 < 30, 6 < 30)"
- assert not _ManifestEvalVisitor(schema, In(Reference("id"), (INT_MIN_VALUE - 2, INT_MIN_VALUE - 1))).eval(
- manifest
- ), "Should not read: id below lower bound (28 < 30, 29 < 30)"
+ assert not _ManifestEvalVisitor(
+ schema, In(Reference("id"), (INT_MIN_VALUE - 2, INT_MIN_VALUE - 1)), case_sensitive=True
+ ).eval(manifest), "Should not read: id below lower bound (28 < 30, 29 < 30)"
- assert _ManifestEvalVisitor(schema, In(Reference("id"), (INT_MIN_VALUE - 1, INT_MIN_VALUE))).eval(
+ assert _ManifestEvalVisitor(schema, In(Reference("id"), (INT_MIN_VALUE - 1, INT_MIN_VALUE)), case_sensitive=True).eval(
manifest
), "Should read: id equal to lower bound (30 == 30)"
- assert _ManifestEvalVisitor(schema, In(Reference("id"), (INT_MAX_VALUE - 4, INT_MAX_VALUE - 3))).eval(
+ assert _ManifestEvalVisitor(schema, In(Reference("id"), (INT_MAX_VALUE - 4, INT_MAX_VALUE - 3)), case_sensitive=True).eval(
manifest
), "Should read: id between lower and upper bounds (30 < 75 < 79, 30 < 76 < 79)"
- assert _ManifestEvalVisitor(schema, In(Reference("id"), (INT_MAX_VALUE, INT_MAX_VALUE + 1))).eval(
+ assert _ManifestEvalVisitor(schema, In(Reference("id"), (INT_MAX_VALUE, INT_MAX_VALUE + 1)), case_sensitive=True).eval(
manifest
), "Should read: id equal to upper bound (79 == 79)"
- assert not _ManifestEvalVisitor(schema, In(Reference("id"), (INT_MAX_VALUE + 1, INT_MAX_VALUE + 2))).eval(
- manifest
- ), "Should not read: id above upper bound (80 > 79, 81 > 79)"
+ assert not _ManifestEvalVisitor(
+ schema, In(Reference("id"), (INT_MAX_VALUE + 1, INT_MAX_VALUE + 2)), case_sensitive=True
+ ).eval(manifest), "Should not read: id above upper bound (80 > 79, 81 > 79)"
- assert not _ManifestEvalVisitor(schema, In(Reference("id"), (INT_MAX_VALUE + 6, INT_MAX_VALUE + 7))).eval(
- manifest
- ), "Should not read: id above upper bound (85 > 79, 86 > 79)"
+ assert not _ManifestEvalVisitor(
+ schema, In(Reference("id"), (INT_MAX_VALUE + 6, INT_MAX_VALUE + 7)), case_sensitive=True
+ ).eval(manifest), "Should not read: id above upper bound (85 > 79, 86 > 79)"
- assert not _ManifestEvalVisitor(schema, In(Reference("all_nulls_missing_nan"), ("abc", "def"))).eval(
+ assert not _ManifestEvalVisitor(schema, In(Reference("all_nulls_missing_nan"), ("abc", "def")), case_sensitive=True).eval(
manifest
), "Should skip: in on all nulls column"
- assert _ManifestEvalVisitor(schema, In(Reference("some_nulls"), ("abc", "def"))).eval(
+ assert _ManifestEvalVisitor(schema, In(Reference("some_nulls"), ("abc", "def")), case_sensitive=True).eval(
manifest
), "Should read: in on some nulls column"
- assert _ManifestEvalVisitor(schema, In(Reference("no_nulls"), ("abc", "def"))).eval(
+ assert _ManifestEvalVisitor(schema, In(Reference("no_nulls"), ("abc", "def")), case_sensitive=True).eval(
manifest
), "Should read: in on no nulls column"
def test_integer_not_in(schema: Schema, manifest: ManifestFile):
- assert _ManifestEvalVisitor(schema, NotIn(Reference("id"), (INT_MIN_VALUE - 25, INT_MIN_VALUE - 24))).eval(
- manifest
- ), "Should read: id below lower bound (5 < 30, 6 < 30)"
+ assert _ManifestEvalVisitor(
+ schema, NotIn(Reference("id"), (INT_MIN_VALUE - 25, INT_MIN_VALUE - 24)), case_sensitive=True
+ ).eval(manifest), "Should read: id below lower bound (5 < 30, 6 < 30)"
- assert _ManifestEvalVisitor(schema, NotIn(Reference("id"), (INT_MIN_VALUE - 2, INT_MIN_VALUE - 1))).eval(
+ assert _ManifestEvalVisitor(schema, NotIn(Reference("id"), (INT_MIN_VALUE - 2, INT_MIN_VALUE - 1)), case_sensitive=True).eval(
manifest
), "Should read: id below lower bound (28 < 30, 29 < 30)"
- assert _ManifestEvalVisitor(schema, NotIn(Reference("id"), (INT_MIN_VALUE - 1, INT_MIN_VALUE))).eval(
+ assert _ManifestEvalVisitor(schema, NotIn(Reference("id"), (INT_MIN_VALUE - 1, INT_MIN_VALUE)), case_sensitive=True).eval(
manifest
), "Should read: id equal to lower bound (30 == 30)"
- assert _ManifestEvalVisitor(schema, NotIn(Reference("id"), (INT_MAX_VALUE - 4, INT_MAX_VALUE - 3))).eval(
+ assert _ManifestEvalVisitor(schema, NotIn(Reference("id"), (INT_MAX_VALUE - 4, INT_MAX_VALUE - 3)), case_sensitive=True).eval(
manifest
), "Should read: id between lower and upper bounds (30 < 75 < 79, 30 < 76 < 79)"
- assert _ManifestEvalVisitor(schema, NotIn(Reference("id"), (INT_MAX_VALUE, INT_MAX_VALUE + 1))).eval(
+ assert _ManifestEvalVisitor(schema, NotIn(Reference("id"), (INT_MAX_VALUE, INT_MAX_VALUE + 1)), case_sensitive=True).eval(
manifest
), "Should read: id equal to upper bound (79 == 79)"
- assert _ManifestEvalVisitor(schema, NotIn(Reference("id"), (INT_MAX_VALUE + 1, INT_MAX_VALUE + 2))).eval(
+ assert _ManifestEvalVisitor(schema, NotIn(Reference("id"), (INT_MAX_VALUE + 1, INT_MAX_VALUE + 2)), case_sensitive=True).eval(
manifest
), "Should read: id above upper bound (80 > 79, 81 > 79)"
- assert _ManifestEvalVisitor(schema, NotIn(Reference("id"), (INT_MAX_VALUE + 6, INT_MAX_VALUE + 7))).eval(
+ assert _ManifestEvalVisitor(schema, NotIn(Reference("id"), (INT_MAX_VALUE + 6, INT_MAX_VALUE + 7)), case_sensitive=True).eval(
manifest
), "Should read: id above upper bound (85 > 79, 86 > 79)"
- assert _ManifestEvalVisitor(schema, NotIn(Reference("all_nulls_missing_nan"), ("abc", "def"))).eval(
+ assert _ManifestEvalVisitor(schema, NotIn(Reference("all_nulls_missing_nan"), ("abc", "def")), case_sensitive=True).eval(
manifest
), "Should read: notIn on no nulls column"
- assert _ManifestEvalVisitor(schema, NotIn(Reference("some_nulls"), ("abc", "def"))).eval(
+ assert _ManifestEvalVisitor(schema, NotIn(Reference("some_nulls"), ("abc", "def")), case_sensitive=True).eval(
manifest
), "Should read: in on some nulls column"
- assert _ManifestEvalVisitor(schema, NotIn(Reference("no_nulls"), ("abc", "def"))).eval(
+ assert _ManifestEvalVisitor(schema, NotIn(Reference("no_nulls"), ("abc", "def")), case_sensitive=True).eval(
manifest
), "Should read: in on no nulls column"
diff --git a/python/tests/io/test_pyarrow.py b/python/tests/io/test_pyarrow.py
index db25c0d134..79136be1c2 100644
--- a/python/tests/io/test_pyarrow.py
+++ b/python/tests/io/test_pyarrow.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-# pylint: disable=protected-access,unused-argument
+# pylint: disable=protected-access,unused-argument,redefined-outer-name
import os
import tempfile
@@ -24,11 +24,33 @@ import pyarrow as pa
import pytest
from pyarrow.fs import FileType
+from pyiceberg.expressions import (
+ AlwaysFalse,
+ AlwaysTrue,
+ And,
+ BoundEqualTo,
+ BoundGreaterThan,
+ BoundGreaterThanOrEqual,
+ BoundIn,
+ BoundIsNaN,
+ BoundIsNull,
+ BoundLessThan,
+ BoundLessThanOrEqual,
+ BoundNotEqualTo,
+ BoundNotIn,
+ BoundNotNaN,
+ BoundNotNull,
+ BoundReference,
+ Not,
+ Or,
+ literal,
+)
from pyiceberg.io import InputStream, OutputStream
from pyiceberg.io.pyarrow import (
PyArrowFile,
PyArrowFileIO,
_ConvertToArrowSchema,
+ expression_to_pyarrow,
schema_to_pyarrow,
)
from pyiceberg.schema import Schema, visit
@@ -44,6 +66,7 @@ from pyiceberg.types import (
ListType,
LongType,
MapType,
+ NestedField,
StringType,
TimestampType,
TimestamptzType,
@@ -411,3 +434,140 @@ def test_list_type_to_pyarrow():
element_required=True,
)
assert visit(iceberg_map, _ConvertToArrowSchema()) == pa.list_(pa.int32())
+
+
+@pytest.fixture
+def bound_reference(table_schema_simple: Schema) -> BoundReference[str]:
+ return BoundReference(table_schema_simple.find_field(1), table_schema_simple.accessor_for_field(1))
+
+
+@pytest.fixture
+def bound_double_reference() -> BoundReference[float]:
+ schema = Schema(
+ NestedField(field_id=1, name="foo", field_type=DoubleType(), required=False),
+ schema_id=1,
+ identifier_field_ids=[2],
+ )
+ return BoundReference(schema.find_field(1), schema.accessor_for_field(1))
+
+
+def test_expr_is_null_to_pyarrow(bound_reference: BoundReference[str]) -> None:
+ assert (
+ repr(expression_to_pyarrow(BoundIsNull(bound_reference)))
+ == "<pyarrow.compute.Expression is_null(foo, {nan_is_null=false})>"
+ )
+
+
+def test_expr_not_null_to_pyarrow(bound_reference: BoundReference[str]) -> None:
+ assert repr(expression_to_pyarrow(BoundNotNull(bound_reference))) == "<pyarrow.compute.Expression is_valid(foo)>"
+
+
+def test_expr_is_nan_to_pyarrow(bound_double_reference: BoundReference[str]) -> None:
+ assert (
+ repr(expression_to_pyarrow(BoundIsNaN(bound_double_reference)))
+ == "<pyarrow.compute.Expression (is_null(foo, {nan_is_null=true}) and is_valid(foo))>"
+ )
+
+
+def test_expr_not_nan_to_pyarrow(bound_double_reference: BoundReference[str]) -> None:
+ assert (
+ repr(expression_to_pyarrow(BoundNotNaN(bound_double_reference)))
+ == "<pyarrow.compute.Expression invert((is_null(foo, {nan_is_null=true}) and is_valid(foo)))>"
+ )
+
+
+def test_expr_equal_to_pyarrow(bound_reference: BoundReference[str]) -> None:
+ assert (
+ repr(expression_to_pyarrow(BoundEqualTo(bound_reference, literal("hello"))))
+ == '<pyarrow.compute.Expression (foo == "hello")>'
+ )
+
+
+def test_expr_not_equal_to_pyarrow(bound_reference: BoundReference[str]) -> None:
+ assert (
+ repr(expression_to_pyarrow(BoundNotEqualTo(bound_reference, literal("hello"))))
+ == '<pyarrow.compute.Expression (foo != "hello")>'
+ )
+
+
+def test_expr_greater_than_or_equal_equal_to_pyarrow(bound_reference: BoundReference[str]) -> None:
+ assert (
+ repr(expression_to_pyarrow(BoundGreaterThanOrEqual(bound_reference, literal("hello"))))
+ == '<pyarrow.compute.Expression (foo >= "hello")>'
+ )
+
+
+def test_expr_greater_than_to_pyarrow(bound_reference: BoundReference[str]) -> None:
+ assert (
+ repr(expression_to_pyarrow(BoundGreaterThan(bound_reference, literal("hello"))))
+ == '<pyarrow.compute.Expression (foo > "hello")>'
+ )
+
+
+def test_expr_less_than_to_pyarrow(bound_reference: BoundReference[str]) -> None:
+ assert (
+ repr(expression_to_pyarrow(BoundLessThan(bound_reference, literal("hello"))))
+ == '<pyarrow.compute.Expression (foo < "hello")>'
+ )
+
+
+def test_expr_less_than_or_equal_to_pyarrow(bound_reference: BoundReference[str]) -> None:
+ assert (
+ repr(expression_to_pyarrow(BoundLessThanOrEqual(bound_reference, literal("hello"))))
+ == '<pyarrow.compute.Expression (foo <= "hello")>'
+ )
+
+
+def test_expr_in_to_pyarrow(bound_reference: BoundReference[str]) -> None:
+ assert repr(expression_to_pyarrow(BoundIn(bound_reference, {literal("hello"), literal("world")}))) in (
+ """<pyarrow.compute.Expression is_in(foo, {value_set=string:[
+ "world",
+ "hello"
+], skip_nulls=false})>""",
+ """<pyarrow.compute.Expression is_in(foo, {value_set=string:[
+ "hello",
+ "world"
+], skip_nulls=false})>""",
+ )
+
+
+def test_expr_not_in_to_pyarrow(bound_reference: BoundReference[str]) -> None:
+ assert repr(expression_to_pyarrow(BoundNotIn(bound_reference, {literal("hello"), literal("world")}))) in (
+ """<pyarrow.compute.Expression invert(is_in(foo, {value_set=string:[
+ "world",
+ "hello"
+], skip_nulls=false}))>""",
+ """<pyarrow.compute.Expression invert(is_in(foo, {value_set=string:[
+ "hello",
+ "world"
+], skip_nulls=false}))>""",
+ )
+
+
+def test_and_to_pyarrow(bound_reference: BoundReference[str]) -> None:
+ assert (
+ repr(expression_to_pyarrow(And(BoundEqualTo(bound_reference, literal("hello")), BoundIsNull(bound_reference))))
+ == '<pyarrow.compute.Expression ((foo == "hello") and is_null(foo, {nan_is_null=false}))>'
+ )
+
+
+def test_or_to_pyarrow(bound_reference: BoundReference[str]) -> None:
+ assert (
+ repr(expression_to_pyarrow(Or(BoundEqualTo(bound_reference, literal("hello")), BoundIsNull(bound_reference))))
+ == '<pyarrow.compute.Expression ((foo == "hello") or is_null(foo, {nan_is_null=false}))>'
+ )
+
+
+def test_not_to_pyarrow(bound_reference: BoundReference[str]) -> None:
+ assert (
+ repr(expression_to_pyarrow(Not(BoundEqualTo(bound_reference, literal("hello")))))
+ == '<pyarrow.compute.Expression invert((foo == "hello"))>'
+ )
+
+
+def test_always_true_to_pyarrow(bound_reference: BoundReference[str]) -> None:
+ assert repr(expression_to_pyarrow(AlwaysTrue())) == "<pyarrow.compute.Expression true>"
+
+
+def test_always_false_to_pyarrow(bound_reference: BoundReference[str]) -> None:
+ assert repr(expression_to_pyarrow(AlwaysFalse())) == "<pyarrow.compute.Expression false>"