jonahgao commented on code in PR #13314:
URL: https://github.com/apache/datafusion/pull/13314#discussion_r1835638570


##########
datafusion/proto/tests/cases/roundtrip_logical_plan.rs:
##########
@@ -2523,3 +2525,75 @@ fn roundtrip_window() {
     roundtrip_expr_test(test_expr6, ctx.clone());
     roundtrip_expr_test(text_expr7, ctx);
 }
+
+#[tokio::test]
+async fn roundtrip_recursive_query() {
+    #[derive(Debug)]
+    pub struct EmptyTableCodec;
+
+    impl LogicalExtensionCodec for EmptyTableCodec {
+        fn try_decode(
+            &self,
+            _buf: &[u8],
+            _inputs: &[LogicalPlan],
+            _ctx: &SessionContext,
+        ) -> Result<Extension, DataFusionError> {
+            not_impl_err!("No extension codec provided")
+        }
+
+        fn try_encode(
+            &self,
+            _node: &Extension,
+            _buf: &mut Vec<u8>,
+        ) -> Result<(), DataFusionError> {
+            not_impl_err!("No extension codec provided")
+        }
+
+        fn try_decode_table_provider(
+            &self,
+            _buf: &[u8],
+            _table_ref: &TableReference,
+            schema: SchemaRef,
+            _ctx: &SessionContext,
+        ) -> Result<Arc<dyn TableProvider>, DataFusionError> {
+            let table = MemTable::try_new(schema, vec![vec![]])?;
+            Ok(Arc::new(table))
+        }
+
+        fn try_encode_table_provider(
+            &self,
+            _table_ref: &TableReference,
+            _node: Arc<dyn TableProvider>,
+            _buf: &mut Vec<u8>,
+        ) -> Result<(), DataFusionError> {
+            Ok(())

Review Comment:
   It seems that we need to 
[encode](https://github.com/apache/datafusion/blob/c48f12d27cc60ae5e9c3d0791d553453f0dd7001/datafusion/proto/src/logical_plan/mod.rs#L964)
 `CteWorkTable` and then recreate a `CteWorkTable` when decoding.🤔
   



##########
datafusion/proto/tests/cases/roundtrip_logical_plan.rs:
##########
@@ -2523,3 +2525,75 @@ fn roundtrip_window() {
     roundtrip_expr_test(test_expr6, ctx.clone());
     roundtrip_expr_test(text_expr7, ctx);
 }
+
+#[tokio::test]
+async fn roundtrip_recursive_query() {
+    #[derive(Debug)]
+    pub struct EmptyTableCodec;
+
+    impl LogicalExtensionCodec for EmptyTableCodec {
+        fn try_decode(
+            &self,
+            _buf: &[u8],
+            _inputs: &[LogicalPlan],
+            _ctx: &SessionContext,
+        ) -> Result<Extension, DataFusionError> {
+            not_impl_err!("No extension codec provided")
+        }
+
+        fn try_encode(
+            &self,
+            _node: &Extension,
+            _buf: &mut Vec<u8>,
+        ) -> Result<(), DataFusionError> {
+            not_impl_err!("No extension codec provided")
+        }
+
+        fn try_decode_table_provider(
+            &self,
+            _buf: &[u8],
+            _table_ref: &TableReference,
+            schema: SchemaRef,
+            _ctx: &SessionContext,
+        ) -> Result<Arc<dyn TableProvider>, DataFusionError> {
+            let table = MemTable::try_new(schema, vec![vec![]])?;
+            Ok(Arc::new(table))
+        }
+
+        fn try_encode_table_provider(
+            &self,
+            _table_ref: &TableReference,
+            _node: Arc<dyn TableProvider>,
+            _buf: &mut Vec<u8>,
+        ) -> Result<(), DataFusionError> {
+            Ok(())
+        }
+    }
+
+    let query = "WITH RECURSIVE cte AS (
+        SELECT 1 as n
+        UNION ALL
+        SELECT n + 1 FROM cte WHERE n < 5
+        )
+        SELECT * FROM cte;";
+
+    let ctx = SessionContext::new();
+    let dataframe = ctx.sql(query).await.unwrap();
+    let plan = dataframe.logical_plan().clone();
+    let output = dataframe.collect().await.unwrap();
+    let extension_codec = EmptyTableCodec {};
+    let bytes =
+        logical_plan_to_bytes_with_extension_codec(&plan, 
&extension_codec).unwrap();
+
+    let ctx = SessionContext::new();
+    let logical_round_trip =
+        logical_plan_from_bytes_with_extension_codec(&bytes, &ctx, 
&extension_codec)
+            .unwrap();
+    let output_round_trip = 
ctx.sql(query).await.unwrap().collect().await.unwrap();
+
+    assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}"));
+    assert_eq!(
+        format!("{}", pretty_format_batches(&output).unwrap()),
+        format!("{}", pretty_format_batches(&output_round_trip).unwrap())

Review Comment:
   This `assert_eq` looks  incorrect. `output_round_trip` and `output` are both 
the result of `ctx.sql(query)`.
   I think `output_round_trip` should be
   ```rust
   let output_round_trip = ctx
           .execute_logical_plan(logical_round_trip)
           .await
           .unwrap()
           .collect()
           .await
           .unwrap()
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org
For additional commands, e-mail: github-h...@datafusion.apache.org

Reply via email to