You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/11/07 19:48:40 UTC

[arrow-datafusion] branch master updated: Linearize binary expressions to reduce proto tree complexity (#4115)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 6b7129483 Linearize binary expressions to reduce proto tree complexity (#4115)
6b7129483 is described below

commit 6b712948362b078f1226623ca46ab96d8bd2b768
Author: Batuhan Taskaya <is...@gmail.com>
AuthorDate: Mon Nov 7 22:48:22 2022 +0300

    Linearize binary expressions to reduce proto tree complexity (#4115)
---
 datafusion/proto/proto/datafusion.proto  |  6 +-
 datafusion/proto/src/bytes/mod.rs        | 95 +++++++++++++++++++++++++++++++-
 datafusion/proto/src/from_proto.rs       | 28 ++++++++--
 datafusion/proto/src/generated/pbjson.rs | 41 ++++----------
 datafusion/proto/src/generated/prost.rs  | 11 ++--
 datafusion/proto/src/to_proto.rs         | 33 +++++++++--
 6 files changed, 168 insertions(+), 46 deletions(-)

diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto
index 3cb9763d3..e40734538 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -409,8 +409,10 @@ message AliasNode {
 }
 
 message BinaryExprNode {
-  LogicalExprNode l = 1;
-  LogicalExprNode r = 2;
+  // Represents the operands from the left inner most expression
+  // to the right outer most expression where each of them are chained
+  // with the operator 'op'.
+  repeated LogicalExprNode operands = 1;
   string op = 3;
 }
 
diff --git a/datafusion/proto/src/bytes/mod.rs b/datafusion/proto/src/bytes/mod.rs
index 3677ea8af..1ffe8ab7a 100644
--- a/datafusion/proto/src/bytes/mod.rs
+++ b/datafusion/proto/src/bytes/mod.rs
@@ -321,6 +321,98 @@ mod test {
         Expr::from_bytes(&bytes).unwrap();
     }
 
+    fn roundtrip_expr(expr: &Expr) -> Expr {
+        let bytes = expr.to_bytes().unwrap();
+        Expr::from_bytes(&bytes).unwrap()
+    }
+
+    #[test]
+    fn exact_roundtrip_linearized_binary_expr() {
+        // (((A AND B) AND C) AND D)
+        let expr_ordered = col("A").and(col("B")).and(col("C")).and(col("D"));
+        assert_eq!(expr_ordered, roundtrip_expr(&expr_ordered));
+
+        // Ensure that no other variation becomes equal
+        let other_variants = vec![
+            // (((B AND A) AND C) AND D)
+            col("B").and(col("A")).and(col("C")).and(col("D")),
+            // (((A AND C) AND B) AND D)
+            col("A").and(col("C")).and(col("B")).and(col("D")),
+            // (((A AND B) AND D) AND C)
+            col("A").and(col("B")).and(col("D")).and(col("C")),
+            // A AND (B AND (C AND D)))
+            col("A").and(col("B").and(col("C").and(col("D")))),
+        ];
+        for case in other_variants {
+            // Each variant is still equal to itself
+            assert_eq!(case, roundtrip_expr(&case));
+
+            // But non of them is equal to the original
+            assert_ne!(expr_ordered, roundtrip_expr(&case));
+            assert_ne!(roundtrip_expr(&expr_ordered), roundtrip_expr(&case));
+        }
+    }
+
+    #[test]
+    fn roundtrip_deeply_nested_binary_expr() {
+        // We need more stack space so this doesn't overflow in dev builds
+        std::thread::Builder::new()
+            .stack_size(10_000_000)
+            .spawn(|| {
+                let n = 100;
+                // a < 5
+                let basic_expr = col("a").lt(lit(5i32));
+                // (a < 5) OR (a < 5) OR (a < 5) OR ...
+                let or_chain = (0..n)
+                    .fold(basic_expr.clone(), |expr, _| expr.or(basic_expr.clone()));
+                // (a < 5) OR (a < 5) AND (a < 5) OR (a < 5) AND (a < 5) AND (a < 5) OR ...
+                let expr =
+                    (0..n).fold(or_chain.clone(), |expr, _| expr.and(or_chain.clone()));
+
+                // Should work fine.
+                let bytes = expr.to_bytes().unwrap();
+
+                let decoded_expr = Expr::from_bytes(&bytes).expect(
+                    "serialization worked, so deserialization should work as well",
+                );
+                assert_eq!(decoded_expr, expr);
+            })
+            .expect("spawning thread")
+            .join()
+            .expect("joining thread");
+    }
+
+    #[test]
+    fn roundtrip_deeply_nested_binary_expr_reverse_order() {
+        // We need more stack space so this doesn't overflow in dev builds
+        std::thread::Builder::new()
+            .stack_size(10_000_000)
+            .spawn(|| {
+                let n = 100;
+
+                // a < 5
+                let expr_base = col("a").lt(lit(5i32));
+
+                // ((a < 5 AND a < 5) AND a < 5) AND ...
+                let and_chain =
+                    (0..n).fold(expr_base.clone(), |expr, _| expr.and(expr_base.clone()));
+
+                // a < 5 AND (a < 5 AND (a < 5 AND ...))
+                let expr = expr_base.and(and_chain);
+
+                // Should work fine.
+                let bytes = expr.to_bytes().unwrap();
+
+                let decoded_expr = Expr::from_bytes(&bytes).expect(
+                    "serialization worked, so deserialization should work as well",
+                );
+                assert_eq!(decoded_expr, expr);
+            })
+            .expect("spawning thread")
+            .join()
+            .expect("joining thread");
+    }
+
     #[test]
     fn roundtrip_deeply_nested() {
         // we need more stack space so this doesn't overflow in dev builds
@@ -332,7 +424,8 @@ mod test {
                 println!("testing: {n}");
 
                 let expr_base = col("a").lt(lit(5i32));
-                let expr = (0..n).fold(expr_base.clone(), |expr, _| expr.and(expr_base.clone()));
+                // Generate a tree of AND and OR expressions (no subsequent ANDs or ORs).
+                let expr = (0..n).fold(expr_base.clone(), |expr, n| if n % 2 == 0 { expr.and(expr_base.clone()) } else { expr.or(expr_base.clone()) });
 
                 // Convert it to an opaque form
                 let bytes = match expr.to_bytes() {
diff --git a/datafusion/proto/src/from_proto.rs b/datafusion/proto/src/from_proto.rs
index 775acc3a0..8f6377739 100644
--- a/datafusion/proto/src/from_proto.rs
+++ b/datafusion/proto/src/from_proto.rs
@@ -690,11 +690,29 @@ pub fn parse_expr(
         .ok_or_else(|| Error::required("expr_type"))?;
 
     match expr_type {
-        ExprType::BinaryExpr(binary_expr) => Ok(Expr::BinaryExpr(BinaryExpr::new(
-            Box::new(parse_required_expr(&binary_expr.l, registry, "l")?),
-            from_proto_binary_op(&binary_expr.op)?,
-            Box::new(parse_required_expr(&binary_expr.r, registry, "r")?),
-        ))),
+        ExprType::BinaryExpr(binary_expr) => {
+            let op = from_proto_binary_op(&binary_expr.op)?;
+            let operands = binary_expr
+                .operands
+                .iter()
+                .map(|expr| parse_expr(expr, registry))
+                .collect::<Result<Vec<_>, _>>()?;
+
+            if operands.len() < 2 {
+                return Err(proto_error(
+                    "A binary expression must always have at least 2 operands",
+                ));
+            }
+
+            // Reduce the linearized operands (ordered by left innermost to right
+            // outermost) into a single expression tree.
+            Ok(operands
+                .into_iter()
+                .reduce(|left, right| {
+                    Expr::BinaryExpr(BinaryExpr::new(Box::new(left), op, Box::new(right)))
+                })
+                .expect("Binary expression could not be reduced to a single expression."))
+        }
         ExprType::GetIndexedField(field) => {
             let key = field
                 .key
diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs
index 2aee09ab3..590ce61a2 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -1464,21 +1464,15 @@ impl serde::Serialize for BinaryExprNode {
     {
         use serde::ser::SerializeStruct;
         let mut len = 0;
-        if self.l.is_some() {
-            len += 1;
-        }
-        if self.r.is_some() {
+        if !self.operands.is_empty() {
             len += 1;
         }
         if !self.op.is_empty() {
             len += 1;
         }
         let mut struct_ser = serializer.serialize_struct("datafusion.BinaryExprNode", len)?;
-        if let Some(v) = self.l.as_ref() {
-            struct_ser.serialize_field("l", v)?;
-        }
-        if let Some(v) = self.r.as_ref() {
-            struct_ser.serialize_field("r", v)?;
+        if !self.operands.is_empty() {
+            struct_ser.serialize_field("operands", &self.operands)?;
         }
         if !self.op.is_empty() {
             struct_ser.serialize_field("op", &self.op)?;
@@ -1493,15 +1487,13 @@ impl<'de> serde::Deserialize<'de> for BinaryExprNode {
         D: serde::Deserializer<'de>,
     {
         const FIELDS: &[&str] = &[
-            "l",
-            "r",
+            "operands",
             "op",
         ];
 
         #[allow(clippy::enum_variant_names)]
         enum GeneratedField {
-            L,
-            R,
+            Operands,
             Op,
         }
         impl<'de> serde::Deserialize<'de> for GeneratedField {
@@ -1524,8 +1516,7 @@ impl<'de> serde::Deserialize<'de> for BinaryExprNode {
                         E: serde::de::Error,
                     {
                         match value {
-                            "l" => Ok(GeneratedField::L),
-                            "r" => Ok(GeneratedField::R),
+                            "operands" => Ok(GeneratedField::Operands),
                             "op" => Ok(GeneratedField::Op),
                             _ => Err(serde::de::Error::unknown_field(value, FIELDS)),
                         }
@@ -1546,22 +1537,15 @@ impl<'de> serde::Deserialize<'de> for BinaryExprNode {
                 where
                     V: serde::de::MapAccess<'de>,
             {
-                let mut l__ = None;
-                let mut r__ = None;
+                let mut operands__ = None;
                 let mut op__ = None;
                 while let Some(k) = map.next_key()? {
                     match k {
-                        GeneratedField::L => {
-                            if l__.is_some() {
-                                return Err(serde::de::Error::duplicate_field("l"));
-                            }
-                            l__ = map.next_value()?;
-                        }
-                        GeneratedField::R => {
-                            if r__.is_some() {
-                                return Err(serde::de::Error::duplicate_field("r"));
+                        GeneratedField::Operands => {
+                            if operands__.is_some() {
+                                return Err(serde::de::Error::duplicate_field("operands"));
                             }
-                            r__ = map.next_value()?;
+                            operands__ = Some(map.next_value()?);
                         }
                         GeneratedField::Op => {
                             if op__.is_some() {
@@ -1572,8 +1556,7 @@ impl<'de> serde::Deserialize<'de> for BinaryExprNode {
                     }
                 }
                 Ok(BinaryExprNode {
-                    l: l__,
-                    r: r__,
+                    operands: operands__.unwrap_or_default(),
                     op: op__.unwrap_or_default(),
                 })
             }
diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs
index 4faa08fec..09177845a 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -407,7 +407,7 @@ pub mod logical_expr_node {
         Literal(super::ScalarValue),
         /// binary expressions
         #[prost(message, tag="4")]
-        BinaryExpr(::prost::alloc::boxed::Box<super::BinaryExprNode>),
+        BinaryExpr(super::BinaryExprNode),
         /// aggregate expressions
         #[prost(message, tag="5")]
         AggregateExpr(::prost::alloc::boxed::Box<super::AggregateExprNode>),
@@ -554,10 +554,11 @@ pub struct AliasNode {
 }
 #[derive(Clone, PartialEq, ::prost::Message)]
 pub struct BinaryExprNode {
-    #[prost(message, optional, boxed, tag="1")]
-    pub l: ::core::option::Option<::prost::alloc::boxed::Box<LogicalExprNode>>,
-    #[prost(message, optional, boxed, tag="2")]
-    pub r: ::core::option::Option<::prost::alloc::boxed::Box<LogicalExprNode>>,
+    /// Represents the operands from the left inner most expression
+    /// to the right outer most expression where each of them are chained
+    /// with the operator 'op'.
+    #[prost(message, repeated, tag="1")]
+    pub operands: ::prost::alloc::vec::Vec<LogicalExprNode>,
     #[prost(string, tag="3")]
     pub op: ::prost::alloc::string::String,
 }
diff --git a/datafusion/proto/src/to_proto.rs b/datafusion/proto/src/to_proto.rs
index 96c9d983a..95224bf8e 100644
--- a/datafusion/proto/src/to_proto.rs
+++ b/datafusion/proto/src/to_proto.rs
@@ -455,11 +455,36 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode {
                 }
             }
             Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
-                let binary_expr = Box::new(protobuf::BinaryExprNode {
-                    l: Some(Box::new(left.as_ref().try_into()?)),
-                    r: Some(Box::new(right.as_ref().try_into()?)),
+                // Try to linerize a nested binary expression tree of the same operator
+                // into a flat vector of expressions.
+                let mut exprs = vec![right.as_ref()];
+                let mut current_expr = left.as_ref();
+                while let Expr::BinaryExpr(BinaryExpr {
+                    left,
+                    op: current_op,
+                    right,
+                }) = current_expr
+                {
+                    if current_op == op {
+                        exprs.push(right.as_ref());
+                        current_expr = left.as_ref();
+                    } else {
+                        break;
+                    }
+                }
+                exprs.push(current_expr);
+
+                let binary_expr = protobuf::BinaryExprNode {
+                    // We need to reverse exprs since operands are expected to be
+                    // linearized from left innermost to right outermost (but while
+                    // traversing the chain we do the exact opposite).
+                    operands: exprs
+                    .into_iter()
+                    .rev()
+                    .map(|expr| expr.try_into())
+                    .collect::<Result<Vec<_>, Error>>()?,
                     op: format!("{:?}", op),
-                });
+                };
                 Self {
                     expr_type: Some(ExprType::BinaryExpr(binary_expr)),
                 }