You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/11/07 19:48:40 UTC
[arrow-datafusion] branch master updated: Linearize binary expressions to reduce proto tree complexity (#4115)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 6b7129483 Linearize binary expressions to reduce proto tree complexity (#4115)
6b7129483 is described below
commit 6b712948362b078f1226623ca46ab96d8bd2b768
Author: Batuhan Taskaya <is...@gmail.com>
AuthorDate: Mon Nov 7 22:48:22 2022 +0300
Linearize binary expressions to reduce proto tree complexity (#4115)
---
datafusion/proto/proto/datafusion.proto | 6 +-
datafusion/proto/src/bytes/mod.rs | 95 +++++++++++++++++++++++++++++++-
datafusion/proto/src/from_proto.rs | 28 ++++++++--
datafusion/proto/src/generated/pbjson.rs | 41 ++++----------
datafusion/proto/src/generated/prost.rs | 11 ++--
datafusion/proto/src/to_proto.rs | 33 +++++++++--
6 files changed, 168 insertions(+), 46 deletions(-)
diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto
index 3cb9763d3..e40734538 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -409,8 +409,10 @@ message AliasNode {
}
message BinaryExprNode {
- LogicalExprNode l = 1;
- LogicalExprNode r = 2;
+ // Represents the operands from the left inner most expression
+ // to the right outer most expression where each of them are chained
+ // with the operator 'op'.
+ repeated LogicalExprNode operands = 1;
string op = 3;
}
diff --git a/datafusion/proto/src/bytes/mod.rs b/datafusion/proto/src/bytes/mod.rs
index 3677ea8af..1ffe8ab7a 100644
--- a/datafusion/proto/src/bytes/mod.rs
+++ b/datafusion/proto/src/bytes/mod.rs
@@ -321,6 +321,98 @@ mod test {
Expr::from_bytes(&bytes).unwrap();
}
+ fn roundtrip_expr(expr: &Expr) -> Expr {
+ let bytes = expr.to_bytes().unwrap();
+ Expr::from_bytes(&bytes).unwrap()
+ }
+
+ #[test]
+ fn exact_roundtrip_linearized_binary_expr() {
+ // (((A AND B) AND C) AND D)
+ let expr_ordered = col("A").and(col("B")).and(col("C")).and(col("D"));
+ assert_eq!(expr_ordered, roundtrip_expr(&expr_ordered));
+
+ // Ensure that no other variation becomes equal
+ let other_variants = vec![
+ // (((B AND A) AND C) AND D)
+ col("B").and(col("A")).and(col("C")).and(col("D")),
+ // (((A AND C) AND B) AND D)
+ col("A").and(col("C")).and(col("B")).and(col("D")),
+ // (((A AND B) AND D) AND C)
+ col("A").and(col("B")).and(col("D")).and(col("C")),
+ // A AND (B AND (C AND D)))
+ col("A").and(col("B").and(col("C").and(col("D")))),
+ ];
+ for case in other_variants {
+ // Each variant is still equal to itself
+ assert_eq!(case, roundtrip_expr(&case));
+
+ // But non of them is equal to the original
+ assert_ne!(expr_ordered, roundtrip_expr(&case));
+ assert_ne!(roundtrip_expr(&expr_ordered), roundtrip_expr(&case));
+ }
+ }
+
+ #[test]
+ fn roundtrip_deeply_nested_binary_expr() {
+ // We need more stack space so this doesn't overflow in dev builds
+ std::thread::Builder::new()
+ .stack_size(10_000_000)
+ .spawn(|| {
+ let n = 100;
+ // a < 5
+ let basic_expr = col("a").lt(lit(5i32));
+ // (a < 5) OR (a < 5) OR (a < 5) OR ...
+ let or_chain = (0..n)
+ .fold(basic_expr.clone(), |expr, _| expr.or(basic_expr.clone()));
+ // (a < 5) OR (a < 5) AND (a < 5) OR (a < 5) AND (a < 5) AND (a < 5) OR ...
+ let expr =
+ (0..n).fold(or_chain.clone(), |expr, _| expr.and(or_chain.clone()));
+
+ // Should work fine.
+ let bytes = expr.to_bytes().unwrap();
+
+ let decoded_expr = Expr::from_bytes(&bytes).expect(
+ "serialization worked, so deserialization should work as well",
+ );
+ assert_eq!(decoded_expr, expr);
+ })
+ .expect("spawning thread")
+ .join()
+ .expect("joining thread");
+ }
+
+ #[test]
+ fn roundtrip_deeply_nested_binary_expr_reverse_order() {
+ // We need more stack space so this doesn't overflow in dev builds
+ std::thread::Builder::new()
+ .stack_size(10_000_000)
+ .spawn(|| {
+ let n = 100;
+
+ // a < 5
+ let expr_base = col("a").lt(lit(5i32));
+
+ // ((a < 5 AND a < 5) AND a < 5) AND ...
+ let and_chain =
+ (0..n).fold(expr_base.clone(), |expr, _| expr.and(expr_base.clone()));
+
+ // a < 5 AND (a < 5 AND (a < 5 AND ...))
+ let expr = expr_base.and(and_chain);
+
+ // Should work fine.
+ let bytes = expr.to_bytes().unwrap();
+
+ let decoded_expr = Expr::from_bytes(&bytes).expect(
+ "serialization worked, so deserialization should work as well",
+ );
+ assert_eq!(decoded_expr, expr);
+ })
+ .expect("spawning thread")
+ .join()
+ .expect("joining thread");
+ }
+
#[test]
fn roundtrip_deeply_nested() {
// we need more stack space so this doesn't overflow in dev builds
@@ -332,7 +424,8 @@ mod test {
println!("testing: {n}");
let expr_base = col("a").lt(lit(5i32));
- let expr = (0..n).fold(expr_base.clone(), |expr, _| expr.and(expr_base.clone()));
+ // Generate a tree of AND and OR expressions (no subsequent ANDs or ORs).
+ let expr = (0..n).fold(expr_base.clone(), |expr, n| if n % 2 == 0 { expr.and(expr_base.clone()) } else { expr.or(expr_base.clone()) });
// Convert it to an opaque form
let bytes = match expr.to_bytes() {
diff --git a/datafusion/proto/src/from_proto.rs b/datafusion/proto/src/from_proto.rs
index 775acc3a0..8f6377739 100644
--- a/datafusion/proto/src/from_proto.rs
+++ b/datafusion/proto/src/from_proto.rs
@@ -690,11 +690,29 @@ pub fn parse_expr(
.ok_or_else(|| Error::required("expr_type"))?;
match expr_type {
- ExprType::BinaryExpr(binary_expr) => Ok(Expr::BinaryExpr(BinaryExpr::new(
- Box::new(parse_required_expr(&binary_expr.l, registry, "l")?),
- from_proto_binary_op(&binary_expr.op)?,
- Box::new(parse_required_expr(&binary_expr.r, registry, "r")?),
- ))),
+ ExprType::BinaryExpr(binary_expr) => {
+ let op = from_proto_binary_op(&binary_expr.op)?;
+ let operands = binary_expr
+ .operands
+ .iter()
+ .map(|expr| parse_expr(expr, registry))
+ .collect::<Result<Vec<_>, _>>()?;
+
+ if operands.len() < 2 {
+ return Err(proto_error(
+ "A binary expression must always have at least 2 operands",
+ ));
+ }
+
+ // Reduce the linearized operands (ordered by left innermost to right
+ // outermost) into a single expression tree.
+ Ok(operands
+ .into_iter()
+ .reduce(|left, right| {
+ Expr::BinaryExpr(BinaryExpr::new(Box::new(left), op, Box::new(right)))
+ })
+ .expect("Binary expression could not be reduced to a single expression."))
+ }
ExprType::GetIndexedField(field) => {
let key = field
.key
diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs
index 2aee09ab3..590ce61a2 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -1464,21 +1464,15 @@ impl serde::Serialize for BinaryExprNode {
{
use serde::ser::SerializeStruct;
let mut len = 0;
- if self.l.is_some() {
- len += 1;
- }
- if self.r.is_some() {
+ if !self.operands.is_empty() {
len += 1;
}
if !self.op.is_empty() {
len += 1;
}
let mut struct_ser = serializer.serialize_struct("datafusion.BinaryExprNode", len)?;
- if let Some(v) = self.l.as_ref() {
- struct_ser.serialize_field("l", v)?;
- }
- if let Some(v) = self.r.as_ref() {
- struct_ser.serialize_field("r", v)?;
+ if !self.operands.is_empty() {
+ struct_ser.serialize_field("operands", &self.operands)?;
}
if !self.op.is_empty() {
struct_ser.serialize_field("op", &self.op)?;
@@ -1493,15 +1487,13 @@ impl<'de> serde::Deserialize<'de> for BinaryExprNode {
D: serde::Deserializer<'de>,
{
const FIELDS: &[&str] = &[
- "l",
- "r",
+ "operands",
"op",
];
#[allow(clippy::enum_variant_names)]
enum GeneratedField {
- L,
- R,
+ Operands,
Op,
}
impl<'de> serde::Deserialize<'de> for GeneratedField {
@@ -1524,8 +1516,7 @@ impl<'de> serde::Deserialize<'de> for BinaryExprNode {
E: serde::de::Error,
{
match value {
- "l" => Ok(GeneratedField::L),
- "r" => Ok(GeneratedField::R),
+ "operands" => Ok(GeneratedField::Operands),
"op" => Ok(GeneratedField::Op),
_ => Err(serde::de::Error::unknown_field(value, FIELDS)),
}
@@ -1546,22 +1537,15 @@ impl<'de> serde::Deserialize<'de> for BinaryExprNode {
where
V: serde::de::MapAccess<'de>,
{
- let mut l__ = None;
- let mut r__ = None;
+ let mut operands__ = None;
let mut op__ = None;
while let Some(k) = map.next_key()? {
match k {
- GeneratedField::L => {
- if l__.is_some() {
- return Err(serde::de::Error::duplicate_field("l"));
- }
- l__ = map.next_value()?;
- }
- GeneratedField::R => {
- if r__.is_some() {
- return Err(serde::de::Error::duplicate_field("r"));
+ GeneratedField::Operands => {
+ if operands__.is_some() {
+ return Err(serde::de::Error::duplicate_field("operands"));
}
- r__ = map.next_value()?;
+ operands__ = Some(map.next_value()?);
}
GeneratedField::Op => {
if op__.is_some() {
@@ -1572,8 +1556,7 @@ impl<'de> serde::Deserialize<'de> for BinaryExprNode {
}
}
Ok(BinaryExprNode {
- l: l__,
- r: r__,
+ operands: operands__.unwrap_or_default(),
op: op__.unwrap_or_default(),
})
}
diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs
index 4faa08fec..09177845a 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -407,7 +407,7 @@ pub mod logical_expr_node {
Literal(super::ScalarValue),
/// binary expressions
#[prost(message, tag="4")]
- BinaryExpr(::prost::alloc::boxed::Box<super::BinaryExprNode>),
+ BinaryExpr(super::BinaryExprNode),
/// aggregate expressions
#[prost(message, tag="5")]
AggregateExpr(::prost::alloc::boxed::Box<super::AggregateExprNode>),
@@ -554,10 +554,11 @@ pub struct AliasNode {
}
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct BinaryExprNode {
- #[prost(message, optional, boxed, tag="1")]
- pub l: ::core::option::Option<::prost::alloc::boxed::Box<LogicalExprNode>>,
- #[prost(message, optional, boxed, tag="2")]
- pub r: ::core::option::Option<::prost::alloc::boxed::Box<LogicalExprNode>>,
+ /// Represents the operands from the left inner most expression
+ /// to the right outer most expression where each of them are chained
+ /// with the operator 'op'.
+ #[prost(message, repeated, tag="1")]
+ pub operands: ::prost::alloc::vec::Vec<LogicalExprNode>,
#[prost(string, tag="3")]
pub op: ::prost::alloc::string::String,
}
diff --git a/datafusion/proto/src/to_proto.rs b/datafusion/proto/src/to_proto.rs
index 96c9d983a..95224bf8e 100644
--- a/datafusion/proto/src/to_proto.rs
+++ b/datafusion/proto/src/to_proto.rs
@@ -455,11 +455,36 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode {
}
}
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
- let binary_expr = Box::new(protobuf::BinaryExprNode {
- l: Some(Box::new(left.as_ref().try_into()?)),
- r: Some(Box::new(right.as_ref().try_into()?)),
+ // Try to linerize a nested binary expression tree of the same operator
+ // into a flat vector of expressions.
+ let mut exprs = vec![right.as_ref()];
+ let mut current_expr = left.as_ref();
+ while let Expr::BinaryExpr(BinaryExpr {
+ left,
+ op: current_op,
+ right,
+ }) = current_expr
+ {
+ if current_op == op {
+ exprs.push(right.as_ref());
+ current_expr = left.as_ref();
+ } else {
+ break;
+ }
+ }
+ exprs.push(current_expr);
+
+ let binary_expr = protobuf::BinaryExprNode {
+ // We need to reverse exprs since operands are expected to be
+ // linearized from left innermost to right outermost (but while
+ // traversing the chain we do the exact opposite).
+ operands: exprs
+ .into_iter()
+ .rev()
+ .map(|expr| expr.try_into())
+ .collect::<Result<Vec<_>, Error>>()?,
op: format!("{:?}", op),
- });
+ };
Self {
expr_type: Some(ExprType::BinaryExpr(binary_expr)),
}