You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iceberg.apache.org by bl...@apache.org on 2022/11/28 21:05:16 UTC

[iceberg] branch master updated: Python: Add boolean expression parser (#6259)

This is an automated email from the ASF dual-hosted git repository.

blue pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iceberg.git


The following commit(s) were added to refs/heads/master by this push:
     new 51b5f08b84 Python: Add boolean expression parser (#6259)
51b5f08b84 is described below

commit 51b5f08b842b5319ac490174e69feae66e3c912f
Author: Ryan Blue <bl...@apache.org>
AuthorDate: Mon Nov 28 13:05:09 2022 -0800

    Python: Add boolean expression parser (#6259)
    
    Co-authored-by: Fokko Driesprong <fo...@apache.org>
---
 python/pyiceberg/expressions/literals.py  |  36 +++++
 python/pyiceberg/expressions/parser.py    | 237 ++++++++++++++++++++++++++++++
 python/pyproject.toml                     |   4 +
 python/tests/expressions/test_literals.py |   4 -
 python/tests/expressions/test_parser.py   | 151 +++++++++++++++++++
 5 files changed, 428 insertions(+), 4 deletions(-)

diff --git a/python/pyiceberg/expressions/literals.py b/python/pyiceberg/expressions/literals.py
index c59c6bcf8d..44ddc1331e 100644
--- a/python/pyiceberg/expressions/literals.py
+++ b/python/pyiceberg/expressions/literals.py
@@ -86,6 +86,8 @@ class Literal(Generic[L], ABC):
         return hash(self.value)
 
     def __eq__(self, other: Any) -> bool:
+        if not isinstance(other, Literal):
+            return False
         return self.value == other.value
 
     def __ne__(self, other) -> bool:
@@ -401,6 +403,40 @@ class DecimalLiteral(Literal[Decimal]):
             return self
         raise ValueError(f"Could not convert {self.value} into a {type_var}")
 
+    @to.register(IntegerType)
+    def _(self, _: IntegerType) -> Literal[int]:
+        value_int = int(self.value.to_integral_value())
+        if value_int > IntegerType.max:
+            return IntAboveMax()
+        elif value_int < IntegerType.min:
+            return IntBelowMin()
+        else:
+            return LongLiteral(value_int)
+
+    @to.register(LongType)
+    def _(self, _: LongType) -> Literal[int]:
+        value_int = int(self.value.to_integral_value())
+        if value_int > LongType.max:
+            return IntAboveMax()
+        elif value_int < LongType.min:
+            return IntBelowMin()
+        else:
+            return LongLiteral(value_int)
+
+    @to.register(FloatType)
+    def _(self, _: FloatType):
+        value_float = float(self.value)
+        if value_float > FloatType.max:
+            return FloatAboveMax()
+        elif value_float < FloatType.min:
+            return FloatBelowMin()
+        else:
+            return FloatLiteral(value_float)
+
+    @to.register(DoubleType)
+    def _(self, _: DoubleLiteral):
+        return DoubleLiteral(float(self.value))
+
 
 class StringLiteral(Literal[str]):
     def __init__(self, value: str):
diff --git a/python/pyiceberg/expressions/parser.py b/python/pyiceberg/expressions/parser.py
new file mode 100644
index 0000000000..0a0173da41
--- /dev/null
+++ b/python/pyiceberg/expressions/parser.py
@@ -0,0 +1,237 @@
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing,
+#  software distributed under the License is distributed on an
+#  "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+#  KIND, either express or implied.  See the License for the
+#  specific language governing permissions and limitations
+#  under the License.
+from decimal import Decimal
+
+from pyparsing import (
+    CaselessKeyword,
+    Group,
+    ParserElement,
+    ParseResults,
+    Suppress,
+    Word,
+    alphanums,
+    alphas,
+    delimited_list,
+    infix_notation,
+    one_of,
+    opAssoc,
+    sgl_quoted_string,
+)
+from pyparsing.common import pyparsing_common as common
+
+from pyiceberg.expressions import (
+    AlwaysFalse,
+    AlwaysTrue,
+    And,
+    BooleanExpression,
+    EqualTo,
+    GreaterThan,
+    GreaterThanOrEqual,
+    In,
+    IsNaN,
+    IsNull,
+    LessThan,
+    LessThanOrEqual,
+    Not,
+    NotEqualTo,
+    NotIn,
+    NotNaN,
+    NotNull,
+    Or,
+    Reference,
+)
+from pyiceberg.expressions.literals import (
+    DecimalLiteral,
+    Literal,
+    LongLiteral,
+    StringLiteral,
+)
+from pyiceberg.typedef import L
+
+ParserElement.enablePackrat()
+
+AND = CaselessKeyword("and")
+OR = CaselessKeyword("or")
+NOT = CaselessKeyword("not")
+IS = CaselessKeyword("is")
+IN = CaselessKeyword("in")
+NULL = CaselessKeyword("null")
+NAN = CaselessKeyword("nan")
+
+identifier = Word(alphas, alphanums + "_$").set_results_name("identifier")
+column = delimited_list(identifier, delim=".", combine=True).set_results_name("column")
+
+
+@column.set_parse_action
+def _(result: ParseResults) -> Reference:
+    return Reference(result.column[0])
+
+
+boolean = one_of(["true", "false"], caseless=True).set_results_name("boolean")
+string = sgl_quoted_string.set_results_name("raw_quoted_string")
+decimal = common.real().set_results_name("decimal")
+integer = common.signed_integer().set_results_name("integer")
+literal = Group(string | decimal | integer).set_results_name("literal")
+literal_set = Group(delimited_list(string) | delimited_list(decimal) | delimited_list(integer)).set_results_name("literal_set")
+
+
+@boolean.set_parse_action
+def _(result: ParseResults) -> BooleanExpression:
+    if "true" == result.boolean.lower():
+        return AlwaysTrue()
+    else:
+        return AlwaysFalse()
+
+
+@string.set_parse_action
+def _(result: ParseResults) -> Literal[str]:
+    return StringLiteral(result.raw_quoted_string[1:-1].replace("''", "'"))
+
+
+@decimal.set_parse_action
+def _(result: ParseResults) -> Literal[Decimal]:
+    return DecimalLiteral(Decimal(result.decimal))
+
+
+@integer.set_parse_action
+def _(result: ParseResults) -> Literal[int]:
+    return LongLiteral(int(result.integer))
+
+
+@literal.set_parse_action
+def _(result: ParseResults) -> Literal[L]:
+    return result[0][0]
+
+
+@literal_set.set_parse_action
+def _(result: ParseResults) -> Literal[L]:
+    return result[0]
+
+
+comparison_op = one_of(["<", "<=", ">", ">=", "=", "==", "!=", "<>"], caseless=True).set_results_name("op")
+left_ref = column + comparison_op + literal
+right_ref = literal + comparison_op + column
+comparison = left_ref | right_ref
+
+
+@left_ref.set_parse_action
+def _(result: ParseResults) -> BooleanExpression:
+    if result.op == "<":
+        return LessThan(result.column, result.literal)
+    elif result.op == "<=":
+        return LessThanOrEqual(result.column, result.literal)
+    elif result.op == ">":
+        return GreaterThan(result.column, result.literal)
+    elif result.op == ">=":
+        return GreaterThanOrEqual(result.column, result.literal)
+    if result.op in ("=", "=="):
+        return EqualTo(result.column, result.literal)
+    if result.op in ("!=", "<>"):
+        return NotEqualTo(result.column, result.literal)
+    raise ValueError(f"Unsupported operation type: {result.op}")
+
+
+@right_ref.set_parse_action
+def _(result: ParseResults) -> BooleanExpression:
+    if result.op == "<":
+        return GreaterThan(result.column, result.literal)
+    elif result.op == "<=":
+        return GreaterThanOrEqual(result.column, result.literal)
+    elif result.op == ">":
+        return LessThan(result.column, result.literal)
+    elif result.op == ">=":
+        return LessThanOrEqual(result.column, result.literal)
+    elif result.op in ("=", "=="):
+        return EqualTo(result.column, result.literal)
+    elif result.op in ("!=", "<>"):
+        return NotEqualTo(result.column, result.literal)
+    raise ValueError(f"Unsupported operation type: {result.op}")
+
+
+is_null = column + IS + NULL
+not_null = column + IS + NOT + NULL
+null_check = is_null | not_null
+
+
+@is_null.set_parse_action
+def _(result: ParseResults) -> BooleanExpression:
+    return IsNull(result.column)
+
+
+@not_null.set_parse_action
+def _(result: ParseResults) -> BooleanExpression:
+    return NotNull(result.column)
+
+
+is_nan = column + IS + NAN
+not_nan = column + IS + NOT + NAN
+nan_check = is_nan | not_nan
+
+
+@is_nan.set_parse_action
+def _(result: ParseResults) -> BooleanExpression:
+    return IsNaN(result.column)
+
+
+@not_nan.set_parse_action
+def _(result: ParseResults) -> BooleanExpression:
+    return NotNaN(result.column)
+
+
+is_in = column + IN + "(" + literal_set + ")"
+not_in = column + NOT + IN + "(" + literal_set + ")"
+in_check = is_in | not_in
+
+
+@is_in.set_parse_action
+def _(result: ParseResults) -> BooleanExpression:
+    return In(result.column, result.literal_set)
+
+
+@not_in.set_parse_action
+def _(result: ParseResults) -> BooleanExpression:
+    return NotIn(result.column, result.literal_set)
+
+
+predicate = (comparison | in_check | null_check | nan_check | boolean).set_results_name("predicate")
+
+
+def handle_not(result: ParseResults) -> Not:
+    return Not(result[0][0])
+
+
+def handle_and(result: ParseResults) -> And:
+    return And(result[0][0], result[0][1])
+
+
+def handle_or(result: ParseResults) -> Or:
+    return Or(result[0][0], result[0][1])
+
+
+boolean_expression = infix_notation(
+    predicate,
+    [
+        (Suppress(NOT), 1, opAssoc.RIGHT, handle_not),
+        (Suppress(AND), 2, opAssoc.LEFT, handle_and),
+        (Suppress(OR), 2, opAssoc.LEFT, handle_or),
+    ],
+).set_name("expr")
+
+
+def parse(expr: str) -> BooleanExpression:
+    """Parses a boolean expression"""
+    return boolean_expression.parse_string(expr)[0]
diff --git a/python/pyproject.toml b/python/pyproject.toml
index 0312018be8..fd8c3fcb52 100644
--- a/python/pyproject.toml
+++ b/python/pyproject.toml
@@ -209,5 +209,9 @@ ignore_missing_imports = true
 module = "duckdb.*"
 ignore_missing_imports = true
 
+[[tool.mypy.overrides]]
+module = "pyparsing.*"
+ignore_missing_imports = true
+
 [tool.coverage.run]
 source = ['pyiceberg/']
diff --git a/python/tests/expressions/test_literals.py b/python/tests/expressions/test_literals.py
index efcacc4574..ff9ae5629e 100644
--- a/python/tests/expressions/test_literals.py
+++ b/python/tests/expressions/test_literals.py
@@ -667,10 +667,6 @@ def test_invalid_decimal_conversions():
         literal(Decimal("34.11")),
         [
             BooleanType(),
-            IntegerType(),
-            LongType(),
-            FloatType(),
-            DoubleType(),
             DateType(),
             TimeType(),
             TimestampType(),
diff --git a/python/tests/expressions/test_parser.py b/python/tests/expressions/test_parser.py
new file mode 100644
index 0000000000..47704be4c3
--- /dev/null
+++ b/python/tests/expressions/test_parser.py
@@ -0,0 +1,151 @@
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing,
+#  software distributed under the License is distributed on an
+#  "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+#  KIND, either express or implied.  See the License for the
+#  specific language governing permissions and limitations
+#  under the License.
+import pytest
+from pyparsing import ParseException
+
+from pyiceberg.expressions import (
+    AlwaysFalse,
+    AlwaysTrue,
+    And,
+    EqualTo,
+    GreaterThan,
+    GreaterThanOrEqual,
+    In,
+    IsNaN,
+    IsNull,
+    LessThan,
+    LessThanOrEqual,
+    Not,
+    NotEqualTo,
+    NotIn,
+    NotNaN,
+    NotNull,
+    Or,
+    parser,
+)
+
+
+def test_true():
+    assert AlwaysTrue() == parser.parse("true")
+
+
+def test_false():
+    assert AlwaysFalse() == parser.parse("false")
+
+
+def test_is_null():
+    assert IsNull("x") == parser.parse("x is null")
+    assert IsNull("x") == parser.parse("x IS NULL")
+
+
+def test_not_null():
+    assert NotNull("x") == parser.parse("x is not null")
+    assert NotNull("x") == parser.parse("x IS NOT NULL")
+
+
+def test_is_nan():
+    assert IsNaN("x") == parser.parse("x is nan")
+    assert IsNaN("x") == parser.parse("x IS NAN")
+
+
+def test_not_nan():
+    assert NotNaN("x") == parser.parse("x is not nan")
+    assert NotNaN("x") == parser.parse("x IS NOT NaN")
+
+
+def test_less_than():
+    assert LessThan("x", 5) == parser.parse("x < 5")
+    assert LessThan("x", "a") == parser.parse("'a' > x")
+
+
+def test_less_than_or_equal():
+    assert LessThanOrEqual("x", 5) == parser.parse("x <= 5")
+    assert LessThanOrEqual("x", "a") == parser.parse("'a' >= x")
+
+
+def test_greater_than():
+    assert GreaterThan("x", 5) == parser.parse("x > 5")
+    assert GreaterThan("x", "a") == parser.parse("'a' < x")
+
+
+def test_greater_than_or_equal():
+    assert GreaterThanOrEqual("x", 5) == parser.parse("x <= 5")
+    assert GreaterThanOrEqual("x", "a") == parser.parse("'a' >= x")
+
+
+def test_equal_to():
+    assert EqualTo("x", 5) == parser.parse("x = 5")
+    assert EqualTo("x", "a") == parser.parse("'a' = x")
+    assert EqualTo("x", "a") == parser.parse("x == 'a'")
+    assert EqualTo("x", 5) == parser.parse("5 == x")
+
+
+def test_not_equal_to():
+    assert NotEqualTo("x", 5) == parser.parse("x != 5")
+    assert NotEqualTo("x", "a") == parser.parse("'a' != x")
+    assert NotEqualTo("x", "a") == parser.parse("x <> 'a'")
+    assert NotEqualTo("x", 5) == parser.parse("5 <> x")
+
+
+def test_in():
+    assert In("x", {5, 6, 7}) == parser.parse("x in (5, 6, 7)")
+    assert In("x", {"a", "b", "c"}) == parser.parse("x IN ('a', 'b', 'c')")
+
+
+def test_in_different_types():
+    with pytest.raises(ParseException):
+        parser.parse("x in (5, 'a')")
+
+
+def test_not_in():
+    assert NotIn("x", {5, 6, 7}) == parser.parse("x not in (5, 6, 7)")
+    assert NotIn("x", {"a", "b", "c"}) == parser.parse("x NOT IN ('a', 'b', 'c')")
+
+
+def test_not_in_different_types():
+    with pytest.raises(ParseException):
+        parser.parse("x not in (5, 'a')")
+
+
+def test_simple_and():
+    assert And(GreaterThanOrEqual("x", 5), LessThan("x", 10)) == parser.parse("5 <= x and x < 10")
+
+
+def test_and_with_not():
+    assert And(Not(GreaterThanOrEqual("x", 5)), LessThan("x", 10)) == parser.parse("not 5 <= x and x < 10")
+    assert And(GreaterThanOrEqual("x", 5), Not(LessThan("x", 10))) == parser.parse("5 <= x and not x < 10")
+
+
+def test_or_with_not():
+    assert Or(Not(LessThan("x", 5)), GreaterThan("x", 10)) == parser.parse("not x < 5 or 10 < x")
+    assert Or(LessThan("x", 5), Not(GreaterThan("x", 10))) == parser.parse("x < 5 or not 10 < x")
+
+
+def test_simple_or():
+    assert Or(LessThan("x", 5), GreaterThan("x", 10)) == parser.parse("x < 5 or 10 < x")
+
+
+def test_and_or_without_parens():
+    assert Or(And(NotNull("x"), LessThan("x", 5)), GreaterThan("x", 10)) == parser.parse("x is not null and x < 5 or 10 < x")
+    assert Or(IsNull("x"), And(GreaterThanOrEqual("x", 5), LessThan("x", 10))) == parser.parse("x is null or 5 <= x and x < 10")
+
+
+def test_and_or_with_parens():
+    assert And(NotNull("x"), Or(LessThan("x", 5), GreaterThan("x", 10))) == parser.parse("x is not null and (x < 5 or 10 < x)")
+    assert Or(IsNull("x"), And(GreaterThanOrEqual("x", 5), Not(LessThan("x", 10)))) == parser.parse(
+        "(x is null) or (5 <= x) and not(x < 10)"
+    )