You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iceberg.apache.org by bl...@apache.org on 2022/10/05 20:33:27 UTC
[iceberg] branch master updated: Python: Add rewriteNot (#5925)
This is an automated email from the ASF dual-hosted git repository.
blue pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iceberg.git
The following commit(s) were added to refs/heads/master by this push:
new 5b91984029 Python: Add rewriteNot (#5925)
5b91984029 is described below
commit 5b919840295dae181f14c7acd2e40ad9a91bf5b3
Author: Fokko Driesprong <fo...@apache.org>
AuthorDate: Wed Oct 5 22:33:18 2022 +0200
Python: Add rewriteNot (#5925)
---
python/pyiceberg/expressions/base.py | 29 ++++++++++
python/tests/expressions/test_expressions_base.py | 65 +++++++++++++++++++++++
2 files changed, 94 insertions(+)
diff --git a/python/pyiceberg/expressions/base.py b/python/pyiceberg/expressions/base.py
index 82b7a414cf..ea0873ede0 100644
--- a/python/pyiceberg/expressions/base.py
+++ b/python/pyiceberg/expressions/base.py
@@ -835,3 +835,32 @@ def _(expr: BoundLessThan, visitor: BoundBooleanExpressionVisitor[T]) -> T:
@visit_bound_predicate.register(BoundLessThanOrEqual)
def _(expr: BoundLessThanOrEqual, visitor: BoundBooleanExpressionVisitor[T]) -> T:
return visitor.visit_less_than_or_equal(term=expr.term, literal=expr.literal)
+
+
+def rewrite_not(expr: BooleanExpression) -> BooleanExpression:
+ return visit(expr, _RewriteNotVisitor())
+
+
+class _RewriteNotVisitor(BooleanExpressionVisitor[BooleanExpression]):
+ """Inverts the negations"""
+
+ def visit_true(self) -> BooleanExpression:
+ return AlwaysTrue()
+
+ def visit_false(self) -> BooleanExpression:
+ return AlwaysFalse()
+
+ def visit_not(self, child_result: BooleanExpression) -> BooleanExpression:
+ return ~child_result
+
+ def visit_and(self, left_result: BooleanExpression, right_result: BooleanExpression) -> BooleanExpression:
+ return And(left=left_result, right=right_result)
+
+ def visit_or(self, left_result: BooleanExpression, right_result: BooleanExpression) -> BooleanExpression:
+ return Or(left=left_result, right=right_result)
+
+ def visit_unbound_predicate(self, predicate) -> BooleanExpression:
+ return predicate
+
+ def visit_bound_predicate(self, predicate) -> BooleanExpression:
+ return predicate
diff --git a/python/tests/expressions/test_expressions_base.py b/python/tests/expressions/test_expressions_base.py
index 94bf94b72a..02997eaa18 100644
--- a/python/tests/expressions/test_expressions_base.py
+++ b/python/tests/expressions/test_expressions_base.py
@@ -22,6 +22,7 @@ from typing import List, Set
import pytest
from pyiceberg.expressions import base
+from pyiceberg.expressions.base import rewrite_not
from pyiceberg.expressions.literals import (
Literal,
LongLiteral,
@@ -1151,3 +1152,67 @@ def test_bound_boolean_expression_visitor_raise_on_unbound_predicate():
with pytest.raises(TypeError) as exc_info:
base.visit(bound_expression, visitor=visitor)
assert "Not a bound predicate" in str(exc_info.value)
+
+
+def test_rewrite_not_equal_to():
+ assert rewrite_not(base.Not(base.EqualTo(base.Reference("x"), literal(34.56)))) == base.NotEqualTo(
+ base.Reference("x"), literal(34.56)
+ )
+
+
+def test_rewrite_not_not_equal_to():
+ assert rewrite_not(base.Not(base.NotEqualTo(base.Reference("x"), literal(34.56)))) == base.EqualTo(
+ base.Reference("x"), literal(34.56)
+ )
+
+
+def test_rewrite_not_in():
+ assert rewrite_not(base.Not(base.In(base.Reference("x"), (literal(34.56),)))) == base.NotIn(
+ base.Reference("x"), (literal(34.56),)
+ )
+
+
+def test_rewrite_and():
+ assert rewrite_not(
+ base.Not(
+ base.And(
+ base.EqualTo(base.Reference("x"), literal(34.56)),
+ base.EqualTo(base.Reference("y"), literal(34.56)),
+ )
+ )
+ ) == base.Or(
+ base.NotEqualTo(term=base.Reference(name="x"), literal=literal(34.56)),
+ base.NotEqualTo(term=base.Reference(name="y"), literal=literal(34.56)),
+ )
+
+
+def test_rewrite_or():
+ assert rewrite_not(
+ base.Not(
+ base.Or(
+ base.EqualTo(base.Reference("x"), literal(34.56)),
+ base.EqualTo(base.Reference("y"), literal(34.56)),
+ )
+ )
+ ) == base.And(
+ base.NotEqualTo(term=base.Reference(name="x"), literal=literal(34.56)),
+ base.NotEqualTo(term=base.Reference(name="y"), literal=literal(34.56)),
+ )
+
+
+def test_rewrite_always_false():
+ assert rewrite_not(base.Not(base.AlwaysFalse())) == base.AlwaysTrue()
+
+
+def test_rewrite_always_true():
+ assert rewrite_not(base.Not(base.AlwaysTrue())) == base.AlwaysFalse()
+
+
+def test_rewrite_bound():
+ schema = Schema(NestedField(2, "a", IntegerType(), required=False), schema_id=1)
+ assert rewrite_not(base.IsNull(base.Reference("a")).bind(schema)) == base.BoundIsNull(
+ term=base.BoundReference(
+ field=NestedField(field_id=2, name="a", field_type=IntegerType(), required=False),
+ accessor=Accessor(position=0, inner=None),
+ )
+ )