This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-sqlparser-rs.git


The following commit(s) were added to refs/heads/main by this push:
     new 84e82e6e Add `#[recursive]` (#1522)
84e82e6e is described below

commit 84e82e6e2ebb95f573ed258865752217f7ae5a6f
Author: Dmitrii Blaginin <[email protected]>
AuthorDate: Thu Dec 19 22:17:20 2024 +0300

    Add `#[recursive]` (#1522)
    
    Co-authored-by: Ifeanyi Ubah <[email protected]>
---
 Cargo.toml                                 |  5 +++-
 README.md                                  |  2 +-
 derive/src/lib.rs                          |  3 +++
 sqlparser_bench/benches/sqlparser_bench.rs | 40 ++++++++++++++++++++++++++++++
 src/ast/mod.rs                             |  1 +
 src/ast/visitor.rs                         | 25 +++++++++++++++++++
 src/parser/mod.rs                          |  6 +++++
 tests/sqlparser_common.rs                  | 13 ++++++++++
 8 files changed, 93 insertions(+), 2 deletions(-)

diff --git a/Cargo.toml b/Cargo.toml
index 301a59c5..8ff0ceb5 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -37,8 +37,9 @@ name = "sqlparser"
 path = "src/lib.rs"
 
 [features]
-default = ["std"]
+default = ["std", "recursive-protection"]
 std = []
+recursive-protection = ["std", "recursive"]
 # Enable JSON output in the `cli` example:
 json_example = ["serde_json", "serde"]
 visitor = ["sqlparser_derive"]
@@ -46,6 +47,8 @@ visitor = ["sqlparser_derive"]
 [dependencies]
 bigdecimal = { version = "0.4.1", features = ["serde"], optional = true }
 log = "0.4"
+recursive = { version = "0.1.1", optional = true}
+
 serde = { version = "1.0", features = ["derive"], optional = true }
 # serde_json is only used in examples/cli, but we have to put it outside
 # of dev-dependencies because of
diff --git a/README.md b/README.md
index fd676d11..997aec58 100644
--- a/README.md
+++ b/README.md
@@ -63,7 +63,7 @@ The following optional [crate  
features](https://doc.rust-lang.org/cargo/referen
 
 * `serde`: Adds [Serde](https://serde.rs/) support by implementing  
`Serialize` and `Deserialize` for all AST nodes.
 * `visitor`: Adds a `Visitor` capable of recursively walking the AST tree.
-
+* `recursive-protection` (enabled by default), uses 
[recursive](https://docs.rs/recursive/latest/recursive/) for stack overflow 
protection. 
 
 ## Syntax vs Semantics
 
diff --git a/derive/src/lib.rs b/derive/src/lib.rs
index b8162331..08c5c5db 100644
--- a/derive/src/lib.rs
+++ b/derive/src/lib.rs
@@ -78,7 +78,10 @@ fn derive_visit(input: proc_macro::TokenStream, visit_type: 
&VisitType) -> proc_
 
     let expanded = quote! {
         // The generated impl.
+        // Note that it uses [`recursive::recursive`] to protect from stack 
overflow.
+        // See tests in 
https://github.com/apache/datafusion-sqlparser-rs/pull/1522/ for more info.
         impl #impl_generics sqlparser::ast::#visit_trait for #name 
#ty_generics #where_clause {
+             #[cfg_attr(feature = "recursive-protection", 
recursive::recursive)]
             fn visit<V: sqlparser::ast::#visitor_trait>(
                 &#modifier self,
                 visitor: &mut V
diff --git a/sqlparser_bench/benches/sqlparser_bench.rs 
b/sqlparser_bench/benches/sqlparser_bench.rs
index 32a6da1b..74cac5c9 100644
--- a/sqlparser_bench/benches/sqlparser_bench.rs
+++ b/sqlparser_bench/benches/sqlparser_bench.rs
@@ -42,6 +42,46 @@ fn basic_queries(c: &mut Criterion) {
     group.bench_function("sqlparser::with_select", |b| {
         b.iter(|| Parser::parse_sql(&dialect, with_query).unwrap());
     });
+
+    let large_statement = {
+        let expressions = (0..1000)
+            .map(|n| format!("FN_{}(COL_{})", n, n))
+            .collect::<Vec<_>>()
+            .join(", ");
+        let tables = (0..1000)
+            .map(|n| format!("TABLE_{}", n))
+            .collect::<Vec<_>>()
+            .join(" JOIN ");
+        let where_condition = (0..1000)
+            .map(|n| format!("COL_{} = {}", n, n))
+            .collect::<Vec<_>>()
+            .join(" OR ");
+        let order_condition = (0..1000)
+            .map(|n| format!("COL_{} DESC", n))
+            .collect::<Vec<_>>()
+            .join(", ");
+
+        format!(
+            "SELECT {} FROM {} WHERE {} ORDER BY {}",
+            expressions, tables, where_condition, order_condition
+        )
+    };
+
+    group.bench_function("parse_large_statement", |b| {
+        b.iter(|| Parser::parse_sql(&dialect, 
criterion::black_box(large_statement.as_str())));
+    });
+
+    let large_statement = Parser::parse_sql(&dialect, large_statement.as_str())
+        .unwrap()
+        .pop()
+        .unwrap();
+
+    group.bench_function("format_large_statement", |b| {
+        b.iter(|| {
+            let formatted_query = large_statement.to_string();
+            assert_eq!(formatted_query, large_statement);
+        });
+    });
 }
 
 criterion_group!(benches, basic_queries);
diff --git a/src/ast/mod.rs b/src/ast/mod.rs
index 39b97463..3157a060 100644
--- a/src/ast/mod.rs
+++ b/src/ast/mod.rs
@@ -1291,6 +1291,7 @@ impl fmt::Display for CastFormat {
 }
 
 impl fmt::Display for Expr {
+    #[cfg_attr(feature = "recursive-protection", recursive::recursive)]
     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
         match self {
             Expr::Identifier(s) => write!(f, "{s}"),
diff --git a/src/ast/visitor.rs b/src/ast/visitor.rs
index f7562b66..c824ad2f 100644
--- a/src/ast/visitor.rs
+++ b/src/ast/visitor.rs
@@ -894,4 +894,29 @@ mod tests {
             assert_eq!(actual, expected)
         }
     }
+
+    struct QuickVisitor; // [`TestVisitor`] is too slow to iterate over 
thousands of nodes
+
+    impl Visitor for QuickVisitor {
+        type Break = ();
+    }
+
+    #[test]
+    fn overflow() {
+        let cond = (0..1000)
+            .map(|n| format!("X = {}", n))
+            .collect::<Vec<_>>()
+            .join(" OR ");
+        let sql = format!("SELECT x where {0}", cond);
+
+        let dialect = GenericDialect {};
+        let tokens = Tokenizer::new(&dialect, 
sql.as_str()).tokenize().unwrap();
+        let s = Parser::new(&dialect)
+            .with_tokens(tokens)
+            .parse_statement()
+            .unwrap();
+
+        let mut visitor = QuickVisitor {};
+        s.visit(&mut visitor);
+    }
 }
diff --git a/src/parser/mod.rs b/src/parser/mod.rs
index df4af538..e809ffba 100644
--- a/src/parser/mod.rs
+++ b/src/parser/mod.rs
@@ -73,6 +73,9 @@ mod recursion {
     /// Note: Uses an [`std::rc::Rc`] and [`std::cell::Cell`] in order to 
satisfy the Rust
     /// borrow checker so the automatic [`DepthGuard`] decrement a
     /// reference to the counter.
+    ///
+    /// Note: when "recursive-protection" feature is enabled, this crate uses 
additional stack overflow protection
+    /// for some of its recursive methods. See [`recursive::recursive`] for 
more information.
     pub(crate) struct RecursionCounter {
         remaining_depth: Rc<Cell<usize>>,
     }
@@ -326,6 +329,9 @@ impl<'a> Parser<'a> {
     /// # Ok(())
     /// # }
     /// ```
+    ///
+    /// Note: when "recursive-protection" feature is enabled, this crate uses 
additional stack overflow protection
+    //  for some of its recursive methods. See [`recursive::recursive`] for 
more information.
     pub fn with_recursion_limit(mut self, recursion_limit: usize) -> Self {
         self.recursion_counter = RecursionCounter::new(recursion_limit);
         self
diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs
index 507c9c77..e7e2e3bc 100644
--- a/tests/sqlparser_common.rs
+++ b/tests/sqlparser_common.rs
@@ -12433,3 +12433,16 @@ fn test_table_sample() {
     dialects.verified_stmt("SELECT * FROM tbl AS t TABLESAMPLE SYSTEM (50)");
     dialects.verified_stmt("SELECT * FROM tbl AS t TABLESAMPLE SYSTEM (50) 
REPEATABLE (10)");
 }
+
+#[test]
+fn overflow() {
+    let expr = std::iter::repeat("1")
+        .take(1000)
+        .collect::<Vec<_>>()
+        .join(" + ");
+    let sql = format!("SELECT {}", expr);
+
+    let mut statements = Parser::parse_sql(&GenericDialect {}, 
sql.as_str()).unwrap();
+    let statement = statements.pop().unwrap();
+    assert_eq!(statement.to_string(), sql);
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to