You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iceberg.apache.org by bl...@apache.org on 2022/10/05 20:33:27 UTC

[iceberg] branch master updated: Python: Add rewriteNot (#5925)

This is an automated email from the ASF dual-hosted git repository.

blue pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iceberg.git


The following commit(s) were added to refs/heads/master by this push:
     new 5b91984029 Python: Add rewriteNot (#5925)
5b91984029 is described below

commit 5b919840295dae181f14c7acd2e40ad9a91bf5b3
Author: Fokko Driesprong <fo...@apache.org>
AuthorDate: Wed Oct 5 22:33:18 2022 +0200

    Python: Add rewriteNot (#5925)
---
 python/pyiceberg/expressions/base.py              | 29 ++++++++++
 python/tests/expressions/test_expressions_base.py | 65 +++++++++++++++++++++++
 2 files changed, 94 insertions(+)

diff --git a/python/pyiceberg/expressions/base.py b/python/pyiceberg/expressions/base.py
index 82b7a414cf..ea0873ede0 100644
--- a/python/pyiceberg/expressions/base.py
+++ b/python/pyiceberg/expressions/base.py
@@ -835,3 +835,32 @@ def _(expr: BoundLessThan, visitor: BoundBooleanExpressionVisitor[T]) -> T:
 @visit_bound_predicate.register(BoundLessThanOrEqual)
 def _(expr: BoundLessThanOrEqual, visitor: BoundBooleanExpressionVisitor[T]) -> T:
     return visitor.visit_less_than_or_equal(term=expr.term, literal=expr.literal)
+
+
+def rewrite_not(expr: BooleanExpression) -> BooleanExpression:
+    return visit(expr, _RewriteNotVisitor())
+
+
+class _RewriteNotVisitor(BooleanExpressionVisitor[BooleanExpression]):
+    """Inverts the negations"""
+
+    def visit_true(self) -> BooleanExpression:
+        return AlwaysTrue()
+
+    def visit_false(self) -> BooleanExpression:
+        return AlwaysFalse()
+
+    def visit_not(self, child_result: BooleanExpression) -> BooleanExpression:
+        return ~child_result
+
+    def visit_and(self, left_result: BooleanExpression, right_result: BooleanExpression) -> BooleanExpression:
+        return And(left=left_result, right=right_result)
+
+    def visit_or(self, left_result: BooleanExpression, right_result: BooleanExpression) -> BooleanExpression:
+        return Or(left=left_result, right=right_result)
+
+    def visit_unbound_predicate(self, predicate) -> BooleanExpression:
+        return predicate
+
+    def visit_bound_predicate(self, predicate) -> BooleanExpression:
+        return predicate
diff --git a/python/tests/expressions/test_expressions_base.py b/python/tests/expressions/test_expressions_base.py
index 94bf94b72a..02997eaa18 100644
--- a/python/tests/expressions/test_expressions_base.py
+++ b/python/tests/expressions/test_expressions_base.py
@@ -22,6 +22,7 @@ from typing import List, Set
 import pytest
 
 from pyiceberg.expressions import base
+from pyiceberg.expressions.base import rewrite_not
 from pyiceberg.expressions.literals import (
     Literal,
     LongLiteral,
@@ -1151,3 +1152,67 @@ def test_bound_boolean_expression_visitor_raise_on_unbound_predicate():
     with pytest.raises(TypeError) as exc_info:
         base.visit(bound_expression, visitor=visitor)
     assert "Not a bound predicate" in str(exc_info.value)
+
+
+def test_rewrite_not_equal_to():
+    assert rewrite_not(base.Not(base.EqualTo(base.Reference("x"), literal(34.56)))) == base.NotEqualTo(
+        base.Reference("x"), literal(34.56)
+    )
+
+
+def test_rewrite_not_not_equal_to():
+    assert rewrite_not(base.Not(base.NotEqualTo(base.Reference("x"), literal(34.56)))) == base.EqualTo(
+        base.Reference("x"), literal(34.56)
+    )
+
+
+def test_rewrite_not_in():
+    assert rewrite_not(base.Not(base.In(base.Reference("x"), (literal(34.56),)))) == base.NotIn(
+        base.Reference("x"), (literal(34.56),)
+    )
+
+
+def test_rewrite_and():
+    assert rewrite_not(
+        base.Not(
+            base.And(
+                base.EqualTo(base.Reference("x"), literal(34.56)),
+                base.EqualTo(base.Reference("y"), literal(34.56)),
+            )
+        )
+    ) == base.Or(
+        base.NotEqualTo(term=base.Reference(name="x"), literal=literal(34.56)),
+        base.NotEqualTo(term=base.Reference(name="y"), literal=literal(34.56)),
+    )
+
+
+def test_rewrite_or():
+    assert rewrite_not(
+        base.Not(
+            base.Or(
+                base.EqualTo(base.Reference("x"), literal(34.56)),
+                base.EqualTo(base.Reference("y"), literal(34.56)),
+            )
+        )
+    ) == base.And(
+        base.NotEqualTo(term=base.Reference(name="x"), literal=literal(34.56)),
+        base.NotEqualTo(term=base.Reference(name="y"), literal=literal(34.56)),
+    )
+
+
+def test_rewrite_always_false():
+    assert rewrite_not(base.Not(base.AlwaysFalse())) == base.AlwaysTrue()
+
+
+def test_rewrite_always_true():
+    assert rewrite_not(base.Not(base.AlwaysTrue())) == base.AlwaysFalse()
+
+
+def test_rewrite_bound():
+    schema = Schema(NestedField(2, "a", IntegerType(), required=False), schema_id=1)
+    assert rewrite_not(base.IsNull(base.Reference("a")).bind(schema)) == base.BoundIsNull(
+        term=base.BoundReference(
+            field=NestedField(field_id=2, name="a", field_type=IntegerType(), required=False),
+            accessor=Accessor(position=0, inner=None),
+        )
+    )