You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ag...@apache.org on 2022/05/01 02:17:38 UTC

[arrow-datafusion] branch master updated: Allow CTEs to be referenced from subquery expressions (#2384)

This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 235ef95da Allow CTEs to be referenced from subquery expressions (#2384)
235ef95da is described below

commit 235ef95da13c772ac866ff6c0a64341b88e2fdf5
Author: Andy Grove <ag...@apache.org>
AuthorDate: Sat Apr 30 20:17:33 2022 -0600

    Allow CTEs to be referenced from subquery expressions (#2384)
---
 datafusion/common/src/dfschema.rs  |   5 +-
 datafusion/core/src/sql/planner.rs | 236 +++++++++++++++++++++++--------------
 2 files changed, 153 insertions(+), 88 deletions(-)

diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs
index 209c5e3ed..ec186899f 100644
--- a/datafusion/common/src/dfschema.rs
+++ b/datafusion/common/src/dfschema.rs
@@ -133,11 +133,14 @@ impl DFSchema {
     /// Modify this schema by appending the fields from the supplied schema, ignoring any
     /// duplicate fields.
     pub fn merge(&mut self, other_schema: &DFSchema) {
+        if other_schema.fields.is_empty() {
+            return;
+        }
         for field in other_schema.fields() {
             // skip duplicate columns
             let duplicated_field = match field.qualifier() {
                 Some(q) => self.field_with_name(Some(q.as_str()), field.name()).is_ok(),
-                // for unqualifed columns, check as unqualified name
+                // for unqualified columns, check as unqualified name
                 None => self.field_with_unqualified_name(field.name()).is_ok(),
             };
             if !duplicated_field {
diff --git a/datafusion/core/src/sql/planner.rs b/datafusion/core/src/sql/planner.rs
index 8325297c5..a04839f02 100644
--- a/datafusion/core/src/sql/planner.rs
+++ b/datafusion/core/src/sql/planner.rs
@@ -148,7 +148,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
                 analyze,
                 describe_alias: _,
             } => self.explain_statement_to_plan(verbose, analyze, *statement),
-            Statement::Query(query) => self.query_to_plan(*query),
+            Statement::Query(query) => self.query_to_plan(*query, &mut HashMap::new()),
             Statement::ShowVariable { variable } => self.show_variable_to_plan(&variable),
             Statement::CreateTable {
                 query: Some(query),
@@ -164,7 +164,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
                 && table_properties.is_empty()
                 && with_options.is_empty() =>
             {
-                let plan = self.query_to_plan(*query)?;
+                let plan = self.query_to_plan(*query, &mut HashMap::new())?;
 
                 Ok(LogicalPlan::CreateMemoryTable(CreateMemoryTable {
                     name: name.to_string(),
@@ -223,22 +223,22 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
     }
 
     /// Generate a logical plan from an SQL query
-    pub fn query_to_plan(&self, query: Query) -> Result<LogicalPlan> {
-        self.query_to_plan_with_alias(query, None, &mut HashMap::new(), None)
+    pub fn query_to_plan(
+        &self,
+        query: Query,
+        ctes: &mut HashMap<String, LogicalPlan>,
+    ) -> Result<LogicalPlan> {
+        self.query_to_plan_with_alias(query, None, ctes, None)
     }
 
     /// Generate a logical plan from a SQL subquery
     pub fn subquery_to_plan(
         &self,
         query: Query,
+        ctes: &mut HashMap<String, LogicalPlan>,
         outer_query_schema: &DFSchema,
     ) -> Result<LogicalPlan> {
-        self.query_to_plan_with_alias(
-            query,
-            None,
-            &mut HashMap::new(),
-            Some(outer_query_schema),
-        )
+        self.query_to_plan_with_alias(query, None, ctes, Some(outer_query_schema))
     }
 
     /// Generate a logic plan from an SQL query with optional alias
@@ -502,16 +502,16 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
         let right = self.create_relation(join.relation, ctes, outer_query_schema)?;
         match join.join_operator {
             JoinOperator::LeftOuter(constraint) => {
-                self.parse_join(left, right, constraint, JoinType::Left)
+                self.parse_join(left, right, constraint, JoinType::Left, ctes)
             }
             JoinOperator::RightOuter(constraint) => {
-                self.parse_join(left, right, constraint, JoinType::Right)
+                self.parse_join(left, right, constraint, JoinType::Right, ctes)
             }
             JoinOperator::Inner(constraint) => {
-                self.parse_join(left, right, constraint, JoinType::Inner)
+                self.parse_join(left, right, constraint, JoinType::Inner, ctes)
             }
             JoinOperator::FullOuter(constraint) => {
-                self.parse_join(left, right, constraint, JoinType::Full)
+                self.parse_join(left, right, constraint, JoinType::Full, ctes)
             }
             JoinOperator::CrossJoin => self.parse_cross_join(left, &right),
             other => Err(DataFusionError::NotImplemented(format!(
@@ -535,6 +535,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
         right: LogicalPlan,
         constraint: JoinConstraint,
         join_type: JoinType,
+        ctes: &mut HashMap<String, LogicalPlan>,
     ) -> Result<LogicalPlan> {
         match constraint {
             JoinConstraint::On(sql_expr) => {
@@ -542,7 +543,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
                 let join_schema = left.schema().join(right.schema())?;
 
                 // parse ON expression
-                let expr = self.sql_to_rex(sql_expr, &join_schema)?;
+                let expr = self.sql_to_rex(sql_expr, &join_schema, ctes)?;
 
                 // expression that didn't match equi-join pattern
                 let mut filter = vec![];
@@ -782,6 +783,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
         selection: Option<SQLExpr>,
         plans: Vec<LogicalPlan>,
         outer_query_schema: Option<&DFSchema>,
+        ctes: &mut HashMap<String, LogicalPlan>,
     ) -> Result<LogicalPlan> {
         match selection {
             Some(predicate_expr) => {
@@ -797,7 +799,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
                     join_schema.merge(outer);
                 }
 
-                let filter_expr = self.sql_to_rex(predicate_expr, &join_schema)?;
+                let filter_expr = self.sql_to_rex(predicate_expr, &join_schema, ctes)?;
 
                 // look for expressions of the form `<column> = <column>`
                 let mut possible_join_keys = vec![];
@@ -920,7 +922,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
         let empty_from = matches!(plans.first(), Some(LogicalPlan::EmptyRelation(_)));
 
         // process `where` clause
-        let plan = self.plan_selection(select.selection, plans, outer_query_schema)?;
+        let plan =
+            self.plan_selection(select.selection, plans, outer_query_schema, ctes)?;
 
         // process the SELECT expressions, with wildcards expanded.
         let select_exprs = self.prepare_select_exprs(
@@ -928,6 +931,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
             select.projection,
             empty_from,
             outer_query_schema,
+            ctes,
         )?;
 
         // having and group by clause may reference aliases defined in select projection
@@ -943,7 +947,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
             .having
             .map::<Result<Expr>, _>(|having_expr| {
                 let having_expr =
-                    self.sql_expr_to_logical_expr(having_expr, &combined_schema)?;
+                    self.sql_expr_to_logical_expr(having_expr, &combined_schema, ctes)?;
                 // This step "dereferences" any aliases in the HAVING clause.
                 //
                 // This is how we support queries with HAVING expressions that
@@ -980,7 +984,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
             .group_by
             .into_iter()
             .map(|e| {
-                let group_by_expr = self.sql_expr_to_logical_expr(e, &combined_schema)?;
+                let group_by_expr =
+                    self.sql_expr_to_logical_expr(e, &combined_schema, ctes)?;
                 let group_by_expr = resolve_aliases_to_exprs(&group_by_expr, &alias_map)?;
                 let group_by_expr =
                     resolve_positions_to_exprs(&group_by_expr, &select_exprs)
@@ -1066,11 +1071,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
         projection: Vec<SelectItem>,
         empty_from: bool,
         outer_query_schema: Option<&DFSchema>,
+        ctes: &mut HashMap<String, LogicalPlan>,
     ) -> Result<Vec<Expr>> {
         projection
             .into_iter()
             .map(|expr| {
-                self.sql_select_to_rex(expr, plan, empty_from, outer_query_schema)
+                self.sql_select_to_rex(expr, plan, empty_from, outer_query_schema, ctes)
             })
             .flat_map(|result| match result {
                 Ok(vec) => vec.into_iter().map(Ok).collect(),
@@ -1151,7 +1157,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
     fn limit(&self, input: LogicalPlan, limit: Option<SQLExpr>) -> Result<LogicalPlan> {
         match limit {
             Some(limit_expr) => {
-                let n = match self.sql_to_rex(limit_expr, input.schema())? {
+                let n = match self.sql_to_rex(
+                    limit_expr,
+                    input.schema(),
+                    &mut HashMap::new(),
+                )? {
                     Expr::Literal(ScalarValue::Int64(Some(n))) => Ok(n as usize),
                     _ => Err(DataFusionError::Plan(
                         "Unexpected expression for LIMIT clause".to_string(),
@@ -1211,7 +1221,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
                 let field = schema.field(field_index - 1);
                 Expr::Column(field.qualified_column())
             }
-            e => self.sql_expr_to_logical_expr(e, schema)?,
+            e => self.sql_expr_to_logical_expr(e, schema, &mut HashMap::new())?,
         };
         Ok({
             let asc = asc.unwrap_or(true);
@@ -1267,6 +1277,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
         plan: &LogicalPlan,
         empty_from: bool,
         outer_query_schema: Option<&DFSchema>,
+        ctes: &mut HashMap<String, LogicalPlan>,
     ) -> Result<Vec<Expr>> {
         let input_schema = match outer_query_schema {
             Some(x) => {
@@ -1279,12 +1290,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
 
         match sql {
             SelectItem::UnnamedExpr(expr) => {
-                let expr = self.sql_to_rex(expr, &input_schema)?;
+                let expr = self.sql_to_rex(expr, &input_schema, ctes)?;
                 Ok(vec![normalize_col(expr, plan)?])
             }
             SelectItem::ExprWithAlias { expr, alias } => {
                 let expr = Alias(
-                    Box::new(self.sql_to_rex(expr, &input_schema)?),
+                    Box::new(self.sql_to_rex(expr, &input_schema, ctes)?),
                     normalize_ident(&alias),
                 );
                 Ok(vec![normalize_col(expr, plan)?])
@@ -1307,8 +1318,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
     }
 
     /// Generate a relational expression from a SQL expression
-    pub fn sql_to_rex(&self, sql: SQLExpr, schema: &DFSchema) -> Result<Expr> {
-        let mut expr = self.sql_expr_to_logical_expr(sql, schema)?;
+    pub fn sql_to_rex(
+        &self,
+        sql: SQLExpr,
+        schema: &DFSchema,
+        ctes: &mut HashMap<String, LogicalPlan>,
+    ) -> Result<Expr> {
+        let mut expr = self.sql_expr_to_logical_expr(sql, schema, ctes)?;
         expr = self.rewrite_partial_qualifier(expr, schema);
         self.validate_schema_satisfies_exprs(schema, &[expr.clone()])?;
         Ok(expr)
@@ -1346,18 +1362,19 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
         &self,
         sql: FunctionArg,
         schema: &DFSchema,
+        ctes: &mut HashMap<String, LogicalPlan>,
     ) -> Result<Expr> {
         match sql {
             FunctionArg::Named {
                 name: _,
                 arg: FunctionArgExpr::Expr(arg),
-            } => self.sql_expr_to_logical_expr(arg, schema),
+            } => self.sql_expr_to_logical_expr(arg, schema, ctes),
             FunctionArg::Named {
                 name: _,
                 arg: FunctionArgExpr::Wildcard,
             } => Ok(Expr::Wildcard),
             FunctionArg::Unnamed(FunctionArgExpr::Expr(arg)) => {
-                self.sql_expr_to_logical_expr(arg, schema)
+                self.sql_expr_to_logical_expr(arg, schema, ctes)
             }
             FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => Ok(Expr::Wildcard),
             _ => Err(DataFusionError::NotImplemented(format!(
@@ -1373,6 +1390,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
         op: BinaryOperator,
         right: SQLExpr,
         schema: &DFSchema,
+        ctes: &mut HashMap<String, LogicalPlan>,
     ) -> Result<Expr> {
         let operator = match op {
             BinaryOperator::Gt => Ok(Operator::Gt),
@@ -1404,9 +1422,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
         }?;
 
         Ok(Expr::BinaryExpr {
-            left: Box::new(self.sql_expr_to_logical_expr(left, schema)?),
+            left: Box::new(self.sql_expr_to_logical_expr(left, schema, ctes)?),
             op: operator,
-            right: Box::new(self.sql_expr_to_logical_expr(right, schema)?),
+            right: Box::new(self.sql_expr_to_logical_expr(right, schema, ctes)?),
         })
     }
 
@@ -1415,12 +1433,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
         op: UnaryOperator,
         expr: SQLExpr,
         schema: &DFSchema,
+        ctes: &mut HashMap<String, LogicalPlan>,
     ) -> Result<Expr> {
         match op {
             UnaryOperator::Not => Ok(Expr::Not(Box::new(
-                self.sql_expr_to_logical_expr(expr, schema)?,
+                self.sql_expr_to_logical_expr(expr, schema, ctes)?,
             ))),
-            UnaryOperator::Plus => Ok(self.sql_expr_to_logical_expr(expr, schema)?),
+            UnaryOperator::Plus => Ok(self.sql_expr_to_logical_expr(expr, schema, ctes)?),
             UnaryOperator::Minus => {
                 match expr {
                     // optimization: if it's a number literal, we apply the negative operator
@@ -1436,7 +1455,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
                             })?)),
                     },
                     // not a literal, apply negative operator on expression
-                    _ => Ok(Expr::Negative(Box::new(self.sql_expr_to_logical_expr(expr, schema)?))),
+                    _ => Ok(Expr::Negative(Box::new(self.sql_expr_to_logical_expr(expr, schema, ctes)?))),
                 }
             }
             _ => Err(DataFusionError::NotImplemented(format!(
@@ -1461,12 +1480,20 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
                             Ok(Expr::Literal(ScalarValue::Utf8(None)))
                         }
                         SQLExpr::Value(Value::Boolean(n)) => Ok(lit(n)),
-                        SQLExpr::UnaryOp { op, expr } => {
-                            self.parse_sql_unary_op(op, *expr, &schema)
-                        }
-                        SQLExpr::BinaryOp { left, op, right } => {
-                            self.parse_sql_binary_op(*left, op, *right, &schema)
-                        }
+                        SQLExpr::UnaryOp { op, expr } => self.parse_sql_unary_op(
+                            op,
+                            *expr,
+                            &schema,
+                            &mut HashMap::new(),
+                        ),
+                        SQLExpr::BinaryOp { left, op, right } => self
+                            .parse_sql_binary_op(
+                                *left,
+                                op,
+                                *right,
+                                &schema,
+                                &mut HashMap::new(),
+                            ),
                         other => Err(DataFusionError::NotImplemented(format!(
                             "Unsupported value {:?} in a values list expression",
                             other
@@ -1478,7 +1505,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
         LogicalPlanBuilder::values(values)?.build()
     }
 
-    fn sql_expr_to_logical_expr(&self, sql: SQLExpr, schema: &DFSchema) -> Result<Expr> {
+    fn sql_expr_to_logical_expr(
+        &self,
+        sql: SQLExpr,
+        schema: &DFSchema,
+        ctes: &mut HashMap<String, LogicalPlan>,
+    ) -> Result<Expr> {
         match sql {
             SQLExpr::Value(Value::Number(n, _)) => parse_sql_number(&n),
             SQLExpr::Value(Value::SingleQuotedString(ref s)) => Ok(lit(s.clone())),
@@ -1488,7 +1520,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
                 fun: BuiltinScalarFunction::DatePart,
                 args: vec![
                     Expr::Literal(ScalarValue::Utf8(Some(format!("{}", field)))),
-                    self.sql_expr_to_logical_expr(*expr, schema)?,
+                    self.sql_expr_to_logical_expr(*expr, schema, ctes)?,
                 ],
             }),
 
@@ -1583,20 +1615,20 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
                 else_result,
             } => {
                 let expr = if let Some(e) = operand {
-                    Some(Box::new(self.sql_expr_to_logical_expr(*e, schema)?))
+                    Some(Box::new(self.sql_expr_to_logical_expr(*e, schema, ctes)?))
                 } else {
                     None
                 };
                 let when_expr = conditions
                     .into_iter()
-                    .map(|e| self.sql_expr_to_logical_expr(e, schema))
+                    .map(|e| self.sql_expr_to_logical_expr(e, schema, ctes))
                     .collect::<Result<Vec<_>>>()?;
                 let then_expr = results
                     .into_iter()
-                    .map(|e| self.sql_expr_to_logical_expr(e, schema))
+                    .map(|e| self.sql_expr_to_logical_expr(e, schema, ctes))
                     .collect::<Result<Vec<_>>>()?;
                 let else_expr = if let Some(e) = else_result {
-                    Some(Box::new(self.sql_expr_to_logical_expr(*e, schema)?))
+                    Some(Box::new(self.sql_expr_to_logical_expr(*e, schema, ctes)?))
                 } else {
                     None
                 };
@@ -1616,7 +1648,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
                 expr,
                 data_type,
             } => Ok(Expr::Cast {
-                expr: Box::new(self.sql_expr_to_logical_expr(*expr, schema)?),
+                expr: Box::new(self.sql_expr_to_logical_expr(*expr, schema, ctes)?),
                 data_type: convert_data_type(&data_type)?,
             }),
 
@@ -1624,7 +1656,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
                 expr,
                 data_type,
             } => Ok(Expr::TryCast {
-                expr: Box::new(self.sql_expr_to_logical_expr(*expr, schema)?),
+                expr: Box::new(self.sql_expr_to_logical_expr(*expr, schema, ctes)?),
                 data_type: convert_data_type(&data_type)?,
             }),
 
@@ -1637,23 +1669,23 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
             }),
 
             SQLExpr::IsNull(expr) => Ok(Expr::IsNull(Box::new(
-                self.sql_expr_to_logical_expr(*expr, schema)?,
+                self.sql_expr_to_logical_expr(*expr, schema, ctes)?,
             ))),
 
             SQLExpr::IsNotNull(expr) => Ok(Expr::IsNotNull(Box::new(
-                self.sql_expr_to_logical_expr(*expr, schema)?,
+                self.sql_expr_to_logical_expr(*expr, schema, ctes)?,
             ))),
 
             SQLExpr::IsDistinctFrom(left, right) => Ok(Expr::BinaryExpr {
-                left: Box::new(self.sql_expr_to_logical_expr(*left, schema)?),
+                left: Box::new(self.sql_expr_to_logical_expr(*left, schema, ctes)?),
                 op: Operator::IsDistinctFrom,
-                right: Box::new(self.sql_expr_to_logical_expr(*right, schema)?),
+                right: Box::new(self.sql_expr_to_logical_expr(*right, schema, ctes)?),
             }),
 
             SQLExpr::IsNotDistinctFrom(left, right) => Ok(Expr::BinaryExpr {
-                left: Box::new(self.sql_expr_to_logical_expr(*left, schema)?),
+                left: Box::new(self.sql_expr_to_logical_expr(*left, schema, ctes)?),
                 op: Operator::IsNotDistinctFrom,
-                right: Box::new(self.sql_expr_to_logical_expr(*right, schema)?),
+                right: Box::new(self.sql_expr_to_logical_expr(*right, schema, ctes)?),
             }),
 
 
@@ -1661,8 +1693,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
                 // The AST for Exists does not support the NOT EXISTS case so it gets
                 // wrapped in a unary NOT
                 // https://github.com/sqlparser-rs/sqlparser-rs/issues/472
-                (&UnaryOperator::Not, &SQLExpr::Exists(ref subquery)) => self.parse_exists_subquery(subquery, true, schema),
-                _ => self.parse_sql_unary_op(op, *expr, schema)
+                (&UnaryOperator::Not, &SQLExpr::Exists(ref subquery)) => self.parse_exists_subquery(subquery, true, schema, ctes),
+                _ => self.parse_sql_unary_op(op, *expr, schema, ctes)
             }
 
             SQLExpr::Between {
@@ -1671,10 +1703,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
                 low,
                 high,
             } => Ok(Expr::Between {
-                expr: Box::new(self.sql_expr_to_logical_expr(*expr, schema)?),
+                expr: Box::new(self.sql_expr_to_logical_expr(*expr, schema, ctes)?),
                 negated,
-                low: Box::new(self.sql_expr_to_logical_expr(*low, schema)?),
-                high: Box::new(self.sql_expr_to_logical_expr(*high, schema)?),
+                low: Box::new(self.sql_expr_to_logical_expr(*low, schema, ctes)?),
+                high: Box::new(self.sql_expr_to_logical_expr(*high, schema, ctes)?),
             }),
 
             SQLExpr::InList {
@@ -1684,11 +1716,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
             } => {
                 let list_expr = list
                     .into_iter()
-                    .map(|e| self.sql_expr_to_logical_expr(e, schema))
+                    .map(|e| self.sql_expr_to_logical_expr(e, schema, ctes))
                     .collect::<Result<Vec<_>>>()?;
 
                 Ok(Expr::InList {
-                    expr: Box::new(self.sql_expr_to_logical_expr(*expr, schema)?),
+                    expr: Box::new(self.sql_expr_to_logical_expr(*expr, schema, ctes)?),
                     list: list_expr,
                     negated,
                 })
@@ -1698,7 +1730,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
                 left,
                 op,
                 right,
-            } => self.parse_sql_binary_op(*left, op, *right, schema),
+            } => self.parse_sql_binary_op(*left, op, *right, schema, ctes),
 
             #[cfg(feature = "unicode_expressions")]
             SQLExpr::Substring {
@@ -1708,24 +1740,24 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
             } => {
                 let args = match (substring_from, substring_for) {
                     (Some(from_expr), Some(for_expr)) => {
-                        let arg = self.sql_expr_to_logical_expr(*expr, schema)?;
+                        let arg = self.sql_expr_to_logical_expr(*expr, schema, ctes)?;
                         let from_logic =
-                            self.sql_expr_to_logical_expr(*from_expr, schema)?;
+                            self.sql_expr_to_logical_expr(*from_expr, schema, ctes)?;
                         let for_logic =
-                            self.sql_expr_to_logical_expr(*for_expr, schema)?;
+                            self.sql_expr_to_logical_expr(*for_expr, schema, ctes)?;
                         vec![arg, from_logic, for_logic]
                     }
                     (Some(from_expr), None) => {
-                        let arg = self.sql_expr_to_logical_expr(*expr, schema)?;
+                        let arg = self.sql_expr_to_logical_expr(*expr, schema, ctes)?;
                         let from_logic =
-                            self.sql_expr_to_logical_expr(*from_expr, schema)?;
+                            self.sql_expr_to_logical_expr(*from_expr, schema, ctes)?;
                         vec![arg, from_logic]
                     }
                     (None, Some(for_expr)) => {
-                        let arg = self.sql_expr_to_logical_expr(*expr, schema)?;
+                        let arg = self.sql_expr_to_logical_expr(*expr, schema, ctes)?;
                         let from_logic = Expr::Literal(ScalarValue::Int64(Some(1)));
                         let for_logic =
-                            self.sql_expr_to_logical_expr(*for_expr, schema)?;
+                            self.sql_expr_to_logical_expr(*for_expr, schema, ctes)?;
                         vec![arg, from_logic, for_logic]
                     }
                     (None, None) => {
@@ -1771,10 +1803,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
                     }
                     None => (BuiltinScalarFunction::Trim, None),
                 };
-                let arg = self.sql_expr_to_logical_expr(*expr, schema)?;
+                let arg = self.sql_expr_to_logical_expr(*expr, schema, ctes)?;
                 let args = match where_expr {
                     Some(to_trim) => {
-                        let to_trim = self.sql_expr_to_logical_expr(*to_trim, schema)?;
+                        let to_trim = self.sql_expr_to_logical_expr(*to_trim, schema, ctes)?;
                         vec![arg, to_trim]
                     }
                     None => vec![arg],
@@ -1803,7 +1835,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
                     let partition_by = window
                         .partition_by
                         .into_iter()
-                        .map(|e| self.sql_expr_to_logical_expr(e, schema))
+                        .map(|e| self.sql_expr_to_logical_expr(e, schema, ctes))
                         .collect::<Result<Vec<_>>>()?;
                     let order_by = window
                         .order_by
@@ -1893,13 +1925,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
                 }
             }
 
-            SQLExpr::Nested(e) => self.sql_expr_to_logical_expr(*e, schema),
+            SQLExpr::Nested(e) => self.sql_expr_to_logical_expr(*e, schema, ctes),
 
-            SQLExpr::Exists(subquery) => self.parse_exists_subquery(&subquery, false, schema),
+            SQLExpr::Exists(subquery) => self.parse_exists_subquery(&subquery, false, schema, ctes),
 
-            SQLExpr::InSubquery {  expr, subquery, negated } => self.parse_in_subquery(&expr, &subquery, negated, schema),
+            SQLExpr::InSubquery {  expr, subquery, negated } => self.parse_in_subquery(&expr, &subquery, negated, schema, ctes),
 
-            SQLExpr::Subquery(subquery) => self.parse_scalar_subquery(&subquery, schema),
+            SQLExpr::Subquery(subquery) => self.parse_scalar_subquery(&subquery, schema, ctes),
 
             _ => Err(DataFusionError::NotImplemented(format!(
                 "Unsupported ast node {:?} in sqltorel",
@@ -1913,12 +1945,15 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
         subquery: &Query,
         negated: bool,
         input_schema: &DFSchema,
+        ctes: &mut HashMap<String, LogicalPlan>,
     ) -> Result<Expr> {
         Ok(Expr::Exists {
             subquery: Subquery {
-                subquery: Arc::new(
-                    self.subquery_to_plan(subquery.clone(), input_schema)?,
-                ),
+                subquery: Arc::new(self.subquery_to_plan(
+                    subquery.clone(),
+                    ctes,
+                    input_schema,
+                )?),
             },
             negated,
         })
@@ -1930,13 +1965,16 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
         subquery: &Query,
         negated: bool,
         input_schema: &DFSchema,
+        ctes: &mut HashMap<String, LogicalPlan>,
     ) -> Result<Expr> {
         Ok(Expr::InSubquery {
-            expr: Box::new(self.sql_to_rex(expr.clone(), input_schema)?),
+            expr: Box::new(self.sql_to_rex(expr.clone(), input_schema, ctes)?),
             subquery: Subquery {
-                subquery: Arc::new(
-                    self.subquery_to_plan(subquery.clone(), input_schema)?,
-                ),
+                subquery: Arc::new(self.subquery_to_plan(
+                    subquery.clone(),
+                    ctes,
+                    input_schema,
+                )?),
             },
             negated,
         })
@@ -1946,9 +1984,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
         &self,
         subquery: &Query,
         input_schema: &DFSchema,
+        ctes: &mut HashMap<String, LogicalPlan>,
     ) -> Result<Expr> {
         Ok(Expr::ScalarSubquery(Subquery {
-            subquery: Arc::new(self.subquery_to_plan(subquery.clone(), input_schema)?),
+            subquery: Arc::new(self.subquery_to_plan(
+                subquery.clone(),
+                ctes,
+                input_schema,
+            )?),
         }))
     }
 
@@ -1958,7 +2001,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
         schema: &DFSchema,
     ) -> Result<Vec<Expr>> {
         args.into_iter()
-            .map(|a| self.sql_fn_arg_to_logical_expr(a, schema))
+            .map(|a| self.sql_fn_arg_to_logical_expr(a, schema, &mut HashMap::new()))
             .collect::<Result<Vec<Expr>>>()
     }
 
@@ -1977,13 +2020,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
                         Value::Number(_, _),
                     ))) => Ok(lit(1_u8)),
                     FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => Ok(lit(1_u8)),
-                    _ => self.sql_fn_arg_to_logical_expr(a, schema),
+                    _ => self.sql_fn_arg_to_logical_expr(a, schema, &mut HashMap::new()),
                 })
                 .collect::<Result<Vec<Expr>>>()?,
             aggregates::AggregateFunction::ApproxMedian => function
                 .args
                 .into_iter()
-                .map(|a| self.sql_fn_arg_to_logical_expr(a, schema))
+                .map(|a| self.sql_fn_arg_to_logical_expr(a, schema, &mut HashMap::new()))
                 .chain(iter::once(Ok(lit(0.5_f64))))
                 .collect::<Result<Vec<Expr>>>()?,
             _ => self.function_args_to_expr(function.args, schema)?,
@@ -2268,7 +2311,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
         let mut values = Vec::with_capacity(elements.len());
 
         for element in elements {
-            let value = self.sql_expr_to_logical_expr(element, schema)?;
+            let value =
+                self.sql_expr_to_logical_expr(element, schema, &mut HashMap::new())?;
             match value {
                 Expr::Literal(scalar) => {
                     values.push(scalar);
@@ -4495,4 +4539,22 @@ mod tests {
         );
         quick_test(sql, &expected);
     }
+
+    #[tokio::test]
+    async fn subquery_references_cte() {
+        let sql = "WITH \
+        cte AS (SELECT * FROM person) \
+        SELECT * FROM person WHERE EXISTS (SELECT * FROM cte WHERE id = person.id)";
+
+        let subquery = "Subquery: Projection: #cte.id, #cte.first_name, #cte.last_name, #cte.age, #cte.state, #cte.salary, #cte.birth_date, #cte.😀\
+        \n  Filter: #cte.id = #person.id\
+        \n    Projection: #person.id, #person.first_name, #person.last_name, #person.age, #person.state, #person.salary, #person.birth_date, #person.😀, alias=cte\
+        \n      TableScan: person projection=None";
+
+        let expected = format!("Projection: #person.id, #person.first_name, #person.last_name, #person.age, #person.state, #person.salary, #person.birth_date, #person.😀\
+        \n  Filter: EXISTS ({})\
+        \n    TableScan: person projection=None", subquery);
+
+        quick_test(sql, &expected)
+    }
 }