You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iceberg.apache.org by dw...@apache.org on 2022/08/01 18:49:20 UTC

[iceberg] branch master updated: Python: Refactor expression hierarchy (#5389)

This is an automated email from the ASF dual-hosted git repository.

dweeks pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iceberg.git


The following commit(s) were added to refs/heads/master by this push:
     new 5f2ce6e8fc Python: Refactor expression hierarchy (#5389)
5f2ce6e8fc is described below

commit 5f2ce6e8fc86be0f785fb0bb6656f2afac0d1d80
Author: Ryan Blue <bl...@apache.org>
AuthorDate: Mon Aug 1 11:49:15 2022 -0700

    Python: Refactor expression hierarchy (#5389)
    
    * Convert And, Or, and Not to dataclasses.
    
    * Refactor base expression types.
---
 python/pyiceberg/expressions/base.py              | 139 ++++++++--------------
 python/tests/expressions/test_expressions_base.py |  12 +-
 2 files changed, 55 insertions(+), 96 deletions(-)

diff --git a/python/pyiceberg/expressions/base.py b/python/pyiceberg/expressions/base.py
index f91093b17a..f4584fdb4c 100644
--- a/python/pyiceberg/expressions/base.py
+++ b/python/pyiceberg/expressions/base.py
@@ -32,50 +32,43 @@ B = TypeVar("B")
 
 
 class BooleanExpression(ABC):
-    """Represents a boolean expression tree."""
+    """An expression that evaluates to a boolean"""
 
     @abstractmethod
     def __invert__(self) -> BooleanExpression:
         """Transform the Expression into its negated version."""
 
 
-class Bound(Generic[T], ABC):
-    """Represents a bound value expression."""
+class Term(Generic[T], ABC):
+    """A simple expression that evaluates to a value"""
 
-    def eval(self, struct: StructProtocol):  # pylint: disable=W0613
-        ...  # pragma: no cover
+
+class Bound(ABC):
+    """Represents a bound value expression"""
 
 
-class Unbound(Generic[T, B], ABC):
-    """Represents an unbound expression node."""
+class Unbound(Generic[B], ABC):
+    """Represents an unbound value expression"""
 
     @abstractmethod
-    def bind(self, schema: Schema, case_sensitive: bool) -> B:
+    def bind(self, schema: Schema, case_sensitive: bool = True) -> B:
         ...  # pragma: no cover
 
 
-class Term(ABC):
-    """An expression that evaluates to a value."""
-
-
-class BaseReference(Generic[T], Term, ABC):
-    """Represents a variable reference in an expression."""
-
-
-class BoundTerm(Bound[T], Term):
-    """Represents a bound term."""
+class BoundTerm(Term[T], Bound, ABC):
+    """Represents a bound term"""
 
     @abstractmethod
     def ref(self) -> BoundReference[T]:
         ...
 
-
-class UnboundTerm(Unbound[T, BoundTerm[T]], Term):
-    """Represents an unbound term."""
+    @abstractmethod
+    def eval(self, struct: StructProtocol):  # pylint: disable=W0613
+        ...  # pragma: no cover
 
 
 @dataclass(frozen=True)
-class BoundReference(BoundTerm[T], BaseReference[T]):
+class BoundReference(BoundTerm[T]):
     """A reference bound to a field in a schema
 
     Args:
@@ -88,6 +81,7 @@ class BoundReference(BoundTerm[T], BaseReference[T]):
 
     def eval(self, struct: StructProtocol) -> T:
         """Returns the value at the referenced field's position in an object that abides by the StructProtocol
+
         Args:
             struct (StructProtocol): A row object that abides by the StructProtocol and returns values given a position
         Returns:
@@ -99,8 +93,12 @@ class BoundReference(BoundTerm[T], BaseReference[T]):
         return self
 
 
+class UnboundTerm(Term[T], Unbound[BoundTerm[T]], ABC):
+    """Represents an unbound term."""
+
+
 @dataclass(frozen=True)
-class Reference(UnboundTerm[T], BaseReference[T]):
+class Reference(UnboundTerm[T]):
     """A reference not yet bound to a field in a schema
 
     Args:
@@ -112,7 +110,7 @@ class Reference(UnboundTerm[T], BaseReference[T]):
 
     name: str
 
-    def bind(self, schema: Schema, case_sensitive: bool) -> BoundReference[T]:
+    def bind(self, schema: Schema, case_sensitive: bool = True) -> BoundReference[T]:
         """Bind the reference to an Iceberg schema
 
         Args:
@@ -125,22 +123,24 @@ class Reference(UnboundTerm[T], BaseReference[T]):
         Returns:
             BoundReference: A reference bound to the specific field in the Iceberg schema
         """
-        field = schema.find_field(name_or_id=self.name, case_sensitive=case_sensitive)  # pylint: disable=redefined-outer-name
-
+        field = schema.find_field(name_or_id=self.name, case_sensitive=case_sensitive)
         if not field:
             raise ValueError(f"Cannot find field '{self.name}' in schema: {schema}")
 
         accessor = schema.accessor_for_field(field.field_id)
-
         if not accessor:
             raise ValueError(f"Cannot find accessor for field '{self.name}' in schema: {schema}")
 
         return BoundReference(field=field, accessor=accessor)
 
 
+@dataclass(frozen=True, init=False)
 class And(BooleanExpression):
     """AND operation expression - logical conjunction"""
 
+    left: BooleanExpression
+    right: BooleanExpression
+
     def __new__(cls, left: BooleanExpression, right: BooleanExpression, *rest: BooleanExpression):
         if rest:
             return reduce(And, (left, right, *rest))
@@ -150,35 +150,23 @@ class And(BooleanExpression):
             return right
         elif right is AlwaysTrue():
             return left
-        self = super().__new__(cls)
-        self._left = left  # type: ignore
-        self._right = right  # type: ignore
-        return self
-
-    @property
-    def left(self) -> BooleanExpression:
-        return self._left  # type: ignore
-
-    @property
-    def right(self) -> BooleanExpression:
-        return self._right  # type: ignore
-
-    def __eq__(self, other) -> bool:
-        return id(self) == id(other) or (isinstance(other, And) and self.left == other.left and self.right == other.right)
+        else:
+            result = super().__new__(cls)
+            object.__setattr__(result, "left", left)
+            object.__setattr__(result, "right", right)
+            return result
 
     def __invert__(self) -> Or:
         return Or(~self.left, ~self.right)
 
-    def __repr__(self) -> str:
-        return f"And({repr(self.left)}, {repr(self.right)})"
-
-    def __str__(self) -> str:
-        return f"And({str(self.left)}, {str(self.right)})"
-
 
+@dataclass(frozen=True, init=False)
 class Or(BooleanExpression):
     """OR operation expression - logical disjunction"""
 
+    left: BooleanExpression
+    right: BooleanExpression
+
     def __new__(cls, left: BooleanExpression, right: BooleanExpression, *rest: BooleanExpression):
         if rest:
             return reduce(Or, (left, right, *rest))
@@ -188,35 +176,22 @@ class Or(BooleanExpression):
             return right
         elif right is AlwaysFalse():
             return left
-        self = super().__new__(cls)
-        self._left = left  # type: ignore
-        self._right = right  # type: ignore
-        return self
-
-    @property
-    def left(self) -> BooleanExpression:
-        return self._left  # type: ignore
-
-    @property
-    def right(self) -> BooleanExpression:
-        return self._right  # type: ignore
-
-    def __eq__(self, other) -> bool:
-        return id(self) == id(other) or (isinstance(other, Or) and self.left == other.left and self.right == other.right)
+        else:
+            result = super().__new__(cls)
+            object.__setattr__(result, "left", left)
+            object.__setattr__(result, "right", right)
+            return result
 
     def __invert__(self) -> And:
         return And(~self.left, ~self.right)
 
-    def __repr__(self) -> str:
-        return f"Or({repr(self.left)}, {repr(self.right)})"
-
-    def __str__(self) -> str:
-        return f"Or({str(self.left)}, {str(self.right)})"
-
 
+@dataclass(frozen=True, init=False)
 class Not(BooleanExpression):
     """NOT operation expression - logical negation"""
 
+    child: BooleanExpression
+
     def __new__(cls, child: BooleanExpression):
         if child is AlwaysTrue():
             return AlwaysFalse()
@@ -224,23 +199,13 @@ class Not(BooleanExpression):
             return AlwaysTrue()
         elif isinstance(child, Not):
             return child.child
-        return super().__new__(cls)
-
-    def __init__(self, child):
-        self.child = child
-
-    def __eq__(self, other) -> bool:
-        return id(self) == id(other) or (isinstance(other, Not) and self.child == other.child)
+        result = super().__new__(cls)
+        object.__setattr__(result, "child", child)
+        return result
 
     def __invert__(self) -> BooleanExpression:
         return self.child
 
-    def __repr__(self) -> str:
-        return f"Not({repr(self.child)})"
-
-    def __str__(self) -> str:
-        return f"Not({str(self.child)})"
-
 
 @dataclass(frozen=True)
 class AlwaysTrue(BooleanExpression, Singleton):
@@ -259,7 +224,7 @@ class AlwaysFalse(BooleanExpression, Singleton):
 
 
 @dataclass(frozen=True)
-class BoundPredicate(Bound[T], BooleanExpression):
+class BoundPredicate(Generic[T], Bound, BooleanExpression):
     term: BoundTerm[T]
 
     def __invert__(self) -> BoundPredicate[T]:
@@ -267,7 +232,7 @@ class BoundPredicate(Bound[T], BooleanExpression):
 
 
 @dataclass(frozen=True)
-class UnboundPredicate(Unbound[T, BooleanExpression], BooleanExpression):
+class UnboundPredicate(Generic[T], Unbound[BooleanExpression], BooleanExpression):
     as_bound: ClassVar[type]
     term: UnboundTerm[T]
 
@@ -661,12 +626,6 @@ def _(obj: And, visitor: BooleanExpressionVisitor[T]) -> T:
     return visitor.visit_and(left_result=left_result, right_result=right_result)
 
 
-@visit.register(In)
-def _(obj: In, visitor: BooleanExpressionVisitor[T]) -> T:
-    """Visit an In boolean expression with a concrete BooleanExpressionVisitor"""
-    return visitor.visit_unbound_predicate(predicate=obj)
-
-
 @visit.register(UnboundPredicate)
 def _(obj: UnboundPredicate, visitor: BooleanExpressionVisitor[T]) -> T:
     """Visit an In boolean expression with a concrete BooleanExpressionVisitor"""
diff --git a/python/tests/expressions/test_expressions_base.py b/python/tests/expressions/test_expressions_base.py
index ba2850133b..cf74298296 100644
--- a/python/tests/expressions/test_expressions_base.py
+++ b/python/tests/expressions/test_expressions_base.py
@@ -120,13 +120,13 @@ def _(obj: ExpressionB, visitor: BooleanExpressionVisitor) -> List:
     [
         (
             base.And(ExpressionA(), ExpressionB()),
-            "And(ExpressionA(), ExpressionB())",
+            "And(left=ExpressionA(), right=ExpressionB())",
         ),
         (
             base.Or(ExpressionA(), ExpressionB()),
-            "Or(ExpressionA(), ExpressionB())",
+            "Or(left=ExpressionA(), right=ExpressionB())",
         ),
-        (base.Not(ExpressionA()), "Not(ExpressionA())"),
+        (base.Not(ExpressionA()), "Not(child=ExpressionA())"),
     ],
 )
 def test_reprs(op, rep):
@@ -208,9 +208,9 @@ def test_notnan_bind_nonfloat():
 @pytest.mark.parametrize(
     "op, string",
     [
-        (base.And(ExpressionA(), ExpressionB()), "And(testexpra, testexprb)"),
-        (base.Or(ExpressionA(), ExpressionB()), "Or(testexpra, testexprb)"),
-        (base.Not(ExpressionA()), "Not(testexpra)"),
+        (base.And(ExpressionA(), ExpressionB()), "And(left=ExpressionA(), right=ExpressionB())"),
+        (base.Or(ExpressionA(), ExpressionB()), "Or(left=ExpressionA(), right=ExpressionB())"),
+        (base.Not(ExpressionA()), "Not(child=ExpressionA())"),
     ],
 )
 def test_strs(op, string):