This is an automated email from the ASF dual-hosted git repository.

iffyio pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-sqlparser-rs.git


The following commit(s) were added to refs/heads/main by this push:
     new a464f8e8 Improve support for cursors for SQL Server (#1831)
a464f8e8 is described below

commit a464f8e8d7a5057e4e5b8046b0f619acdf7fce74
Author: Andrew Harper <andrew.har...@veracross.com>
AuthorDate: Thu May 1 23:25:30 2025 -0400

    Improve support for cursors for SQL Server (#1831)
    
    Co-authored-by: Ifeanyi Ubah <ify1...@yahoo.com>
---
 src/ast/mod.rs            | 90 ++++++++++++++++++++++++++++++++++++++++++++---
 src/ast/spans.rs          | 31 ++++++++++++----
 src/keywords.rs           |  2 ++
 src/parser/mod.rs         | 70 ++++++++++++++++++++++++++++++++----
 src/test_utils.rs         | 20 +++++++++++
 tests/sqlparser_common.rs | 12 +++++++
 tests/sqlparser_mssql.rs  | 84 +++++++++++++++++++++++++++++++++++++++++--
 7 files changed, 289 insertions(+), 20 deletions(-)

diff --git a/src/ast/mod.rs b/src/ast/mod.rs
index b1439266..582922a3 100644
--- a/src/ast/mod.rs
+++ b/src/ast/mod.rs
@@ -2228,7 +2228,33 @@ impl fmt::Display for IfStatement {
     }
 }
 
-/// A block within a [Statement::Case] or [Statement::If]-like statement
+/// A `WHILE` statement.
+///
+/// Example:
+/// ```sql
+/// WHILE @@FETCH_STATUS = 0
+/// BEGIN
+///    FETCH NEXT FROM c1 INTO @var1, @var2;
+/// END
+/// ```
+///
+/// 
[MsSql](https://learn.microsoft.com/en-us/sql/t-sql/language-elements/while-transact-sql)
+#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
+#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
+#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
+pub struct WhileStatement {
+    pub while_block: ConditionalStatementBlock,
+}
+
+impl fmt::Display for WhileStatement {
+    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+        let WhileStatement { while_block } = self;
+        write!(f, "{while_block}")?;
+        Ok(())
+    }
+}
+
+/// A block within a [Statement::Case] or [Statement::If] or 
[Statement::While]-like statement
 ///
 /// Example 1:
 /// ```sql
@@ -2244,6 +2270,14 @@ impl fmt::Display for IfStatement {
 /// ```sql
 /// ELSE SELECT 1; SELECT 2;
 /// ```
+///
+/// Example 4:
+/// ```sql
+/// WHILE @@FETCH_STATUS = 0
+/// BEGIN
+///    FETCH NEXT FROM c1 INTO @var1, @var2;
+/// END
+/// ```
 #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
 #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
 #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
@@ -2983,6 +3017,8 @@ pub enum Statement {
     Case(CaseStatement),
     /// An `IF` statement.
     If(IfStatement),
+    /// A `WHILE` statement.
+    While(WhileStatement),
     /// A `RAISE` statement.
     Raise(RaiseStatement),
     /// ```sql
@@ -3034,6 +3070,11 @@ pub enum Statement {
         partition: Option<Box<Expr>>,
     },
     /// ```sql
+    /// OPEN cursor_name
+    /// ```
+    /// Opens a cursor.
+    Open(OpenStatement),
+    /// ```sql
     /// CLOSE
     /// ```
     /// Closes the portal underlying an open cursor.
@@ -3413,6 +3454,7 @@ pub enum Statement {
         /// Cursor name
         name: Ident,
         direction: FetchDirection,
+        position: FetchPosition,
         /// Optional, It's possible to fetch rows form cursor to the table
         into: Option<ObjectName>,
     },
@@ -4235,11 +4277,10 @@ impl fmt::Display for Statement {
             Statement::Fetch {
                 name,
                 direction,
+                position,
                 into,
             } => {
-                write!(f, "FETCH {direction} ")?;
-
-                write!(f, "IN {name}")?;
+                write!(f, "FETCH {direction} {position} {name}")?;
 
                 if let Some(into) = into {
                     write!(f, " INTO {into}")?;
@@ -4329,6 +4370,9 @@ impl fmt::Display for Statement {
             Statement::If(stmt) => {
                 write!(f, "{stmt}")
             }
+            Statement::While(stmt) => {
+                write!(f, "{stmt}")
+            }
             Statement::Raise(stmt) => {
                 write!(f, "{stmt}")
             }
@@ -4498,6 +4542,7 @@ impl fmt::Display for Statement {
                 Ok(())
             }
             Statement::Delete(delete) => write!(f, "{delete}"),
+            Statement::Open(open) => write!(f, "{open}"),
             Statement::Close { cursor } => {
                 write!(f, "CLOSE {cursor}")?;
 
@@ -6187,6 +6232,28 @@ impl fmt::Display for FetchDirection {
     }
 }
 
+/// The "position" for a FETCH statement.
+///
+/// 
[MsSql](https://learn.microsoft.com/en-us/sql/t-sql/language-elements/fetch-transact-sql)
+#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
+#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
+#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
+pub enum FetchPosition {
+    From,
+    In,
+}
+
+impl fmt::Display for FetchPosition {
+    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+        match self {
+            FetchPosition::From => f.write_str("FROM")?,
+            FetchPosition::In => f.write_str("IN")?,
+        };
+
+        Ok(())
+    }
+}
+
 /// A privilege on a database object (table, sequence, etc.).
 #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
 #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
@@ -9354,6 +9421,21 @@ pub enum ReturnStatementValue {
     Expr(Expr),
 }
 
+/// Represents an `OPEN` statement.
+#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
+#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
+#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
+pub struct OpenStatement {
+    /// Cursor name
+    pub cursor_name: Ident,
+}
+
+impl fmt::Display for OpenStatement {
+    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+        write!(f, "OPEN {}", self.cursor_name)
+    }
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;
diff --git a/src/ast/spans.rs b/src/ast/spans.rs
index 33bc0739..836f229a 100644
--- a/src/ast/spans.rs
+++ b/src/ast/spans.rs
@@ -31,13 +31,13 @@ use super::{
     FunctionArguments, GroupByExpr, HavingBound, IfStatement, IlikeSelectItem, 
Insert, Interpolate,
     InterpolateExpr, Join, JoinConstraint, JoinOperator, JsonPath, 
JsonPathElem, LateralView,
     LimitClause, MatchRecognizePattern, Measure, NamedWindowDefinition, 
ObjectName, ObjectNamePart,
-    Offset, OnConflict, OnConflictAction, OnInsert, OrderBy, OrderByExpr, 
OrderByKind, Partition,
-    PivotValueSource, ProjectionSelect, Query, RaiseStatement, 
RaiseStatementValue,
-    ReferentialAction, RenameSelectItem, ReplaceSelectElement, 
ReplaceSelectItem, Select,
-    SelectInto, SelectItem, SetExpr, SqlOption, Statement, Subscript, 
SymbolDefinition, TableAlias,
-    TableAliasColumnDef, TableConstraint, TableFactor, TableObject, 
TableOptionsClustered,
-    TableWithJoins, UpdateTableFromKind, Use, Value, Values, ViewColumnDef,
-    WildcardAdditionalOptions, With, WithFill,
+    Offset, OnConflict, OnConflictAction, OnInsert, OpenStatement, OrderBy, 
OrderByExpr,
+    OrderByKind, Partition, PivotValueSource, ProjectionSelect, Query, 
RaiseStatement,
+    RaiseStatementValue, ReferentialAction, RenameSelectItem, 
ReplaceSelectElement,
+    ReplaceSelectItem, Select, SelectInto, SelectItem, SetExpr, SqlOption, 
Statement, Subscript,
+    SymbolDefinition, TableAlias, TableAliasColumnDef, TableConstraint, 
TableFactor, TableObject,
+    TableOptionsClustered, TableWithJoins, UpdateTableFromKind, Use, Value, 
Values, ViewColumnDef,
+    WhileStatement, WildcardAdditionalOptions, With, WithFill,
 };
 
 /// Given an iterator of spans, return the [Span::union] of all spans.
@@ -339,6 +339,7 @@ impl Spanned for Statement {
             } => source.span(),
             Statement::Case(stmt) => stmt.span(),
             Statement::If(stmt) => stmt.span(),
+            Statement::While(stmt) => stmt.span(),
             Statement::Raise(stmt) => stmt.span(),
             Statement::Call(function) => function.span(),
             Statement::Copy {
@@ -365,6 +366,7 @@ impl Spanned for Statement {
                 from_query: _,
                 partition: _,
             } => Span::empty(),
+            Statement::Open(open) => open.span(),
             Statement::Close { cursor } => match cursor {
                 CloseCursor::All => Span::empty(),
                 CloseCursor::Specific { name } => name.span,
@@ -776,6 +778,14 @@ impl Spanned for IfStatement {
     }
 }
 
+impl Spanned for WhileStatement {
+    fn span(&self) -> Span {
+        let WhileStatement { while_block } = self;
+
+        while_block.span()
+    }
+}
+
 impl Spanned for ConditionalStatements {
     fn span(&self) -> Span {
         match self {
@@ -2297,6 +2307,13 @@ impl Spanned for BeginEndStatements {
     }
 }
 
+impl Spanned for OpenStatement {
+    fn span(&self) -> Span {
+        let OpenStatement { cursor_name } = self;
+        cursor_name.span
+    }
+}
+
 #[cfg(test)]
 pub mod tests {
     use crate::dialect::{Dialect, GenericDialect, SnowflakeDialect};
diff --git a/src/keywords.rs b/src/keywords.rs
index 32612ccd..bf8a1915 100644
--- a/src/keywords.rs
+++ b/src/keywords.rs
@@ -985,6 +985,7 @@ define_keywords!(
     WHEN,
     WHENEVER,
     WHERE,
+    WHILE,
     WIDTH_BUCKET,
     WINDOW,
     WITH,
@@ -1068,6 +1069,7 @@ pub const RESERVED_FOR_TABLE_ALIAS: &[Keyword] = &[
     Keyword::SAMPLE,
     Keyword::TABLESAMPLE,
     Keyword::FROM,
+    Keyword::OPEN,
 ];
 
 /// Can't be used as a column alias, so that `SELECT <expr> alias`
diff --git a/src/parser/mod.rs b/src/parser/mod.rs
index 0d74235b..cbd464c3 100644
--- a/src/parser/mod.rs
+++ b/src/parser/mod.rs
@@ -536,6 +536,10 @@ impl<'a> Parser<'a> {
                     self.prev_token();
                     self.parse_if_stmt()
                 }
+                Keyword::WHILE => {
+                    self.prev_token();
+                    self.parse_while()
+                }
                 Keyword::RAISE => {
                     self.prev_token();
                     self.parse_raise_stmt()
@@ -570,6 +574,10 @@ impl<'a> Parser<'a> {
                 Keyword::ALTER => self.parse_alter(),
                 Keyword::CALL => self.parse_call(),
                 Keyword::COPY => self.parse_copy(),
+                Keyword::OPEN => {
+                    self.prev_token();
+                    self.parse_open()
+                }
                 Keyword::CLOSE => self.parse_close(),
                 Keyword::SET => self.parse_set(),
                 Keyword::SHOW => self.parse_show(),
@@ -700,8 +708,18 @@ impl<'a> Parser<'a> {
         }))
     }
 
+    /// Parse a `WHILE` statement.
+    ///
+    /// See [Statement::While]
+    fn parse_while(&mut self) -> Result<Statement, ParserError> {
+        self.expect_keyword_is(Keyword::WHILE)?;
+        let while_block = 
self.parse_conditional_statement_block(&[Keyword::END])?;
+
+        Ok(Statement::While(WhileStatement { while_block }))
+    }
+
     /// Parses an expression and associated list of statements
-    /// belonging to a conditional statement like `IF` or `WHEN`.
+    /// belonging to a conditional statement like `IF` or `WHEN` or `WHILE`.
     ///
     /// Example:
     /// ```sql
@@ -716,6 +734,10 @@ impl<'a> Parser<'a> {
 
         let condition = match &start_token.token {
             Token::Word(w) if w.keyword == Keyword::ELSE => None,
+            Token::Word(w) if w.keyword == Keyword::WHILE => {
+                let expr = self.parse_expr()?;
+                Some(expr)
+            }
             _ => {
                 let expr = self.parse_expr()?;
                 then_token = 
Some(AttachedToken(self.expect_keyword(Keyword::THEN)?));
@@ -723,13 +745,25 @@ impl<'a> Parser<'a> {
             }
         };
 
-        let statements = self.parse_statement_list(terminal_keywords)?;
+        let conditional_statements = if self.peek_keyword(Keyword::BEGIN) {
+            let begin_token = self.expect_keyword(Keyword::BEGIN)?;
+            let statements = self.parse_statement_list(terminal_keywords)?;
+            let end_token = self.expect_keyword(Keyword::END)?;
+            ConditionalStatements::BeginEnd(BeginEndStatements {
+                begin_token: AttachedToken(begin_token),
+                statements,
+                end_token: AttachedToken(end_token),
+            })
+        } else {
+            let statements = self.parse_statement_list(terminal_keywords)?;
+            ConditionalStatements::Sequence { statements }
+        };
 
         Ok(ConditionalStatementBlock {
             start_token: AttachedToken(start_token),
             condition,
             then_token,
-            conditional_statements: ConditionalStatements::Sequence { 
statements },
+            conditional_statements,
         })
     }
 
@@ -4467,11 +4501,16 @@ impl<'a> Parser<'a> {
     ) -> Result<Vec<Statement>, ParserError> {
         let mut values = vec![];
         loop {
-            if let Token::Word(w) = &self.peek_nth_token_ref(0).token {
-                if w.quote_style.is_none() && 
terminal_keywords.contains(&w.keyword) {
-                    break;
+            match &self.peek_nth_token_ref(0).token {
+                Token::EOF => break,
+                Token::Word(w) => {
+                    if w.quote_style.is_none() && 
terminal_keywords.contains(&w.keyword) {
+                        break;
+                    }
                 }
+                _ => {}
             }
+
             values.push(self.parse_statement()?);
             self.expect_token(&Token::SemiColon)?;
         }
@@ -6644,7 +6683,15 @@ impl<'a> Parser<'a> {
             }
         };
 
-        self.expect_one_of_keywords(&[Keyword::FROM, Keyword::IN])?;
+        let position = if self.peek_keyword(Keyword::FROM) {
+            self.expect_keyword(Keyword::FROM)?;
+            FetchPosition::From
+        } else if self.peek_keyword(Keyword::IN) {
+            self.expect_keyword(Keyword::IN)?;
+            FetchPosition::In
+        } else {
+            return parser_err!("Expected FROM or IN", 
self.peek_token().span.start);
+        };
 
         let name = self.parse_identifier()?;
 
@@ -6657,6 +6704,7 @@ impl<'a> Parser<'a> {
         Ok(Statement::Fetch {
             name,
             direction,
+            position,
             into,
         })
     }
@@ -8770,6 +8818,14 @@ impl<'a> Parser<'a> {
         })
     }
 
+    /// Parse [Statement::Open]
+    fn parse_open(&mut self) -> Result<Statement, ParserError> {
+        self.expect_keyword(Keyword::OPEN)?;
+        Ok(Statement::Open(OpenStatement {
+            cursor_name: self.parse_identifier()?,
+        }))
+    }
+
     pub fn parse_close(&mut self) -> Result<Statement, ParserError> {
         let cursor = if self.parse_keyword(Keyword::ALL) {
             CloseCursor::All
diff --git a/src/test_utils.rs b/src/test_utils.rs
index 6270ac42..3c22fa91 100644
--- a/src/test_utils.rs
+++ b/src/test_utils.rs
@@ -151,6 +151,8 @@ impl TestedDialects {
     ///
     /// 2. re-serializing the result of parsing `sql` produces the same
     ///    `canonical` sql string
+    ///
+    ///  For multiple statements, use [`statements_parse_to`].
     pub fn one_statement_parses_to(&self, sql: &str, canonical: &str) -> 
Statement {
         let mut statements = self.parse_sql_statements(sql).expect(sql);
         assert_eq!(statements.len(), 1);
@@ -166,6 +168,24 @@ impl TestedDialects {
         only_statement
     }
 
+    /// The same as [`one_statement_parses_to`] but it works for a multiple 
statements
+    pub fn statements_parse_to(&self, sql: &str, canonical: &str) -> 
Vec<Statement> {
+        let statements = self.parse_sql_statements(sql).expect(sql);
+        if !canonical.is_empty() && sql != canonical {
+            assert_eq!(self.parse_sql_statements(canonical).unwrap(), 
statements);
+        } else {
+            assert_eq!(
+                sql,
+                statements
+                    .iter()
+                    .map(|s| s.to_string())
+                    .collect::<Vec<_>>()
+                    .join("; ")
+            );
+        }
+        statements
+    }
+
     /// Ensures that `sql` parses as an [`Expr`], and that
     /// re-serializing the parse result produces canonical
     pub fn expr_parses_to(&self, sql: &str, canonical: &str) -> Expr {
diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs
index 6d99929d..1ddf3f92 100644
--- a/tests/sqlparser_common.rs
+++ b/tests/sqlparser_common.rs
@@ -15187,3 +15187,15 @@ fn parse_return() {
 
     let _ = all_dialects().verified_stmt("RETURN 1");
 }
+
+#[test]
+fn test_open() {
+    let open_cursor = "OPEN Employee_Cursor";
+    let stmt = all_dialects().verified_stmt(open_cursor);
+    assert_eq!(
+        stmt,
+        Statement::Open(OpenStatement {
+            cursor_name: Ident::new("Employee_Cursor"),
+        })
+    );
+}
diff --git a/tests/sqlparser_mssql.rs b/tests/sqlparser_mssql.rs
index b2d5210c..88e7a1f1 100644
--- a/tests/sqlparser_mssql.rs
+++ b/tests/sqlparser_mssql.rs
@@ -23,7 +23,8 @@
 mod test_utils;
 
 use helpers::attached_token::AttachedToken;
-use sqlparser::tokenizer::{Location, Span};
+use sqlparser::keywords::Keyword;
+use sqlparser::tokenizer::{Location, Span, Token, TokenWithSpan, Word};
 use test_utils::*;
 
 use sqlparser::ast::DataType::{Int, Text, Varbinary};
@@ -223,7 +224,7 @@ fn parse_create_function() {
                     value: Some(ReturnStatementValue::Expr(Expr::Value(
                         (number("1")).with_empty_span()
                     ))),
-                }),],
+                })],
                 end_token: AttachedToken::empty(),
             })),
             behavior: None,
@@ -1397,6 +1398,85 @@ fn parse_mssql_declare() {
     let _ = ms().verified_stmt(declare_cursor_for_select);
 }
 
+#[test]
+fn test_mssql_cursor() {
+    let full_cursor_usage = "\
+        DECLARE Employee_Cursor CURSOR FOR \
+        SELECT LastName, FirstName \
+        FROM AdventureWorks2022.HumanResources.vEmployee \
+        WHERE LastName LIKE 'B%'; \
+        \
+        OPEN Employee_Cursor; \
+        \
+        FETCH NEXT FROM Employee_Cursor; \
+        \
+        WHILE @@FETCH_STATUS = 0 \
+        BEGIN \
+            FETCH NEXT FROM Employee_Cursor; \
+        END; \
+        \
+        CLOSE Employee_Cursor; \
+        DEALLOCATE Employee_Cursor\
+    ";
+    let _ = ms().statements_parse_to(full_cursor_usage, "");
+}
+
+#[test]
+fn test_mssql_while_statement() {
+    let while_single_statement = "WHILE 1 = 0 PRINT 'Hello World';";
+    let stmt = ms().verified_stmt(while_single_statement);
+    assert_eq!(
+        stmt,
+        Statement::While(sqlparser::ast::WhileStatement {
+            while_block: ConditionalStatementBlock {
+                start_token: AttachedToken(TokenWithSpan {
+                    token: Token::Word(Word {
+                        value: "WHILE".to_string(),
+                        quote_style: None,
+                        keyword: Keyword::WHILE
+                    }),
+                    span: Span::empty()
+                }),
+                condition: Some(Expr::BinaryOp {
+                    left: Box::new(Expr::Value(
+                        (Value::Number("1".parse().unwrap(), 
false)).with_empty_span()
+                    )),
+                    op: BinaryOperator::Eq,
+                    right: Box::new(Expr::Value(
+                        (Value::Number("0".parse().unwrap(), 
false)).with_empty_span()
+                    )),
+                }),
+                then_token: None,
+                conditional_statements: ConditionalStatements::Sequence {
+                    statements: vec![Statement::Print(PrintStatement {
+                        message: Box::new(Expr::Value(
+                            (Value::SingleQuotedString("Hello 
World".to_string()))
+                                .with_empty_span()
+                        )),
+                    })],
+                }
+            }
+        })
+    );
+
+    let while_begin_end = "\
+        WHILE @@FETCH_STATUS = 0 \
+        BEGIN \
+            FETCH NEXT FROM Employee_Cursor; \
+        END\
+    ";
+    let _ = ms().verified_stmt(while_begin_end);
+
+    let while_begin_end_multiple_statements = "\
+        WHILE @@FETCH_STATUS = 0 \
+        BEGIN \
+            FETCH NEXT FROM Employee_Cursor; \
+            PRINT 'Hello World'; \
+        END\
+    ";
+    let _ = ms().verified_stmt(while_begin_end_multiple_statements);
+}
+
 #[test]
 fn test_parse_raiserror() {
     let sql = r#"RAISERROR('This is a test', 16, 1)"#;


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org
For additional commands, e-mail: commits-h...@datafusion.apache.org

Reply via email to