alamb commented on code in PR #2885:
URL: https://github.com/apache/arrow-datafusion/pull/2885#discussion_r925732051


##########
datafusion/common/src/error.rs:
##########
@@ -83,6 +83,30 @@ pub enum DataFusionError {
     #[cfg(feature = "jit")]
     /// Error occurs during code generation
     JITError(ModuleError),
+    /// Error with context
+    Context(String, Box<DataFusionError>),
+}
+
+#[macro_export]
+macro_rules! context {
+    ($desc:expr, $err:expr) => {
+        datafusion_common::DataFusionError::Context(
+            format!("{} at {}:{}", $desc, file!(), line!()),
+            Box::new($err),
+        )
+    };
+}
+
+#[macro_export]
+macro_rules! plan_err {

Review Comment:
   This is a nice improvement



##########
datafusion/core/tests/sql/subqueries.rs:
##########
@@ -0,0 +1,461 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use super::*;
+use crate::sql::execute_to_batches;
+use datafusion::assert_batches_eq;
+use datafusion::prelude::SessionContext;
+use log::debug;
+
+#[cfg(test)]
+#[ctor::ctor]
+fn init() {

Review Comment:
   TIL `ctor`. That is quite cool 👍 



##########
datafusion/core/tests/sql/mod.rs:
##########
@@ -499,6 +537,77 @@ async fn register_tpch_csv(ctx: &SessionContext, table: 
&str) -> Result<()> {
     Ok(())
 }
 
+async fn register_tpch_csv_data(
+    ctx: &SessionContext,
+    table_name: &str,
+    data: &str,
+) -> Result<()> {
+    let schema = Arc::new(get_tpch_table_schema(table_name));

Review Comment:
   What do you think about using `SessionContext::register_csv` here instead? 
Is there some reason we need to explicitly parse the CSV file and build in 
memory tables?
   
   
   
https://docs.rs/datafusion/10.0.0/datafusion/execution/context/struct.SessionContext.html#method.register_csv
   
   



##########
datafusion/common/src/error.rs:
##########
@@ -83,6 +83,30 @@ pub enum DataFusionError {
     #[cfg(feature = "jit")]
     /// Error occurs during code generation
     JITError(ModuleError),
+    /// Error with context

Review Comment:
   ```suggestion
       /// Error with additional context
   ```
   
   I am +0 on this change (as in I don't oppose it but I also would be fine 
without it) 
   
   I can see the usecase (to keep the structure of nested errors), however 
there are other more structured approaches to Error handling (e.g 
https://crates.io/crates/snafu or https://crates.io/crates/thiserror), but so 
far DataFusion just uses string handling (aka `format!("Error processing 
subqueries: {}", inner_error)`);
   
   Other libraries have quite sophisticated  `Error` hierarchies which 
theoretically allow the errors to be programmatically handled in different 
ways.  It adds non trivial maintenance and contribution overhead,  and I 
haven't seen anyone ask for this yet in DataFusion so I don't think it is worth 
adding yet. 



##########
datafusion/core/tests/sql/subqueries.rs:
##########
@@ -0,0 +1,461 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use super::*;
+use crate::sql::execute_to_batches;
+use datafusion::assert_batches_eq;
+use datafusion::prelude::SessionContext;
+use log::debug;
+
+#[cfg(test)]
+#[ctor::ctor]
+fn init() {
+    let _ = env_logger::try_init();
+}
+
+#[tokio::test]
+async fn correlated_recursive_scalar_subquery() -> Result<()> {
+    let ctx = SessionContext::new();
+    register_tpch_csv(&ctx, "customer").await?;
+    register_tpch_csv(&ctx, "orders").await?;
+    register_tpch_csv(&ctx, "lineitem").await?;
+
+    let sql = r#"
+select c_custkey from customer
+where c_acctbal < (
+    select sum(o_totalprice) from orders
+    where o_custkey = c_custkey
+    and o_totalprice < (
+            select sum(l_extendedprice) as price from lineitem where 
l_orderkey = o_orderkey
+    )
+) order by c_custkey;"#;
+
+    // assert plan
+    let plan = ctx.create_logical_plan(sql).unwrap();
+    debug!("input:\n{}", plan.display_indent());
+
+    let plan = ctx.optimize(&plan).unwrap();
+    let actual = format!("{}", plan.display_indent());
+    let expected = r#"Sort: #customer.c_custkey ASC NULLS LAST
+  Projection: #customer.c_custkey
+    Filter: #customer.c_acctbal < #__sq_2.__value
+      Inner Join: #customer.c_custkey = #__sq_2.o_custkey
+        TableScan: customer projection=[c_custkey, c_acctbal]
+        Projection: #orders.o_custkey, #SUM(orders.o_totalprice) AS __value, 
alias=__sq_2
+          Aggregate: groupBy=[[#orders.o_custkey]], 
aggr=[[SUM(#orders.o_totalprice)]]
+            Filter: #orders.o_totalprice < #__sq_1.__value
+              Inner Join: #orders.o_orderkey = #__sq_1.l_orderkey

Review Comment:
   👍  for this plan. Very cool @avantgardnerio 



##########
datafusion/core/tests/sql/subqueries.rs:
##########
@@ -0,0 +1,461 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use super::*;
+use crate::sql::execute_to_batches;
+use datafusion::assert_batches_eq;
+use datafusion::prelude::SessionContext;
+use log::debug;
+
+#[cfg(test)]
+#[ctor::ctor]
+fn init() {
+    let _ = env_logger::try_init();
+}
+
+#[tokio::test]
+async fn correlated_recursive_scalar_subquery() -> Result<()> {
+    let ctx = SessionContext::new();
+    register_tpch_csv(&ctx, "customer").await?;
+    register_tpch_csv(&ctx, "orders").await?;
+    register_tpch_csv(&ctx, "lineitem").await?;
+
+    let sql = r#"
+select c_custkey from customer
+where c_acctbal < (
+    select sum(o_totalprice) from orders
+    where o_custkey = c_custkey
+    and o_totalprice < (
+            select sum(l_extendedprice) as price from lineitem where 
l_orderkey = o_orderkey
+    )
+) order by c_custkey;"#;
+
+    // assert plan
+    let plan = ctx.create_logical_plan(sql).unwrap();
+    debug!("input:\n{}", plan.display_indent());
+
+    let plan = ctx.optimize(&plan).unwrap();
+    let actual = format!("{}", plan.display_indent());
+    let expected = r#"Sort: #customer.c_custkey ASC NULLS LAST
+  Projection: #customer.c_custkey
+    Filter: #customer.c_acctbal < #__sq_2.__value
+      Inner Join: #customer.c_custkey = #__sq_2.o_custkey
+        TableScan: customer projection=[c_custkey, c_acctbal]
+        Projection: #orders.o_custkey, #SUM(orders.o_totalprice) AS __value, 
alias=__sq_2
+          Aggregate: groupBy=[[#orders.o_custkey]], 
aggr=[[SUM(#orders.o_totalprice)]]
+            Filter: #orders.o_totalprice < #__sq_1.__value
+              Inner Join: #orders.o_orderkey = #__sq_1.l_orderkey
+                TableScan: orders projection=[o_orderkey, o_custkey, 
o_totalprice]
+                Projection: #lineitem.l_orderkey, 
#SUM(lineitem.l_extendedprice) AS price AS __value, alias=__sq_1
+                  Aggregate: groupBy=[[#lineitem.l_orderkey]], 
aggr=[[SUM(#lineitem.l_extendedprice)]]
+                    TableScan: lineitem projection=[l_orderkey, 
l_extendedprice]"#
+        .to_string();
+    assert_eq!(actual, expected);
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn correlated_where_in() -> Result<()> {
+    let ctx = SessionContext::new();
+    register_tpch_csv(&ctx, "orders").await?;
+    register_tpch_csv(&ctx, "lineitem").await?;
+    register_tpch_csv(&ctx, "partsupp").await?;
+
+    let sql = r#"select o_orderkey from orders
+inner join lineitem on o_orderkey = l_orderkey
+where l_partkey in ( select ps_partkey from partsupp where ps_suppkey = 
l_suppkey );"#;
+
+    // assert plan
+    let plan = ctx.create_logical_plan(sql).unwrap();
+    let plan = ctx.optimize(&plan).unwrap();
+    let actual = format!("{}", plan.display_indent());
+    let expected = r#"Projection: #orders.o_orderkey
+  Semi Join: #lineitem.l_partkey = #__sq_1.ps_partkey, #lineitem.l_suppkey = 
#__sq_1.ps_suppkey
+    Inner Join: #orders.o_orderkey = #lineitem.l_orderkey
+      TableScan: orders projection=[o_orderkey]
+      TableScan: lineitem projection=[l_orderkey, l_partkey, l_suppkey]
+    Projection: #partsupp.ps_partkey AS ps_partkey, #partsupp.ps_suppkey AS 
ps_suppkey, alias=__sq_1
+      TableScan: partsupp projection=[ps_partkey, ps_suppkey]"#
+        .to_string();
+    assert_eq!(actual, expected);
+
+    // assert data
+    let results = execute_to_batches(&ctx, sql).await;
+    let expected = vec!["++", "++"];
+    assert_batches_eq!(expected, &results);
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn tpch_q2_correlated() -> Result<()> {
+    let ctx = SessionContext::new();
+    register_tpch_csv(&ctx, "part").await?;
+    register_tpch_csv(&ctx, "supplier").await?;
+    register_tpch_csv(&ctx, "partsupp").await?;
+    register_tpch_csv(&ctx, "nation").await?;
+    register_tpch_csv(&ctx, "region").await?;
+
+    let sql = r#"select s_acctbal, s_name, n_name, p_partkey, p_mfgr, 
s_address, s_phone, s_comment
+from part, supplier, partsupp, nation, region
+where p_partkey = ps_partkey and s_suppkey = ps_suppkey and p_size = 15 and 
p_type like '%BRASS'
+    and s_nationkey = n_nationkey and n_regionkey = r_regionkey and r_name = 
'EUROPE'
+    and ps_supplycost = (
+        select min(ps_supplycost) from partsupp, supplier, nation, region
+        where p_partkey = ps_partkey and s_suppkey = ps_suppkey and 
s_nationkey = n_nationkey
+        and n_regionkey = r_regionkey and r_name = 'EUROPE'
+    )
+order by s_acctbal desc, n_name, s_name, p_partkey;"#;
+
+    // assert plan
+    let plan = ctx.create_logical_plan(sql).unwrap();
+    let plan = ctx.optimize(&plan).unwrap();
+    let actual = format!("{}", plan.display_indent());
+    let expected = r#"Sort: #supplier.s_acctbal DESC NULLS FIRST, 
#nation.n_name ASC NULLS LAST, #supplier.s_name ASC NULLS LAST, #part.p_partkey 
ASC NULLS LAST
+  Projection: #supplier.s_acctbal, #supplier.s_name, #nation.n_name, 
#part.p_partkey, #part.p_mfgr, #supplier.s_address, #supplier.s_phone, 
#supplier.s_comment
+    Filter: #partsupp.ps_supplycost = #__sq_1.__value

Review Comment:
   So cool to see this query planned ❤️ 



##########
datafusion/expr/src/expr.rs:
##########
@@ -452,6 +452,13 @@ impl Expr {
             nulls_first,
         }
     }
+
+    pub fn try_into_col(&self) -> Result<Column> {
+        match self {
+            Expr::Column(it) => Ok(it.clone()),
+            _ => plan_err!("Could not coerce into Column!"),

Review Comment:
   it might help to include `self` in this error so it is clear what expr is 
causing issues
   
   like:
   
   ```suggestion
               _ => plan_err!("Could not coerce '{}' into Column", self),
   ```



##########
datafusion/core/tests/sql/subqueries.rs:
##########
@@ -0,0 +1,461 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use super::*;
+use crate::sql::execute_to_batches;
+use datafusion::assert_batches_eq;
+use datafusion::prelude::SessionContext;
+use log::debug;
+
+#[cfg(test)]
+#[ctor::ctor]
+fn init() {
+    let _ = env_logger::try_init();
+}
+
+#[tokio::test]
+async fn correlated_recursive_scalar_subquery() -> Result<()> {
+    let ctx = SessionContext::new();
+    register_tpch_csv(&ctx, "customer").await?;
+    register_tpch_csv(&ctx, "orders").await?;
+    register_tpch_csv(&ctx, "lineitem").await?;
+
+    let sql = r#"
+select c_custkey from customer
+where c_acctbal < (
+    select sum(o_totalprice) from orders
+    where o_custkey = c_custkey
+    and o_totalprice < (
+            select sum(l_extendedprice) as price from lineitem where 
l_orderkey = o_orderkey
+    )
+) order by c_custkey;"#;
+
+    // assert plan
+    let plan = ctx.create_logical_plan(sql).unwrap();
+    debug!("input:\n{}", plan.display_indent());
+
+    let plan = ctx.optimize(&plan).unwrap();
+    let actual = format!("{}", plan.display_indent());
+    let expected = r#"Sort: #customer.c_custkey ASC NULLS LAST
+  Projection: #customer.c_custkey
+    Filter: #customer.c_acctbal < #__sq_2.__value
+      Inner Join: #customer.c_custkey = #__sq_2.o_custkey
+        TableScan: customer projection=[c_custkey, c_acctbal]
+        Projection: #orders.o_custkey, #SUM(orders.o_totalprice) AS __value, 
alias=__sq_2
+          Aggregate: groupBy=[[#orders.o_custkey]], 
aggr=[[SUM(#orders.o_totalprice)]]
+            Filter: #orders.o_totalprice < #__sq_1.__value
+              Inner Join: #orders.o_orderkey = #__sq_1.l_orderkey
+                TableScan: orders projection=[o_orderkey, o_custkey, 
o_totalprice]
+                Projection: #lineitem.l_orderkey, 
#SUM(lineitem.l_extendedprice) AS price AS __value, alias=__sq_1
+                  Aggregate: groupBy=[[#lineitem.l_orderkey]], 
aggr=[[SUM(#lineitem.l_extendedprice)]]
+                    TableScan: lineitem projection=[l_orderkey, 
l_extendedprice]"#
+        .to_string();
+    assert_eq!(actual, expected);
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn correlated_where_in() -> Result<()> {
+    let ctx = SessionContext::new();
+    register_tpch_csv(&ctx, "orders").await?;
+    register_tpch_csv(&ctx, "lineitem").await?;
+    register_tpch_csv(&ctx, "partsupp").await?;
+
+    let sql = r#"select o_orderkey from orders
+inner join lineitem on o_orderkey = l_orderkey
+where l_partkey in ( select ps_partkey from partsupp where ps_suppkey = 
l_suppkey );"#;
+
+    // assert plan
+    let plan = ctx.create_logical_plan(sql).unwrap();
+    let plan = ctx.optimize(&plan).unwrap();
+    let actual = format!("{}", plan.display_indent());
+    let expected = r#"Projection: #orders.o_orderkey
+  Semi Join: #lineitem.l_partkey = #__sq_1.ps_partkey, #lineitem.l_suppkey = 
#__sq_1.ps_suppkey
+    Inner Join: #orders.o_orderkey = #lineitem.l_orderkey
+      TableScan: orders projection=[o_orderkey]
+      TableScan: lineitem projection=[l_orderkey, l_partkey, l_suppkey]
+    Projection: #partsupp.ps_partkey AS ps_partkey, #partsupp.ps_suppkey AS 
ps_suppkey, alias=__sq_1
+      TableScan: partsupp projection=[ps_partkey, ps_suppkey]"#
+        .to_string();
+    assert_eq!(actual, expected);
+
+    // assert data
+    let results = execute_to_batches(&ctx, sql).await;
+    let expected = vec!["++", "++"];
+    assert_batches_eq!(expected, &results);
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn tpch_q2_correlated() -> Result<()> {
+    let ctx = SessionContext::new();
+    register_tpch_csv(&ctx, "part").await?;
+    register_tpch_csv(&ctx, "supplier").await?;
+    register_tpch_csv(&ctx, "partsupp").await?;
+    register_tpch_csv(&ctx, "nation").await?;
+    register_tpch_csv(&ctx, "region").await?;
+
+    let sql = r#"select s_acctbal, s_name, n_name, p_partkey, p_mfgr, 
s_address, s_phone, s_comment
+from part, supplier, partsupp, nation, region
+where p_partkey = ps_partkey and s_suppkey = ps_suppkey and p_size = 15 and 
p_type like '%BRASS'
+    and s_nationkey = n_nationkey and n_regionkey = r_regionkey and r_name = 
'EUROPE'
+    and ps_supplycost = (
+        select min(ps_supplycost) from partsupp, supplier, nation, region
+        where p_partkey = ps_partkey and s_suppkey = ps_suppkey and 
s_nationkey = n_nationkey
+        and n_regionkey = r_regionkey and r_name = 'EUROPE'
+    )
+order by s_acctbal desc, n_name, s_name, p_partkey;"#;
+
+    // assert plan
+    let plan = ctx.create_logical_plan(sql).unwrap();
+    let plan = ctx.optimize(&plan).unwrap();
+    let actual = format!("{}", plan.display_indent());
+    let expected = r#"Sort: #supplier.s_acctbal DESC NULLS FIRST, 
#nation.n_name ASC NULLS LAST, #supplier.s_name ASC NULLS LAST, #part.p_partkey 
ASC NULLS LAST
+  Projection: #supplier.s_acctbal, #supplier.s_name, #nation.n_name, 
#part.p_partkey, #part.p_mfgr, #supplier.s_address, #supplier.s_phone, 
#supplier.s_comment
+    Filter: #partsupp.ps_supplycost = #__sq_1.__value
+      Inner Join: #part.p_partkey = #__sq_1.ps_partkey
+        Inner Join: #nation.n_regionkey = #region.r_regionkey
+          Inner Join: #supplier.s_nationkey = #nation.n_nationkey
+            Inner Join: #partsupp.ps_suppkey = #supplier.s_suppkey
+              Inner Join: #part.p_partkey = #partsupp.ps_partkey
+                Filter: #part.p_size = Int64(15) AND #part.p_type LIKE 
Utf8("%BRASS")
+                  TableScan: part projection=[p_partkey, p_mfgr, p_type, 
p_size], partial_filters=[#part.p_size = Int64(15), #part.p_type LIKE 
Utf8("%BRASS")]
+                TableScan: partsupp projection=[ps_partkey, ps_suppkey, 
ps_supplycost]
+              TableScan: supplier projection=[s_suppkey, s_name, s_address, 
s_nationkey, s_phone, s_acctbal, s_comment]
+            TableScan: nation projection=[n_nationkey, n_name, n_regionkey]
+          Filter: #region.r_name = Utf8("EUROPE")
+            TableScan: region projection=[r_regionkey, r_name], 
partial_filters=[#region.r_name = Utf8("EUROPE")]
+        Projection: #partsupp.ps_partkey, #MIN(partsupp.ps_supplycost) AS 
__value, alias=__sq_1
+          Aggregate: groupBy=[[#partsupp.ps_partkey]], 
aggr=[[MIN(#partsupp.ps_supplycost)]]
+            Inner Join: #nation.n_regionkey = #region.r_regionkey
+              Inner Join: #supplier.s_nationkey = #nation.n_nationkey
+                Inner Join: #partsupp.ps_suppkey = #supplier.s_suppkey
+                  TableScan: partsupp projection=[ps_partkey, ps_suppkey, 
ps_supplycost]
+                  TableScan: supplier projection=[s_suppkey, s_name, 
s_address, s_nationkey, s_phone, s_acctbal, s_comment]
+                TableScan: nation projection=[n_nationkey, n_name, n_regionkey]
+              Filter: #region.r_name = Utf8("EUROPE")
+                TableScan: region projection=[r_regionkey, r_name], 
partial_filters=[#region.r_name = Utf8("EUROPE")]"#
+        .to_string();
+    assert_eq!(actual, expected);
+
+    // assert data
+    let results = execute_to_batches(&ctx, sql).await;
+    let expected = vec!["++", "++"];
+    assert_batches_eq!(expected, &results);
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn tpch_q4_correlated() -> Result<()> {
+    let orders = r#"4,13678,O,53829.87,1995-10-11,5-LOW,Clerk#000000124,0,
+35,12760,O,192885.43,1995-10-23,4-NOT SPECIFIED,Clerk#000000259,0,
+65,1627,P,99763.79,1995-03-18,1-URGENT,Clerk#000000632,0,
+"#;
+    let lineitems = 
r#"4,8804,579,1,30,51384,0.03,0.08,N,O,1996-01-10,1995-12-14,1996-01-18,DELIVER 
IN PERSON,REG AIR,
+35,45,296,1,24,22680.96,0.02,0,N,O,1996-02-21,1996-01-03,1996-03-18,TAKE BACK 
RETURN,FOB,
+65,5970,481,1,26,48775.22,0.03,0.03,A,F,1995-04-20,1995-04-25,1995-05-13,NONE,TRUCK,
+"#;
+
+    let ctx = SessionContext::new();
+    register_tpch_csv_data(&ctx, "orders", orders).await?;
+    register_tpch_csv_data(&ctx, "lineitem", lineitems).await?;
+
+    let sql = r#"
+        select o_orderpriority, count(*) as order_count
+        from orders
+        where exists (
+            select * from lineitem where l_orderkey = o_orderkey and 
l_commitdate < l_receiptdate)
+        group by o_orderpriority
+        order by o_orderpriority;
+        "#;
+
+    // assert plan
+    let plan = ctx.create_logical_plan(sql).unwrap();
+    let plan = ctx.optimize(&plan).unwrap();
+    let actual = format!("{}", plan.display_indent());
+    let expected = r#"Sort: #orders.o_orderpriority ASC NULLS LAST
+  Projection: #orders.o_orderpriority, #COUNT(UInt8(1)) AS order_count
+    Aggregate: groupBy=[[#orders.o_orderpriority]], aggr=[[COUNT(UInt8(1))]]
+      Semi Join: #orders.o_orderkey = #lineitem.l_orderkey
+        TableScan: orders projection=[o_orderkey, o_orderpriority]
+        Filter: #lineitem.l_commitdate < #lineitem.l_receiptdate
+          TableScan: lineitem projection=[l_orderkey, l_commitdate, 
l_receiptdate]"#
+        .to_string();
+    assert_eq!(actual, expected);
+
+    // assert data
+    let results = execute_to_batches(&ctx, sql).await;
+    let expected = vec![
+        "+-----------------+-------------+",
+        "| o_orderpriority | order_count |",
+        "+-----------------+-------------+",
+        "| 1-URGENT        | 1           |",
+        "| 4-NOT SPECIFIED | 1           |",
+        "| 5-LOW           | 1           |",
+        "+-----------------+-------------+",
+    ];
+    assert_batches_eq!(expected, &results);
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn tpch_q17_correlated() -> Result<()> {
+    let parts = r#"63700,goldenrod lavender spring chocolate 
lace,Manufacturer#1,Brand#23,PROMO BURNISHED COPPER,7,MED BOX,901.00,ly. slyly 
ironi
+"#;
+    let lineitems = 
r#"1,63700,7311,2,36.0,45983.16,0.09,0.06,N,O,1996-04-12,1996-02-28,1996-04-20,TAKE
 BACK RETURN,MAIL,ly final dependencies: slyly bold
+1,63700,3701,3,1.0,13309.6,0.1,0.02,N,O,1996-01-29,1996-03-05,1996-01-31,TAKE 
BACK RETURN,REG AIR,"riously. regular, express dep"
+"#;
+
+    let ctx = SessionContext::new();
+    register_tpch_csv_data(&ctx, "part", parts).await?;
+    register_tpch_csv_data(&ctx, "lineitem", lineitems).await?;
+
+    let sql = r#"select sum(l_extendedprice) / 7.0 as avg_yearly
+        from lineitem, part
+        where p_partkey = l_partkey and p_brand = 'Brand#23' and p_container = 
'MED BOX'
+        and l_quantity < (
+            select 0.2 * avg(l_quantity)
+            from lineitem where l_partkey = p_partkey
+        );"#;
+
+    // assert plan
+    let plan = ctx
+        .create_logical_plan(sql)
+        .map_err(|e| format!("{:?} at {}", e, "error"))
+        .unwrap();
+    println!("before:\n{}", plan.display_indent());
+    let plan = ctx
+        .optimize(&plan)
+        .map_err(|e| format!("{:?} at {}", e, "error"))
+        .unwrap();
+    let actual = format!("{}", plan.display_indent());
+    let expected = r#"Projection: #SUM(lineitem.l_extendedprice) / Float64(7) 
AS avg_yearly
+  Aggregate: groupBy=[[]], aggr=[[SUM(#lineitem.l_extendedprice)]]
+    Filter: #lineitem.l_quantity < #__sq_1.__value
+      Inner Join: #part.p_partkey = #__sq_1.l_partkey
+        Inner Join: #lineitem.l_partkey = #part.p_partkey
+          TableScan: lineitem projection=[l_partkey, l_quantity, 
l_extendedprice]
+          Filter: #part.p_brand = Utf8("Brand#23") AND #part.p_container = 
Utf8("MED BOX")
+            TableScan: part projection=[p_partkey, p_brand, p_container]
+        Projection: #lineitem.l_partkey, Float64(0.2) * 
#AVG(lineitem.l_quantity) AS __value, alias=__sq_1
+          Aggregate: groupBy=[[#lineitem.l_partkey]], 
aggr=[[AVG(#lineitem.l_quantity)]]
+            TableScan: lineitem projection=[l_partkey, l_quantity, 
l_extendedprice]"#
+        .to_string();
+    assert_eq!(actual, expected);
+
+    // assert data
+    let results = execute_to_batches(&ctx, sql).await;
+    let expected = vec![
+        "+--------------------+",
+        "| avg_yearly         |",
+        "+--------------------+",
+        "| 1901.3714285714286 |",
+        "+--------------------+",
+    ];
+    assert_batches_eq!(expected, &results);
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn tpch_q20_correlated() -> Result<()> {
+    let ctx = SessionContext::new();
+    register_tpch_csv(&ctx, "supplier").await?;
+    register_tpch_csv(&ctx, "nation").await?;
+    register_tpch_csv(&ctx, "partsupp").await?;
+    register_tpch_csv(&ctx, "part").await?;
+    register_tpch_csv(&ctx, "lineitem").await?;
+
+    let sql = r#"select s_name, s_address
+from supplier, nation
+where s_suppkey in (
+    select ps_suppkey from partsupp
+    where ps_partkey in ( select p_partkey from part where p_name like 
'forest%' )
+      and ps_availqty > ( select 0.5 * sum(l_quantity) from lineitem
+        where l_partkey = ps_partkey and l_suppkey = ps_suppkey and l_shipdate 
>= date '1994-01-01'
+    )
+)
+and s_nationkey = n_nationkey and n_name = 'CANADA'
+order by s_name;
+"#;
+
+    // assert plan
+    let plan = ctx
+        .create_logical_plan(sql)
+        .map_err(|e| format!("{:?} at {}", e, "error"))
+        .unwrap();
+    let plan = ctx
+        .optimize(&plan)
+        .map_err(|e| format!("{:?} at {}", e, "error"))
+        .unwrap();
+    let actual = format!("{}", plan.display_indent());
+    let expected = r#"Sort: #supplier.s_name ASC NULLS LAST
+  Projection: #supplier.s_name, #supplier.s_address
+    Semi Join: #supplier.s_suppkey = #__sq_2.ps_suppkey
+      Inner Join: #supplier.s_nationkey = #nation.n_nationkey
+        TableScan: supplier projection=[s_suppkey, s_name, s_address, 
s_nationkey]
+        Filter: #nation.n_name = Utf8("CANADA")
+          TableScan: nation projection=[n_nationkey, n_name], 
partial_filters=[#nation.n_name = Utf8("CANADA")]
+      Projection: #partsupp.ps_suppkey AS ps_suppkey, alias=__sq_2
+        Filter: #partsupp.ps_availqty > #__sq_3.__value
+          Inner Join: #partsupp.ps_partkey = #__sq_3.l_partkey, 
#partsupp.ps_suppkey = #__sq_3.l_suppkey
+            Semi Join: #partsupp.ps_partkey = #__sq_1.p_partkey
+              TableScan: partsupp projection=[ps_partkey, ps_suppkey, 
ps_availqty]
+              Projection: #part.p_partkey AS p_partkey, alias=__sq_1
+                Filter: #part.p_name LIKE Utf8("forest%")
+                  TableScan: part projection=[p_partkey, p_name], 
partial_filters=[#part.p_name LIKE Utf8("forest%")]
+            Projection: #lineitem.l_partkey, #lineitem.l_suppkey, Float64(0.5) 
* #SUM(lineitem.l_quantity) AS __value, alias=__sq_3
+              Aggregate: groupBy=[[#lineitem.l_partkey, #lineitem.l_suppkey]], 
aggr=[[SUM(#lineitem.l_quantity)]]
+                Filter: #lineitem.l_shipdate >= CAST(Utf8("1994-01-01") AS 
Date32)
+                  TableScan: lineitem projection=[l_partkey, l_suppkey, 
l_quantity, l_shipdate], partial_filters=[#lineitem.l_shipdate >= 
CAST(Utf8("1994-01-01") AS Date32)]"#
+        .to_string();
+    assert_eq!(actual, expected);
+
+    // assert data
+    let results = execute_to_batches(&ctx, sql).await;
+    let expected = vec!["++", "++"];
+    assert_batches_eq!(expected, &results);
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn tpch_q22_correlated() -> Result<()> {
+    let ctx = SessionContext::new();
+    register_tpch_csv(&ctx, "customer").await?;
+    register_tpch_csv(&ctx, "orders").await?;
+
+    let sql = r#"select cntrycode, count(*) as numcust, sum(c_acctbal) as 
totacctbal
+from (
+        select substring(c_phone from 1 for 2) as cntrycode, c_acctbal from 
customer
+        where substring(c_phone from 1 for 2) in ('13', '31', '23', '29', 
'30', '18', '17')
+          and c_acctbal > (
+            select avg(c_acctbal) from customer where c_acctbal > 0.00
+              and substring(c_phone from 1 for 2) in ('13', '31', '23', '29', 
'30', '18', '17')
+        )
+          and not exists ( select * from orders where o_custkey = c_custkey )
+    ) as custsale
+group by cntrycode
+order by cntrycode;"#;
+
+    // assert plan
+    let plan = ctx
+        .create_logical_plan(sql)
+        .map_err(|e| format!("{:?} at {}", e, "error"))
+        .unwrap();
+    let plan = ctx
+        .optimize(&plan)
+        .map_err(|e| format!("{:?} at {}", e, "error"))
+        .unwrap();
+    let actual = format!("{}", plan.display_indent());
+    let expected = r#"Sort: #custsale.cntrycode ASC NULLS LAST
+  Projection: #custsale.cntrycode, #COUNT(UInt8(1)) AS numcust, 
#SUM(custsale.c_acctbal) AS totacctbal
+    Aggregate: groupBy=[[#custsale.cntrycode]], aggr=[[COUNT(UInt8(1)), 
SUM(#custsale.c_acctbal)]]
+      Projection: #custsale.cntrycode, #custsale.c_acctbal, alias=custsale
+        Projection: substr(#customer.c_phone, Int64(1), Int64(2)) AS 
cntrycode, #customer.c_acctbal, alias=custsale
+          Filter: #customer.c_acctbal > #__sq_1.__value
+            CrossJoin:
+              Anti Join: #customer.c_custkey = #orders.o_custkey

Review Comment:
   looks good



##########
datafusion/core/tests/sql/subqueries.rs:
##########
@@ -0,0 +1,461 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use super::*;
+use crate::sql::execute_to_batches;
+use datafusion::assert_batches_eq;
+use datafusion::prelude::SessionContext;
+use log::debug;
+
+#[cfg(test)]
+#[ctor::ctor]
+fn init() {
+    let _ = env_logger::try_init();
+}
+
+#[tokio::test]
+async fn correlated_recursive_scalar_subquery() -> Result<()> {
+    let ctx = SessionContext::new();
+    register_tpch_csv(&ctx, "customer").await?;
+    register_tpch_csv(&ctx, "orders").await?;
+    register_tpch_csv(&ctx, "lineitem").await?;
+
+    let sql = r#"
+select c_custkey from customer
+where c_acctbal < (
+    select sum(o_totalprice) from orders
+    where o_custkey = c_custkey
+    and o_totalprice < (
+            select sum(l_extendedprice) as price from lineitem where 
l_orderkey = o_orderkey
+    )
+) order by c_custkey;"#;
+
+    // assert plan
+    let plan = ctx.create_logical_plan(sql).unwrap();
+    debug!("input:\n{}", plan.display_indent());
+
+    let plan = ctx.optimize(&plan).unwrap();
+    let actual = format!("{}", plan.display_indent());
+    let expected = r#"Sort: #customer.c_custkey ASC NULLS LAST
+  Projection: #customer.c_custkey
+    Filter: #customer.c_acctbal < #__sq_2.__value
+      Inner Join: #customer.c_custkey = #__sq_2.o_custkey
+        TableScan: customer projection=[c_custkey, c_acctbal]
+        Projection: #orders.o_custkey, #SUM(orders.o_totalprice) AS __value, 
alias=__sq_2
+          Aggregate: groupBy=[[#orders.o_custkey]], 
aggr=[[SUM(#orders.o_totalprice)]]
+            Filter: #orders.o_totalprice < #__sq_1.__value
+              Inner Join: #orders.o_orderkey = #__sq_1.l_orderkey
+                TableScan: orders projection=[o_orderkey, o_custkey, 
o_totalprice]
+                Projection: #lineitem.l_orderkey, 
#SUM(lineitem.l_extendedprice) AS price AS __value, alias=__sq_1
+                  Aggregate: groupBy=[[#lineitem.l_orderkey]], 
aggr=[[SUM(#lineitem.l_extendedprice)]]
+                    TableScan: lineitem projection=[l_orderkey, 
l_extendedprice]"#
+        .to_string();
+    assert_eq!(actual, expected);
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn correlated_where_in() -> Result<()> {
+    let ctx = SessionContext::new();
+    register_tpch_csv(&ctx, "orders").await?;
+    register_tpch_csv(&ctx, "lineitem").await?;
+    register_tpch_csv(&ctx, "partsupp").await?;
+
+    let sql = r#"select o_orderkey from orders
+inner join lineitem on o_orderkey = l_orderkey
+where l_partkey in ( select ps_partkey from partsupp where ps_suppkey = 
l_suppkey );"#;
+
+    // assert plan
+    let plan = ctx.create_logical_plan(sql).unwrap();
+    let plan = ctx.optimize(&plan).unwrap();
+    let actual = format!("{}", plan.display_indent());
+    let expected = r#"Projection: #orders.o_orderkey
+  Semi Join: #lineitem.l_partkey = #__sq_1.ps_partkey, #lineitem.l_suppkey = 
#__sq_1.ps_suppkey
+    Inner Join: #orders.o_orderkey = #lineitem.l_orderkey
+      TableScan: orders projection=[o_orderkey]
+      TableScan: lineitem projection=[l_orderkey, l_partkey, l_suppkey]
+    Projection: #partsupp.ps_partkey AS ps_partkey, #partsupp.ps_suppkey AS 
ps_suppkey, alias=__sq_1
+      TableScan: partsupp projection=[ps_partkey, ps_suppkey]"#
+        .to_string();
+    assert_eq!(actual, expected);
+
+    // assert data
+    let results = execute_to_batches(&ctx, sql).await;
+    let expected = vec!["++", "++"];
+    assert_batches_eq!(expected, &results);
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn tpch_q2_correlated() -> Result<()> {
+    let ctx = SessionContext::new();
+    register_tpch_csv(&ctx, "part").await?;
+    register_tpch_csv(&ctx, "supplier").await?;
+    register_tpch_csv(&ctx, "partsupp").await?;
+    register_tpch_csv(&ctx, "nation").await?;
+    register_tpch_csv(&ctx, "region").await?;
+
+    let sql = r#"select s_acctbal, s_name, n_name, p_partkey, p_mfgr, 
s_address, s_phone, s_comment
+from part, supplier, partsupp, nation, region
+where p_partkey = ps_partkey and s_suppkey = ps_suppkey and p_size = 15 and 
p_type like '%BRASS'
+    and s_nationkey = n_nationkey and n_regionkey = r_regionkey and r_name = 
'EUROPE'
+    and ps_supplycost = (
+        select min(ps_supplycost) from partsupp, supplier, nation, region
+        where p_partkey = ps_partkey and s_suppkey = ps_suppkey and 
s_nationkey = n_nationkey
+        and n_regionkey = r_regionkey and r_name = 'EUROPE'
+    )
+order by s_acctbal desc, n_name, s_name, p_partkey;"#;
+
+    // assert plan
+    let plan = ctx.create_logical_plan(sql).unwrap();
+    let plan = ctx.optimize(&plan).unwrap();
+    let actual = format!("{}", plan.display_indent());
+    let expected = r#"Sort: #supplier.s_acctbal DESC NULLS FIRST, 
#nation.n_name ASC NULLS LAST, #supplier.s_name ASC NULLS LAST, #part.p_partkey 
ASC NULLS LAST
+  Projection: #supplier.s_acctbal, #supplier.s_name, #nation.n_name, 
#part.p_partkey, #part.p_mfgr, #supplier.s_address, #supplier.s_phone, 
#supplier.s_comment
+    Filter: #partsupp.ps_supplycost = #__sq_1.__value
+      Inner Join: #part.p_partkey = #__sq_1.ps_partkey
+        Inner Join: #nation.n_regionkey = #region.r_regionkey
+          Inner Join: #supplier.s_nationkey = #nation.n_nationkey
+            Inner Join: #partsupp.ps_suppkey = #supplier.s_suppkey
+              Inner Join: #part.p_partkey = #partsupp.ps_partkey
+                Filter: #part.p_size = Int64(15) AND #part.p_type LIKE 
Utf8("%BRASS")
+                  TableScan: part projection=[p_partkey, p_mfgr, p_type, 
p_size], partial_filters=[#part.p_size = Int64(15), #part.p_type LIKE 
Utf8("%BRASS")]
+                TableScan: partsupp projection=[ps_partkey, ps_suppkey, 
ps_supplycost]
+              TableScan: supplier projection=[s_suppkey, s_name, s_address, 
s_nationkey, s_phone, s_acctbal, s_comment]
+            TableScan: nation projection=[n_nationkey, n_name, n_regionkey]
+          Filter: #region.r_name = Utf8("EUROPE")
+            TableScan: region projection=[r_regionkey, r_name], 
partial_filters=[#region.r_name = Utf8("EUROPE")]
+        Projection: #partsupp.ps_partkey, #MIN(partsupp.ps_supplycost) AS 
__value, alias=__sq_1
+          Aggregate: groupBy=[[#partsupp.ps_partkey]], 
aggr=[[MIN(#partsupp.ps_supplycost)]]
+            Inner Join: #nation.n_regionkey = #region.r_regionkey
+              Inner Join: #supplier.s_nationkey = #nation.n_nationkey
+                Inner Join: #partsupp.ps_suppkey = #supplier.s_suppkey
+                  TableScan: partsupp projection=[ps_partkey, ps_suppkey, 
ps_supplycost]
+                  TableScan: supplier projection=[s_suppkey, s_name, 
s_address, s_nationkey, s_phone, s_acctbal, s_comment]
+                TableScan: nation projection=[n_nationkey, n_name, n_regionkey]
+              Filter: #region.r_name = Utf8("EUROPE")
+                TableScan: region projection=[r_regionkey, r_name], 
partial_filters=[#region.r_name = Utf8("EUROPE")]"#
+        .to_string();
+    assert_eq!(actual, expected);
+
+    // assert data
+    let results = execute_to_batches(&ctx, sql).await;
+    let expected = vec!["++", "++"];
+    assert_batches_eq!(expected, &results);
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn tpch_q4_correlated() -> Result<()> {
+    let orders = r#"4,13678,O,53829.87,1995-10-11,5-LOW,Clerk#000000124,0,
+35,12760,O,192885.43,1995-10-23,4-NOT SPECIFIED,Clerk#000000259,0,
+65,1627,P,99763.79,1995-03-18,1-URGENT,Clerk#000000632,0,
+"#;
+    let lineitems = 
r#"4,8804,579,1,30,51384,0.03,0.08,N,O,1996-01-10,1995-12-14,1996-01-18,DELIVER 
IN PERSON,REG AIR,
+35,45,296,1,24,22680.96,0.02,0,N,O,1996-02-21,1996-01-03,1996-03-18,TAKE BACK 
RETURN,FOB,
+65,5970,481,1,26,48775.22,0.03,0.03,A,F,1995-04-20,1995-04-25,1995-05-13,NONE,TRUCK,
+"#;
+
+    let ctx = SessionContext::new();
+    register_tpch_csv_data(&ctx, "orders", orders).await?;
+    register_tpch_csv_data(&ctx, "lineitem", lineitems).await?;
+
+    let sql = r#"
+        select o_orderpriority, count(*) as order_count
+        from orders
+        where exists (
+            select * from lineitem where l_orderkey = o_orderkey and 
l_commitdate < l_receiptdate)
+        group by o_orderpriority
+        order by o_orderpriority;
+        "#;
+
+    // assert plan
+    let plan = ctx.create_logical_plan(sql).unwrap();
+    let plan = ctx.optimize(&plan).unwrap();
+    let actual = format!("{}", plan.display_indent());
+    let expected = r#"Sort: #orders.o_orderpriority ASC NULLS LAST
+  Projection: #orders.o_orderpriority, #COUNT(UInt8(1)) AS order_count
+    Aggregate: groupBy=[[#orders.o_orderpriority]], aggr=[[COUNT(UInt8(1))]]
+      Semi Join: #orders.o_orderkey = #lineitem.l_orderkey
+        TableScan: orders projection=[o_orderkey, o_orderpriority]
+        Filter: #lineitem.l_commitdate < #lineitem.l_receiptdate
+          TableScan: lineitem projection=[l_orderkey, l_commitdate, 
l_receiptdate]"#
+        .to_string();
+    assert_eq!(actual, expected);
+
+    // assert data
+    let results = execute_to_batches(&ctx, sql).await;
+    let expected = vec![
+        "+-----------------+-------------+",
+        "| o_orderpriority | order_count |",
+        "+-----------------+-------------+",
+        "| 1-URGENT        | 1           |",
+        "| 4-NOT SPECIFIED | 1           |",
+        "| 5-LOW           | 1           |",
+        "+-----------------+-------------+",
+    ];
+    assert_batches_eq!(expected, &results);
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn tpch_q17_correlated() -> Result<()> {
+    let parts = r#"63700,goldenrod lavender spring chocolate 
lace,Manufacturer#1,Brand#23,PROMO BURNISHED COPPER,7,MED BOX,901.00,ly. slyly 
ironi
+"#;
+    let lineitems = 
r#"1,63700,7311,2,36.0,45983.16,0.09,0.06,N,O,1996-04-12,1996-02-28,1996-04-20,TAKE
 BACK RETURN,MAIL,ly final dependencies: slyly bold
+1,63700,3701,3,1.0,13309.6,0.1,0.02,N,O,1996-01-29,1996-03-05,1996-01-31,TAKE 
BACK RETURN,REG AIR,"riously. regular, express dep"
+"#;
+
+    let ctx = SessionContext::new();
+    register_tpch_csv_data(&ctx, "part", parts).await?;
+    register_tpch_csv_data(&ctx, "lineitem", lineitems).await?;
+
+    let sql = r#"select sum(l_extendedprice) / 7.0 as avg_yearly
+        from lineitem, part
+        where p_partkey = l_partkey and p_brand = 'Brand#23' and p_container = 
'MED BOX'
+        and l_quantity < (
+            select 0.2 * avg(l_quantity)
+            from lineitem where l_partkey = p_partkey
+        );"#;
+
+    // assert plan
+    let plan = ctx
+        .create_logical_plan(sql)
+        .map_err(|e| format!("{:?} at {}", e, "error"))
+        .unwrap();
+    println!("before:\n{}", plan.display_indent());
+    let plan = ctx
+        .optimize(&plan)
+        .map_err(|e| format!("{:?} at {}", e, "error"))
+        .unwrap();
+    let actual = format!("{}", plan.display_indent());
+    let expected = r#"Projection: #SUM(lineitem.l_extendedprice) / Float64(7) 
AS avg_yearly
+  Aggregate: groupBy=[[]], aggr=[[SUM(#lineitem.l_extendedprice)]]
+    Filter: #lineitem.l_quantity < #__sq_1.__value
+      Inner Join: #part.p_partkey = #__sq_1.l_partkey
+        Inner Join: #lineitem.l_partkey = #part.p_partkey
+          TableScan: lineitem projection=[l_partkey, l_quantity, 
l_extendedprice]
+          Filter: #part.p_brand = Utf8("Brand#23") AND #part.p_container = 
Utf8("MED BOX")
+            TableScan: part projection=[p_partkey, p_brand, p_container]
+        Projection: #lineitem.l_partkey, Float64(0.2) * 
#AVG(lineitem.l_quantity) AS __value, alias=__sq_1
+          Aggregate: groupBy=[[#lineitem.l_partkey]], 
aggr=[[AVG(#lineitem.l_quantity)]]
+            TableScan: lineitem projection=[l_partkey, l_quantity, 
l_extendedprice]"#
+        .to_string();
+    assert_eq!(actual, expected);
+
+    // assert data
+    let results = execute_to_batches(&ctx, sql).await;
+    let expected = vec![
+        "+--------------------+",
+        "| avg_yearly         |",
+        "+--------------------+",
+        "| 1901.3714285714286 |",
+        "+--------------------+",
+    ];
+    assert_batches_eq!(expected, &results);
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn tpch_q20_correlated() -> Result<()> {
+    let ctx = SessionContext::new();
+    register_tpch_csv(&ctx, "supplier").await?;
+    register_tpch_csv(&ctx, "nation").await?;
+    register_tpch_csv(&ctx, "partsupp").await?;
+    register_tpch_csv(&ctx, "part").await?;
+    register_tpch_csv(&ctx, "lineitem").await?;
+
+    let sql = r#"select s_name, s_address
+from supplier, nation
+where s_suppkey in (
+    select ps_suppkey from partsupp
+    where ps_partkey in ( select p_partkey from part where p_name like 
'forest%' )
+      and ps_availqty > ( select 0.5 * sum(l_quantity) from lineitem
+        where l_partkey = ps_partkey and l_suppkey = ps_suppkey and l_shipdate 
>= date '1994-01-01'
+    )
+)
+and s_nationkey = n_nationkey and n_name = 'CANADA'
+order by s_name;
+"#;
+
+    // assert plan
+    let plan = ctx
+        .create_logical_plan(sql)
+        .map_err(|e| format!("{:?} at {}", e, "error"))
+        .unwrap();
+    let plan = ctx
+        .optimize(&plan)
+        .map_err(|e| format!("{:?} at {}", e, "error"))
+        .unwrap();
+    let actual = format!("{}", plan.display_indent());
+    let expected = r#"Sort: #supplier.s_name ASC NULLS LAST
+  Projection: #supplier.s_name, #supplier.s_address
+    Semi Join: #supplier.s_suppkey = #__sq_2.ps_suppkey
+      Inner Join: #supplier.s_nationkey = #nation.n_nationkey
+        TableScan: supplier projection=[s_suppkey, s_name, s_address, 
s_nationkey]
+        Filter: #nation.n_name = Utf8("CANADA")
+          TableScan: nation projection=[n_nationkey, n_name], 
partial_filters=[#nation.n_name = Utf8("CANADA")]
+      Projection: #partsupp.ps_suppkey AS ps_suppkey, alias=__sq_2
+        Filter: #partsupp.ps_availqty > #__sq_3.__value
+          Inner Join: #partsupp.ps_partkey = #__sq_3.l_partkey, 
#partsupp.ps_suppkey = #__sq_3.l_suppkey
+            Semi Join: #partsupp.ps_partkey = #__sq_1.p_partkey
+              TableScan: partsupp projection=[ps_partkey, ps_suppkey, 
ps_availqty]
+              Projection: #part.p_partkey AS p_partkey, alias=__sq_1
+                Filter: #part.p_name LIKE Utf8("forest%")
+                  TableScan: part projection=[p_partkey, p_name], 
partial_filters=[#part.p_name LIKE Utf8("forest%")]
+            Projection: #lineitem.l_partkey, #lineitem.l_suppkey, Float64(0.5) 
* #SUM(lineitem.l_quantity) AS __value, alias=__sq_3
+              Aggregate: groupBy=[[#lineitem.l_partkey, #lineitem.l_suppkey]], 
aggr=[[SUM(#lineitem.l_quantity)]]
+                Filter: #lineitem.l_shipdate >= CAST(Utf8("1994-01-01") AS 
Date32)
+                  TableScan: lineitem projection=[l_partkey, l_suppkey, 
l_quantity, l_shipdate], partial_filters=[#lineitem.l_shipdate >= 
CAST(Utf8("1994-01-01") AS Date32)]"#
+        .to_string();
+    assert_eq!(actual, expected);
+
+    // assert data
+    let results = execute_to_batches(&ctx, sql).await;
+    let expected = vec!["++", "++"];
+    assert_batches_eq!(expected, &results);
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn tpch_q22_correlated() -> Result<()> {
+    let ctx = SessionContext::new();
+    register_tpch_csv(&ctx, "customer").await?;
+    register_tpch_csv(&ctx, "orders").await?;
+
+    let sql = r#"select cntrycode, count(*) as numcust, sum(c_acctbal) as 
totacctbal
+from (
+        select substring(c_phone from 1 for 2) as cntrycode, c_acctbal from 
customer
+        where substring(c_phone from 1 for 2) in ('13', '31', '23', '29', 
'30', '18', '17')
+          and c_acctbal > (
+            select avg(c_acctbal) from customer where c_acctbal > 0.00
+              and substring(c_phone from 1 for 2) in ('13', '31', '23', '29', 
'30', '18', '17')
+        )
+          and not exists ( select * from orders where o_custkey = c_custkey )
+    ) as custsale
+group by cntrycode
+order by cntrycode;"#;
+
+    // assert plan
+    let plan = ctx
+        .create_logical_plan(sql)
+        .map_err(|e| format!("{:?} at {}", e, "error"))
+        .unwrap();
+    let plan = ctx
+        .optimize(&plan)
+        .map_err(|e| format!("{:?} at {}", e, "error"))
+        .unwrap();
+    let actual = format!("{}", plan.display_indent());
+    let expected = r#"Sort: #custsale.cntrycode ASC NULLS LAST
+  Projection: #custsale.cntrycode, #COUNT(UInt8(1)) AS numcust, 
#SUM(custsale.c_acctbal) AS totacctbal
+    Aggregate: groupBy=[[#custsale.cntrycode]], aggr=[[COUNT(UInt8(1)), 
SUM(#custsale.c_acctbal)]]
+      Projection: #custsale.cntrycode, #custsale.c_acctbal, alias=custsale
+        Projection: substr(#customer.c_phone, Int64(1), Int64(2)) AS 
cntrycode, #customer.c_acctbal, alias=custsale
+          Filter: #customer.c_acctbal > #__sq_1.__value
+            CrossJoin:
+              Anti Join: #customer.c_custkey = #orders.o_custkey
+                Filter: substr(#customer.c_phone, Int64(1), Int64(2)) IN 
([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), 
Utf8("17")])
+                  TableScan: customer projection=[c_custkey, c_phone, 
c_acctbal], partial_filters=[substr(#customer.c_phone, Int64(1), Int64(2)) IN 
([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), 
Utf8("17")])]
+                TableScan: orders projection=[o_custkey]
+              Projection: #AVG(customer.c_acctbal) AS __value, alias=__sq_1
+                Aggregate: groupBy=[[]], aggr=[[AVG(#customer.c_acctbal)]]
+                  Filter: #customer.c_acctbal > Float64(0) AND 
substr(#customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), 
Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")])
+                    TableScan: customer projection=[c_phone, c_acctbal], 
partial_filters=[#customer.c_acctbal > Float64(0), substr(#customer.c_phone, 
Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), 
Utf8("30"), Utf8("18"), Utf8("17")])]"#
+        .to_string();
+    assert_eq!(actual, expected);
+
+    // assert data
+    let results = execute_to_batches(&ctx, sql).await;
+    let expected = vec![
+        "+-----------+---------+------------+",
+        "| cntrycode | numcust | totacctbal |",
+        "+-----------+---------+------------+",
+        "| 18        | 1       | 8324.07    |",
+        "| 30        | 1       | 7638.57    |",
+        "+-----------+---------+------------+",
+    ];
+    assert_batches_eq!(expected, &results);
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn tpch_q11_correlated() -> Result<()> {
+    let ctx = SessionContext::new();
+    register_tpch_csv(&ctx, "partsupp").await?;
+    register_tpch_csv(&ctx, "supplier").await?;
+    register_tpch_csv(&ctx, "nation").await?;
+
+    let sql = r#"select ps_partkey, sum(ps_supplycost * ps_availqty) as value
+from partsupp, supplier, nation
+where ps_suppkey = s_suppkey and s_nationkey = n_nationkey and n_name = 
'GERMANY'
+group by ps_partkey having
+    sum(ps_supplycost * ps_availqty) > (
+        select sum(ps_supplycost * ps_availqty) * 0.0001
+        from partsupp, supplier, nation
+        where ps_suppkey = s_suppkey and s_nationkey = n_nationkey and n_name 
= 'GERMANY'
+    )
+order by value desc;
+"#;
+
+    // assert plan
+    let plan = ctx
+        .create_logical_plan(sql)
+        .map_err(|e| format!("{:?} at {}", e, "error"))
+        .unwrap();
+    println!("before:\n{}", plan.display_indent());
+    let plan = ctx
+        .optimize(&plan)
+        .map_err(|e| format!("{:?} at {}", e, "error"))
+        .unwrap();
+    let actual = format!("{}", plan.display_indent());
+    println!("after:\n{}", actual);
+    let expected = r#"Sort: #value DESC NULLS FIRST
+  Projection: #partsupp.ps_partkey, #SUM(partsupp.ps_supplycost * 
partsupp.ps_availqty) AS value

Review Comment:
   👍 



##########
datafusion/core/tests/sql/subqueries.rs:
##########
@@ -0,0 +1,461 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use super::*;
+use crate::sql::execute_to_batches;
+use datafusion::assert_batches_eq;
+use datafusion::prelude::SessionContext;
+use log::debug;
+
+#[cfg(test)]
+#[ctor::ctor]
+fn init() {
+    let _ = env_logger::try_init();
+}
+
+#[tokio::test]
+async fn correlated_recursive_scalar_subquery() -> Result<()> {
+    let ctx = SessionContext::new();
+    register_tpch_csv(&ctx, "customer").await?;
+    register_tpch_csv(&ctx, "orders").await?;
+    register_tpch_csv(&ctx, "lineitem").await?;
+
+    let sql = r#"
+select c_custkey from customer
+where c_acctbal < (
+    select sum(o_totalprice) from orders
+    where o_custkey = c_custkey
+    and o_totalprice < (
+            select sum(l_extendedprice) as price from lineitem where 
l_orderkey = o_orderkey
+    )
+) order by c_custkey;"#;
+
+    // assert plan
+    let plan = ctx.create_logical_plan(sql).unwrap();
+    debug!("input:\n{}", plan.display_indent());
+
+    let plan = ctx.optimize(&plan).unwrap();
+    let actual = format!("{}", plan.display_indent());
+    let expected = r#"Sort: #customer.c_custkey ASC NULLS LAST
+  Projection: #customer.c_custkey
+    Filter: #customer.c_acctbal < #__sq_2.__value
+      Inner Join: #customer.c_custkey = #__sq_2.o_custkey
+        TableScan: customer projection=[c_custkey, c_acctbal]
+        Projection: #orders.o_custkey, #SUM(orders.o_totalprice) AS __value, 
alias=__sq_2
+          Aggregate: groupBy=[[#orders.o_custkey]], 
aggr=[[SUM(#orders.o_totalprice)]]
+            Filter: #orders.o_totalprice < #__sq_1.__value
+              Inner Join: #orders.o_orderkey = #__sq_1.l_orderkey
+                TableScan: orders projection=[o_orderkey, o_custkey, 
o_totalprice]
+                Projection: #lineitem.l_orderkey, 
#SUM(lineitem.l_extendedprice) AS price AS __value, alias=__sq_1
+                  Aggregate: groupBy=[[#lineitem.l_orderkey]], 
aggr=[[SUM(#lineitem.l_extendedprice)]]
+                    TableScan: lineitem projection=[l_orderkey, 
l_extendedprice]"#
+        .to_string();
+    assert_eq!(actual, expected);
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn correlated_where_in() -> Result<()> {
+    let ctx = SessionContext::new();
+    register_tpch_csv(&ctx, "orders").await?;
+    register_tpch_csv(&ctx, "lineitem").await?;
+    register_tpch_csv(&ctx, "partsupp").await?;
+
+    let sql = r#"select o_orderkey from orders
+inner join lineitem on o_orderkey = l_orderkey
+where l_partkey in ( select ps_partkey from partsupp where ps_suppkey = 
l_suppkey );"#;
+
+    // assert plan
+    let plan = ctx.create_logical_plan(sql).unwrap();
+    let plan = ctx.optimize(&plan).unwrap();
+    let actual = format!("{}", plan.display_indent());
+    let expected = r#"Projection: #orders.o_orderkey
+  Semi Join: #lineitem.l_partkey = #__sq_1.ps_partkey, #lineitem.l_suppkey = 
#__sq_1.ps_suppkey
+    Inner Join: #orders.o_orderkey = #lineitem.l_orderkey
+      TableScan: orders projection=[o_orderkey]
+      TableScan: lineitem projection=[l_orderkey, l_partkey, l_suppkey]
+    Projection: #partsupp.ps_partkey AS ps_partkey, #partsupp.ps_suppkey AS 
ps_suppkey, alias=__sq_1
+      TableScan: partsupp projection=[ps_partkey, ps_suppkey]"#
+        .to_string();
+    assert_eq!(actual, expected);
+
+    // assert data
+    let results = execute_to_batches(&ctx, sql).await;
+    let expected = vec!["++", "++"];
+    assert_batches_eq!(expected, &results);
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn tpch_q2_correlated() -> Result<()> {
+    let ctx = SessionContext::new();
+    register_tpch_csv(&ctx, "part").await?;
+    register_tpch_csv(&ctx, "supplier").await?;
+    register_tpch_csv(&ctx, "partsupp").await?;
+    register_tpch_csv(&ctx, "nation").await?;
+    register_tpch_csv(&ctx, "region").await?;
+
+    let sql = r#"select s_acctbal, s_name, n_name, p_partkey, p_mfgr, 
s_address, s_phone, s_comment
+from part, supplier, partsupp, nation, region
+where p_partkey = ps_partkey and s_suppkey = ps_suppkey and p_size = 15 and 
p_type like '%BRASS'
+    and s_nationkey = n_nationkey and n_regionkey = r_regionkey and r_name = 
'EUROPE'
+    and ps_supplycost = (
+        select min(ps_supplycost) from partsupp, supplier, nation, region
+        where p_partkey = ps_partkey and s_suppkey = ps_suppkey and 
s_nationkey = n_nationkey
+        and n_regionkey = r_regionkey and r_name = 'EUROPE'
+    )
+order by s_acctbal desc, n_name, s_name, p_partkey;"#;
+
+    // assert plan
+    let plan = ctx.create_logical_plan(sql).unwrap();
+    let plan = ctx.optimize(&plan).unwrap();
+    let actual = format!("{}", plan.display_indent());
+    let expected = r#"Sort: #supplier.s_acctbal DESC NULLS FIRST, 
#nation.n_name ASC NULLS LAST, #supplier.s_name ASC NULLS LAST, #part.p_partkey 
ASC NULLS LAST
+  Projection: #supplier.s_acctbal, #supplier.s_name, #nation.n_name, 
#part.p_partkey, #part.p_mfgr, #supplier.s_address, #supplier.s_phone, 
#supplier.s_comment
+    Filter: #partsupp.ps_supplycost = #__sq_1.__value
+      Inner Join: #part.p_partkey = #__sq_1.ps_partkey
+        Inner Join: #nation.n_regionkey = #region.r_regionkey
+          Inner Join: #supplier.s_nationkey = #nation.n_nationkey
+            Inner Join: #partsupp.ps_suppkey = #supplier.s_suppkey
+              Inner Join: #part.p_partkey = #partsupp.ps_partkey
+                Filter: #part.p_size = Int64(15) AND #part.p_type LIKE 
Utf8("%BRASS")
+                  TableScan: part projection=[p_partkey, p_mfgr, p_type, 
p_size], partial_filters=[#part.p_size = Int64(15), #part.p_type LIKE 
Utf8("%BRASS")]
+                TableScan: partsupp projection=[ps_partkey, ps_suppkey, 
ps_supplycost]
+              TableScan: supplier projection=[s_suppkey, s_name, s_address, 
s_nationkey, s_phone, s_acctbal, s_comment]
+            TableScan: nation projection=[n_nationkey, n_name, n_regionkey]
+          Filter: #region.r_name = Utf8("EUROPE")
+            TableScan: region projection=[r_regionkey, r_name], 
partial_filters=[#region.r_name = Utf8("EUROPE")]
+        Projection: #partsupp.ps_partkey, #MIN(partsupp.ps_supplycost) AS 
__value, alias=__sq_1
+          Aggregate: groupBy=[[#partsupp.ps_partkey]], 
aggr=[[MIN(#partsupp.ps_supplycost)]]
+            Inner Join: #nation.n_regionkey = #region.r_regionkey
+              Inner Join: #supplier.s_nationkey = #nation.n_nationkey
+                Inner Join: #partsupp.ps_suppkey = #supplier.s_suppkey
+                  TableScan: partsupp projection=[ps_partkey, ps_suppkey, 
ps_supplycost]
+                  TableScan: supplier projection=[s_suppkey, s_name, 
s_address, s_nationkey, s_phone, s_acctbal, s_comment]
+                TableScan: nation projection=[n_nationkey, n_name, n_regionkey]
+              Filter: #region.r_name = Utf8("EUROPE")
+                TableScan: region projection=[r_regionkey, r_name], 
partial_filters=[#region.r_name = Utf8("EUROPE")]"#
+        .to_string();
+    assert_eq!(actual, expected);
+
+    // assert data
+    let results = execute_to_batches(&ctx, sql).await;
+    let expected = vec!["++", "++"];
+    assert_batches_eq!(expected, &results);
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn tpch_q4_correlated() -> Result<()> {
+    let orders = r#"4,13678,O,53829.87,1995-10-11,5-LOW,Clerk#000000124,0,
+35,12760,O,192885.43,1995-10-23,4-NOT SPECIFIED,Clerk#000000259,0,
+65,1627,P,99763.79,1995-03-18,1-URGENT,Clerk#000000632,0,
+"#;
+    let lineitems = 
r#"4,8804,579,1,30,51384,0.03,0.08,N,O,1996-01-10,1995-12-14,1996-01-18,DELIVER 
IN PERSON,REG AIR,
+35,45,296,1,24,22680.96,0.02,0,N,O,1996-02-21,1996-01-03,1996-03-18,TAKE BACK 
RETURN,FOB,
+65,5970,481,1,26,48775.22,0.03,0.03,A,F,1995-04-20,1995-04-25,1995-05-13,NONE,TRUCK,
+"#;
+
+    let ctx = SessionContext::new();
+    register_tpch_csv_data(&ctx, "orders", orders).await?;
+    register_tpch_csv_data(&ctx, "lineitem", lineitems).await?;
+
+    let sql = r#"
+        select o_orderpriority, count(*) as order_count
+        from orders
+        where exists (
+            select * from lineitem where l_orderkey = o_orderkey and 
l_commitdate < l_receiptdate)
+        group by o_orderpriority
+        order by o_orderpriority;
+        "#;
+
+    // assert plan
+    let plan = ctx.create_logical_plan(sql).unwrap();
+    let plan = ctx.optimize(&plan).unwrap();
+    let actual = format!("{}", plan.display_indent());
+    let expected = r#"Sort: #orders.o_orderpriority ASC NULLS LAST
+  Projection: #orders.o_orderpriority, #COUNT(UInt8(1)) AS order_count
+    Aggregate: groupBy=[[#orders.o_orderpriority]], aggr=[[COUNT(UInt8(1))]]
+      Semi Join: #orders.o_orderkey = #lineitem.l_orderkey
+        TableScan: orders projection=[o_orderkey, o_orderpriority]
+        Filter: #lineitem.l_commitdate < #lineitem.l_receiptdate
+          TableScan: lineitem projection=[l_orderkey, l_commitdate, 
l_receiptdate]"#
+        .to_string();
+    assert_eq!(actual, expected);
+
+    // assert data
+    let results = execute_to_batches(&ctx, sql).await;
+    let expected = vec![
+        "+-----------------+-------------+",
+        "| o_orderpriority | order_count |",
+        "+-----------------+-------------+",
+        "| 1-URGENT        | 1           |",
+        "| 4-NOT SPECIFIED | 1           |",
+        "| 5-LOW           | 1           |",
+        "+-----------------+-------------+",
+    ];
+    assert_batches_eq!(expected, &results);
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn tpch_q17_correlated() -> Result<()> {
+    let parts = r#"63700,goldenrod lavender spring chocolate 
lace,Manufacturer#1,Brand#23,PROMO BURNISHED COPPER,7,MED BOX,901.00,ly. slyly 
ironi
+"#;
+    let lineitems = 
r#"1,63700,7311,2,36.0,45983.16,0.09,0.06,N,O,1996-04-12,1996-02-28,1996-04-20,TAKE
 BACK RETURN,MAIL,ly final dependencies: slyly bold
+1,63700,3701,3,1.0,13309.6,0.1,0.02,N,O,1996-01-29,1996-03-05,1996-01-31,TAKE 
BACK RETURN,REG AIR,"riously. regular, express dep"
+"#;
+
+    let ctx = SessionContext::new();
+    register_tpch_csv_data(&ctx, "part", parts).await?;
+    register_tpch_csv_data(&ctx, "lineitem", lineitems).await?;
+
+    let sql = r#"select sum(l_extendedprice) / 7.0 as avg_yearly
+        from lineitem, part
+        where p_partkey = l_partkey and p_brand = 'Brand#23' and p_container = 
'MED BOX'
+        and l_quantity < (
+            select 0.2 * avg(l_quantity)
+            from lineitem where l_partkey = p_partkey
+        );"#;
+
+    // assert plan
+    let plan = ctx
+        .create_logical_plan(sql)
+        .map_err(|e| format!("{:?} at {}", e, "error"))
+        .unwrap();
+    println!("before:\n{}", plan.display_indent());
+    let plan = ctx
+        .optimize(&plan)
+        .map_err(|e| format!("{:?} at {}", e, "error"))
+        .unwrap();
+    let actual = format!("{}", plan.display_indent());
+    let expected = r#"Projection: #SUM(lineitem.l_extendedprice) / Float64(7) 
AS avg_yearly
+  Aggregate: groupBy=[[]], aggr=[[SUM(#lineitem.l_extendedprice)]]
+    Filter: #lineitem.l_quantity < #__sq_1.__value
+      Inner Join: #part.p_partkey = #__sq_1.l_partkey
+        Inner Join: #lineitem.l_partkey = #part.p_partkey
+          TableScan: lineitem projection=[l_partkey, l_quantity, 
l_extendedprice]
+          Filter: #part.p_brand = Utf8("Brand#23") AND #part.p_container = 
Utf8("MED BOX")
+            TableScan: part projection=[p_partkey, p_brand, p_container]
+        Projection: #lineitem.l_partkey, Float64(0.2) * 
#AVG(lineitem.l_quantity) AS __value, alias=__sq_1
+          Aggregate: groupBy=[[#lineitem.l_partkey]], 
aggr=[[AVG(#lineitem.l_quantity)]]
+            TableScan: lineitem projection=[l_partkey, l_quantity, 
l_extendedprice]"#
+        .to_string();
+    assert_eq!(actual, expected);
+
+    // assert data
+    let results = execute_to_batches(&ctx, sql).await;
+    let expected = vec![
+        "+--------------------+",
+        "| avg_yearly         |",
+        "+--------------------+",
+        "| 1901.3714285714286 |",
+        "+--------------------+",
+    ];
+    assert_batches_eq!(expected, &results);
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn tpch_q20_correlated() -> Result<()> {
+    let ctx = SessionContext::new();
+    register_tpch_csv(&ctx, "supplier").await?;
+    register_tpch_csv(&ctx, "nation").await?;
+    register_tpch_csv(&ctx, "partsupp").await?;
+    register_tpch_csv(&ctx, "part").await?;
+    register_tpch_csv(&ctx, "lineitem").await?;
+
+    let sql = r#"select s_name, s_address
+from supplier, nation
+where s_suppkey in (
+    select ps_suppkey from partsupp
+    where ps_partkey in ( select p_partkey from part where p_name like 
'forest%' )
+      and ps_availqty > ( select 0.5 * sum(l_quantity) from lineitem
+        where l_partkey = ps_partkey and l_suppkey = ps_suppkey and l_shipdate 
>= date '1994-01-01'
+    )
+)
+and s_nationkey = n_nationkey and n_name = 'CANADA'
+order by s_name;
+"#;
+
+    // assert plan
+    let plan = ctx
+        .create_logical_plan(sql)
+        .map_err(|e| format!("{:?} at {}", e, "error"))
+        .unwrap();
+    let plan = ctx
+        .optimize(&plan)
+        .map_err(|e| format!("{:?} at {}", e, "error"))
+        .unwrap();
+    let actual = format!("{}", plan.display_indent());
+    let expected = r#"Sort: #supplier.s_name ASC NULLS LAST
+  Projection: #supplier.s_name, #supplier.s_address

Review Comment:
   👍 



##########
datafusion/expr/src/logical_plan/plan.rs:
##########
@@ -1074,6 +1074,13 @@ impl Projection {
             alias,
         })
     }
+
+    pub fn try_from_plan(plan: &LogicalPlan) -> 
datafusion_common::Result<&Projection> {

Review Comment:
   For what it is worth, I think a more rust idiomatic way would be to `impl 
std::convert::TryInto`
   
   https://doc.rust-lang.org/std/convert/trait.TryInto.html
   
   However, I think this way is also fine 



##########
datafusion/core/tests/sql/subqueries.rs:
##########
@@ -0,0 +1,461 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use super::*;
+use crate::sql::execute_to_batches;
+use datafusion::assert_batches_eq;
+use datafusion::prelude::SessionContext;
+use log::debug;
+
+#[cfg(test)]
+#[ctor::ctor]
+fn init() {
+    let _ = env_logger::try_init();
+}
+
+#[tokio::test]
+async fn correlated_recursive_scalar_subquery() -> Result<()> {
+    let ctx = SessionContext::new();
+    register_tpch_csv(&ctx, "customer").await?;
+    register_tpch_csv(&ctx, "orders").await?;
+    register_tpch_csv(&ctx, "lineitem").await?;
+
+    let sql = r#"
+select c_custkey from customer
+where c_acctbal < (
+    select sum(o_totalprice) from orders
+    where o_custkey = c_custkey
+    and o_totalprice < (
+            select sum(l_extendedprice) as price from lineitem where 
l_orderkey = o_orderkey
+    )
+) order by c_custkey;"#;
+
+    // assert plan
+    let plan = ctx.create_logical_plan(sql).unwrap();
+    debug!("input:\n{}", plan.display_indent());
+
+    let plan = ctx.optimize(&plan).unwrap();
+    let actual = format!("{}", plan.display_indent());
+    let expected = r#"Sort: #customer.c_custkey ASC NULLS LAST
+  Projection: #customer.c_custkey
+    Filter: #customer.c_acctbal < #__sq_2.__value
+      Inner Join: #customer.c_custkey = #__sq_2.o_custkey
+        TableScan: customer projection=[c_custkey, c_acctbal]
+        Projection: #orders.o_custkey, #SUM(orders.o_totalprice) AS __value, 
alias=__sq_2
+          Aggregate: groupBy=[[#orders.o_custkey]], 
aggr=[[SUM(#orders.o_totalprice)]]
+            Filter: #orders.o_totalprice < #__sq_1.__value
+              Inner Join: #orders.o_orderkey = #__sq_1.l_orderkey
+                TableScan: orders projection=[o_orderkey, o_custkey, 
o_totalprice]
+                Projection: #lineitem.l_orderkey, 
#SUM(lineitem.l_extendedprice) AS price AS __value, alias=__sq_1
+                  Aggregate: groupBy=[[#lineitem.l_orderkey]], 
aggr=[[SUM(#lineitem.l_extendedprice)]]
+                    TableScan: lineitem projection=[l_orderkey, 
l_extendedprice]"#
+        .to_string();
+    assert_eq!(actual, expected);
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn correlated_where_in() -> Result<()> {
+    let ctx = SessionContext::new();
+    register_tpch_csv(&ctx, "orders").await?;
+    register_tpch_csv(&ctx, "lineitem").await?;
+    register_tpch_csv(&ctx, "partsupp").await?;
+
+    let sql = r#"select o_orderkey from orders
+inner join lineitem on o_orderkey = l_orderkey
+where l_partkey in ( select ps_partkey from partsupp where ps_suppkey = 
l_suppkey );"#;
+
+    // assert plan
+    let plan = ctx.create_logical_plan(sql).unwrap();
+    let plan = ctx.optimize(&plan).unwrap();
+    let actual = format!("{}", plan.display_indent());
+    let expected = r#"Projection: #orders.o_orderkey
+  Semi Join: #lineitem.l_partkey = #__sq_1.ps_partkey, #lineitem.l_suppkey = 
#__sq_1.ps_suppkey
+    Inner Join: #orders.o_orderkey = #lineitem.l_orderkey
+      TableScan: orders projection=[o_orderkey]
+      TableScan: lineitem projection=[l_orderkey, l_partkey, l_suppkey]
+    Projection: #partsupp.ps_partkey AS ps_partkey, #partsupp.ps_suppkey AS 
ps_suppkey, alias=__sq_1
+      TableScan: partsupp projection=[ps_partkey, ps_suppkey]"#
+        .to_string();
+    assert_eq!(actual, expected);
+
+    // assert data
+    let results = execute_to_batches(&ctx, sql).await;
+    let expected = vec!["++", "++"];

Review Comment:
   Is it correct that there is no data?



##########
datafusion/core/tests/sql/subqueries.rs:
##########
@@ -0,0 +1,461 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use super::*;
+use crate::sql::execute_to_batches;
+use datafusion::assert_batches_eq;
+use datafusion::prelude::SessionContext;
+use log::debug;
+
+#[cfg(test)]
+#[ctor::ctor]
+fn init() {
+    let _ = env_logger::try_init();
+}
+
+#[tokio::test]
+async fn correlated_recursive_scalar_subquery() -> Result<()> {
+    let ctx = SessionContext::new();
+    register_tpch_csv(&ctx, "customer").await?;
+    register_tpch_csv(&ctx, "orders").await?;
+    register_tpch_csv(&ctx, "lineitem").await?;
+
+    let sql = r#"
+select c_custkey from customer
+where c_acctbal < (
+    select sum(o_totalprice) from orders
+    where o_custkey = c_custkey
+    and o_totalprice < (
+            select sum(l_extendedprice) as price from lineitem where 
l_orderkey = o_orderkey
+    )
+) order by c_custkey;"#;
+
+    // assert plan
+    let plan = ctx.create_logical_plan(sql).unwrap();
+    debug!("input:\n{}", plan.display_indent());
+
+    let plan = ctx.optimize(&plan).unwrap();
+    let actual = format!("{}", plan.display_indent());
+    let expected = r#"Sort: #customer.c_custkey ASC NULLS LAST
+  Projection: #customer.c_custkey
+    Filter: #customer.c_acctbal < #__sq_2.__value
+      Inner Join: #customer.c_custkey = #__sq_2.o_custkey
+        TableScan: customer projection=[c_custkey, c_acctbal]
+        Projection: #orders.o_custkey, #SUM(orders.o_totalprice) AS __value, 
alias=__sq_2
+          Aggregate: groupBy=[[#orders.o_custkey]], 
aggr=[[SUM(#orders.o_totalprice)]]
+            Filter: #orders.o_totalprice < #__sq_1.__value
+              Inner Join: #orders.o_orderkey = #__sq_1.l_orderkey
+                TableScan: orders projection=[o_orderkey, o_custkey, 
o_totalprice]
+                Projection: #lineitem.l_orderkey, 
#SUM(lineitem.l_extendedprice) AS price AS __value, alias=__sq_1
+                  Aggregate: groupBy=[[#lineitem.l_orderkey]], 
aggr=[[SUM(#lineitem.l_extendedprice)]]
+                    TableScan: lineitem projection=[l_orderkey, 
l_extendedprice]"#
+        .to_string();
+    assert_eq!(actual, expected);
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn correlated_where_in() -> Result<()> {
+    let ctx = SessionContext::new();
+    register_tpch_csv(&ctx, "orders").await?;
+    register_tpch_csv(&ctx, "lineitem").await?;
+    register_tpch_csv(&ctx, "partsupp").await?;
+
+    let sql = r#"select o_orderkey from orders
+inner join lineitem on o_orderkey = l_orderkey
+where l_partkey in ( select ps_partkey from partsupp where ps_suppkey = 
l_suppkey );"#;
+
+    // assert plan
+    let plan = ctx.create_logical_plan(sql).unwrap();
+    let plan = ctx.optimize(&plan).unwrap();
+    let actual = format!("{}", plan.display_indent());
+    let expected = r#"Projection: #orders.o_orderkey
+  Semi Join: #lineitem.l_partkey = #__sq_1.ps_partkey, #lineitem.l_suppkey = 
#__sq_1.ps_suppkey
+    Inner Join: #orders.o_orderkey = #lineitem.l_orderkey
+      TableScan: orders projection=[o_orderkey]
+      TableScan: lineitem projection=[l_orderkey, l_partkey, l_suppkey]
+    Projection: #partsupp.ps_partkey AS ps_partkey, #partsupp.ps_suppkey AS 
ps_suppkey, alias=__sq_1
+      TableScan: partsupp projection=[ps_partkey, ps_suppkey]"#
+        .to_string();
+    assert_eq!(actual, expected);
+
+    // assert data
+    let results = execute_to_batches(&ctx, sql).await;
+    let expected = vec!["++", "++"];
+    assert_batches_eq!(expected, &results);
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn tpch_q2_correlated() -> Result<()> {
+    let ctx = SessionContext::new();
+    register_tpch_csv(&ctx, "part").await?;
+    register_tpch_csv(&ctx, "supplier").await?;
+    register_tpch_csv(&ctx, "partsupp").await?;
+    register_tpch_csv(&ctx, "nation").await?;
+    register_tpch_csv(&ctx, "region").await?;
+
+    let sql = r#"select s_acctbal, s_name, n_name, p_partkey, p_mfgr, 
s_address, s_phone, s_comment
+from part, supplier, partsupp, nation, region
+where p_partkey = ps_partkey and s_suppkey = ps_suppkey and p_size = 15 and 
p_type like '%BRASS'
+    and s_nationkey = n_nationkey and n_regionkey = r_regionkey and r_name = 
'EUROPE'
+    and ps_supplycost = (
+        select min(ps_supplycost) from partsupp, supplier, nation, region
+        where p_partkey = ps_partkey and s_suppkey = ps_suppkey and 
s_nationkey = n_nationkey
+        and n_regionkey = r_regionkey and r_name = 'EUROPE'
+    )
+order by s_acctbal desc, n_name, s_name, p_partkey;"#;
+
+    // assert plan
+    let plan = ctx.create_logical_plan(sql).unwrap();
+    let plan = ctx.optimize(&plan).unwrap();
+    let actual = format!("{}", plan.display_indent());
+    let expected = r#"Sort: #supplier.s_acctbal DESC NULLS FIRST, 
#nation.n_name ASC NULLS LAST, #supplier.s_name ASC NULLS LAST, #part.p_partkey 
ASC NULLS LAST
+  Projection: #supplier.s_acctbal, #supplier.s_name, #nation.n_name, 
#part.p_partkey, #part.p_mfgr, #supplier.s_address, #supplier.s_phone, 
#supplier.s_comment
+    Filter: #partsupp.ps_supplycost = #__sq_1.__value
+      Inner Join: #part.p_partkey = #__sq_1.ps_partkey
+        Inner Join: #nation.n_regionkey = #region.r_regionkey
+          Inner Join: #supplier.s_nationkey = #nation.n_nationkey
+            Inner Join: #partsupp.ps_suppkey = #supplier.s_suppkey
+              Inner Join: #part.p_partkey = #partsupp.ps_partkey
+                Filter: #part.p_size = Int64(15) AND #part.p_type LIKE 
Utf8("%BRASS")
+                  TableScan: part projection=[p_partkey, p_mfgr, p_type, 
p_size], partial_filters=[#part.p_size = Int64(15), #part.p_type LIKE 
Utf8("%BRASS")]
+                TableScan: partsupp projection=[ps_partkey, ps_suppkey, 
ps_supplycost]
+              TableScan: supplier projection=[s_suppkey, s_name, s_address, 
s_nationkey, s_phone, s_acctbal, s_comment]
+            TableScan: nation projection=[n_nationkey, n_name, n_regionkey]
+          Filter: #region.r_name = Utf8("EUROPE")
+            TableScan: region projection=[r_regionkey, r_name], 
partial_filters=[#region.r_name = Utf8("EUROPE")]
+        Projection: #partsupp.ps_partkey, #MIN(partsupp.ps_supplycost) AS 
__value, alias=__sq_1
+          Aggregate: groupBy=[[#partsupp.ps_partkey]], 
aggr=[[MIN(#partsupp.ps_supplycost)]]
+            Inner Join: #nation.n_regionkey = #region.r_regionkey
+              Inner Join: #supplier.s_nationkey = #nation.n_nationkey
+                Inner Join: #partsupp.ps_suppkey = #supplier.s_suppkey
+                  TableScan: partsupp projection=[ps_partkey, ps_suppkey, 
ps_supplycost]
+                  TableScan: supplier projection=[s_suppkey, s_name, 
s_address, s_nationkey, s_phone, s_acctbal, s_comment]
+                TableScan: nation projection=[n_nationkey, n_name, n_regionkey]
+              Filter: #region.r_name = Utf8("EUROPE")
+                TableScan: region projection=[r_regionkey, r_name], 
partial_filters=[#region.r_name = Utf8("EUROPE")]"#
+        .to_string();
+    assert_eq!(actual, expected);
+
+    // assert data
+    let results = execute_to_batches(&ctx, sql).await;
+    let expected = vec!["++", "++"];
+    assert_batches_eq!(expected, &results);
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn tpch_q4_correlated() -> Result<()> {
+    let orders = r#"4,13678,O,53829.87,1995-10-11,5-LOW,Clerk#000000124,0,
+35,12760,O,192885.43,1995-10-23,4-NOT SPECIFIED,Clerk#000000259,0,
+65,1627,P,99763.79,1995-03-18,1-URGENT,Clerk#000000632,0,
+"#;
+    let lineitems = 
r#"4,8804,579,1,30,51384,0.03,0.08,N,O,1996-01-10,1995-12-14,1996-01-18,DELIVER 
IN PERSON,REG AIR,
+35,45,296,1,24,22680.96,0.02,0,N,O,1996-02-21,1996-01-03,1996-03-18,TAKE BACK 
RETURN,FOB,
+65,5970,481,1,26,48775.22,0.03,0.03,A,F,1995-04-20,1995-04-25,1995-05-13,NONE,TRUCK,
+"#;
+
+    let ctx = SessionContext::new();
+    register_tpch_csv_data(&ctx, "orders", orders).await?;
+    register_tpch_csv_data(&ctx, "lineitem", lineitems).await?;
+
+    let sql = r#"
+        select o_orderpriority, count(*) as order_count
+        from orders
+        where exists (
+            select * from lineitem where l_orderkey = o_orderkey and 
l_commitdate < l_receiptdate)
+        group by o_orderpriority
+        order by o_orderpriority;
+        "#;
+
+    // assert plan
+    let plan = ctx.create_logical_plan(sql).unwrap();
+    let plan = ctx.optimize(&plan).unwrap();
+    let actual = format!("{}", plan.display_indent());
+    let expected = r#"Sort: #orders.o_orderpriority ASC NULLS LAST
+  Projection: #orders.o_orderpriority, #COUNT(UInt8(1)) AS order_count

Review Comment:
   👍 



##########
datafusion/core/tests/sql/subqueries.rs:
##########
@@ -0,0 +1,461 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use super::*;
+use crate::sql::execute_to_batches;
+use datafusion::assert_batches_eq;
+use datafusion::prelude::SessionContext;
+use log::debug;
+
+#[cfg(test)]
+#[ctor::ctor]
+fn init() {
+    let _ = env_logger::try_init();
+}
+
+#[tokio::test]
+async fn correlated_recursive_scalar_subquery() -> Result<()> {
+    let ctx = SessionContext::new();
+    register_tpch_csv(&ctx, "customer").await?;
+    register_tpch_csv(&ctx, "orders").await?;
+    register_tpch_csv(&ctx, "lineitem").await?;
+
+    let sql = r#"
+select c_custkey from customer
+where c_acctbal < (
+    select sum(o_totalprice) from orders
+    where o_custkey = c_custkey
+    and o_totalprice < (
+            select sum(l_extendedprice) as price from lineitem where 
l_orderkey = o_orderkey
+    )
+) order by c_custkey;"#;
+
+    // assert plan
+    let plan = ctx.create_logical_plan(sql).unwrap();
+    debug!("input:\n{}", plan.display_indent());
+
+    let plan = ctx.optimize(&plan).unwrap();
+    let actual = format!("{}", plan.display_indent());
+    let expected = r#"Sort: #customer.c_custkey ASC NULLS LAST
+  Projection: #customer.c_custkey
+    Filter: #customer.c_acctbal < #__sq_2.__value
+      Inner Join: #customer.c_custkey = #__sq_2.o_custkey
+        TableScan: customer projection=[c_custkey, c_acctbal]
+        Projection: #orders.o_custkey, #SUM(orders.o_totalprice) AS __value, 
alias=__sq_2
+          Aggregate: groupBy=[[#orders.o_custkey]], 
aggr=[[SUM(#orders.o_totalprice)]]
+            Filter: #orders.o_totalprice < #__sq_1.__value
+              Inner Join: #orders.o_orderkey = #__sq_1.l_orderkey
+                TableScan: orders projection=[o_orderkey, o_custkey, 
o_totalprice]
+                Projection: #lineitem.l_orderkey, 
#SUM(lineitem.l_extendedprice) AS price AS __value, alias=__sq_1
+                  Aggregate: groupBy=[[#lineitem.l_orderkey]], 
aggr=[[SUM(#lineitem.l_extendedprice)]]
+                    TableScan: lineitem projection=[l_orderkey, 
l_extendedprice]"#
+        .to_string();
+    assert_eq!(actual, expected);
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn correlated_where_in() -> Result<()> {
+    let ctx = SessionContext::new();
+    register_tpch_csv(&ctx, "orders").await?;
+    register_tpch_csv(&ctx, "lineitem").await?;
+    register_tpch_csv(&ctx, "partsupp").await?;
+
+    let sql = r#"select o_orderkey from orders
+inner join lineitem on o_orderkey = l_orderkey
+where l_partkey in ( select ps_partkey from partsupp where ps_suppkey = 
l_suppkey );"#;
+
+    // assert plan
+    let plan = ctx.create_logical_plan(sql).unwrap();
+    let plan = ctx.optimize(&plan).unwrap();
+    let actual = format!("{}", plan.display_indent());
+    let expected = r#"Projection: #orders.o_orderkey
+  Semi Join: #lineitem.l_partkey = #__sq_1.ps_partkey, #lineitem.l_suppkey = 
#__sq_1.ps_suppkey

Review Comment:
   👍 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to