This is an automated email from the ASF dual-hosted git repository.

iffyio pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-sqlparser-rs.git


The following commit(s) were added to refs/heads/main by this push:
     new 862e887a Add `CASE` and `IF` statement support (#1741)
862e887a is described below

commit 862e887a66cf8ede2dc3a641db7cdcf52b061b76
Author: Ifeanyi Ubah <[email protected]>
AuthorDate: Fri Mar 14 07:49:25 2025 +0100

    Add `CASE` and `IF` statement support (#1741)
---
 src/ast/mod.rs            | 198 ++++++++++++++++++++++++++++++++++++++++++++--
 src/ast/spans.rs          |  78 ++++++++++++++----
 src/keywords.rs           |   1 +
 src/parser/mod.rs         | 104 ++++++++++++++++++++++++
 tests/sqlparser_common.rs | 114 ++++++++++++++++++++++++++
 5 files changed, 473 insertions(+), 22 deletions(-)

diff --git a/src/ast/mod.rs b/src/ast/mod.rs
index 4c0ffea9..66fd4c6f 100644
--- a/src/ast/mod.rs
+++ b/src/ast/mod.rs
@@ -151,6 +151,15 @@ where
     DisplaySeparated { slice, sep: ", " }
 }
 
+/// Writes the given statements to the formatter, each ending with
+/// a semicolon and space separated.
+fn format_statement_list(f: &mut fmt::Formatter, statements: &[Statement]) -> 
fmt::Result {
+    write!(f, "{}", display_separated(statements, "; "))?;
+    // We manually insert semicolon for the last statement,
+    // since display_separated doesn't handle that case.
+    write!(f, ";")
+}
+
 /// An identifier, decomposed into its value or character data and the quote 
style.
 #[derive(Debug, Clone, PartialOrd, Ord)]
 #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
@@ -2080,6 +2089,173 @@ pub enum Password {
     NullPassword,
 }
 
+/// A `CASE` statement.
+///
+/// Examples:
+/// ```sql
+/// CASE
+///     WHEN EXISTS(SELECT 1)
+///         THEN SELECT 1 FROM T;
+///     WHEN EXISTS(SELECT 2)
+///         THEN SELECT 1 FROM U;
+///     ELSE
+///         SELECT 1 FROM V;
+/// END CASE;
+/// ```
+///
+/// 
[BigQuery](https://cloud.google.com/bigquery/docs/reference/standard-sql/procedural-language#case_search_expression)
+/// 
[Snowflake](https://docs.snowflake.com/en/sql-reference/snowflake-scripting/case)
+#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
+#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
+#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
+pub struct CaseStatement {
+    pub match_expr: Option<Expr>,
+    pub when_blocks: Vec<ConditionalStatements>,
+    pub else_block: Option<Vec<Statement>>,
+    /// TRUE if the statement ends with `END CASE` (vs `END`).
+    pub has_end_case: bool,
+}
+
+impl fmt::Display for CaseStatement {
+    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+        let CaseStatement {
+            match_expr,
+            when_blocks,
+            else_block,
+            has_end_case,
+        } = self;
+
+        write!(f, "CASE")?;
+
+        if let Some(expr) = match_expr {
+            write!(f, " {expr}")?;
+        }
+
+        if !when_blocks.is_empty() {
+            write!(f, " {}", display_separated(when_blocks, " "))?;
+        }
+
+        if let Some(else_block) = else_block {
+            write!(f, " ELSE ")?;
+            format_statement_list(f, else_block)?;
+        }
+
+        write!(f, " END")?;
+        if *has_end_case {
+            write!(f, " CASE")?;
+        }
+
+        Ok(())
+    }
+}
+
+/// An `IF` statement.
+///
+/// Examples:
+/// ```sql
+/// IF TRUE THEN
+///     SELECT 1;
+///     SELECT 2;
+/// ELSEIF TRUE THEN
+///     SELECT 3;
+/// ELSE
+///     SELECT 4;
+/// END IF
+/// ```
+///
+/// 
[BigQuery](https://cloud.google.com/bigquery/docs/reference/standard-sql/procedural-language#if)
+/// 
[Snowflake](https://docs.snowflake.com/en/sql-reference/snowflake-scripting/if)
+#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
+#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
+#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
+pub struct IfStatement {
+    pub if_block: ConditionalStatements,
+    pub elseif_blocks: Vec<ConditionalStatements>,
+    pub else_block: Option<Vec<Statement>>,
+}
+
+impl fmt::Display for IfStatement {
+    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+        let IfStatement {
+            if_block,
+            elseif_blocks,
+            else_block,
+        } = self;
+
+        write!(f, "{if_block}")?;
+
+        if !elseif_blocks.is_empty() {
+            write!(f, " {}", display_separated(elseif_blocks, " "))?;
+        }
+
+        if let Some(else_block) = else_block {
+            write!(f, " ELSE ")?;
+            format_statement_list(f, else_block)?;
+        }
+
+        write!(f, " END IF")?;
+
+        Ok(())
+    }
+}
+
+/// Represents a type of [ConditionalStatements]
+#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
+#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
+#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
+pub enum ConditionalStatementKind {
+    /// `WHEN <condition> THEN <statements>`
+    When,
+    /// `IF <condition> THEN <statements>`
+    If,
+    /// `ELSEIF <condition> THEN <statements>`
+    ElseIf,
+}
+
+/// A block within a [Statement::Case] or [Statement::If]-like statement
+///
+/// Examples:
+/// ```sql
+/// WHEN EXISTS(SELECT 1) THEN SELECT 1;
+///
+/// IF TRUE THEN SELECT 1; SELECT 2;
+/// ```
+#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
+#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
+#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
+pub struct ConditionalStatements {
+    /// The condition expression.
+    pub condition: Expr,
+    /// Statement list of the `THEN` clause.
+    pub statements: Vec<Statement>,
+    pub kind: ConditionalStatementKind,
+}
+
+impl fmt::Display for ConditionalStatements {
+    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+        let ConditionalStatements {
+            condition: expr,
+            statements,
+            kind,
+        } = self;
+
+        let kind = match kind {
+            ConditionalStatementKind::When => "WHEN",
+            ConditionalStatementKind::If => "IF",
+            ConditionalStatementKind::ElseIf => "ELSEIF",
+        };
+
+        write!(f, "{kind} {expr} THEN")?;
+
+        if !statements.is_empty() {
+            write!(f, " ")?;
+            format_statement_list(f, statements)?;
+        }
+
+        Ok(())
+    }
+}
+
 /// Represents an expression assignment within a variable `DECLARE` statement.
 ///
 /// Examples:
@@ -2647,6 +2823,10 @@ pub enum Statement {
         file_format: Option<FileFormat>,
         source: Box<Query>,
     },
+    /// A `CASE` statement.
+    Case(CaseStatement),
+    /// An `IF` statement.
+    If(IfStatement),
     /// ```sql
     /// CALL <function>
     /// ```
@@ -3940,6 +4120,12 @@ impl fmt::Display for Statement {
                 }
                 Ok(())
             }
+            Statement::Case(stmt) => {
+                write!(f, "{stmt}")
+            }
+            Statement::If(stmt) => {
+                write!(f, "{stmt}")
+            }
             Statement::AttachDatabase {
                 schema_name,
                 database_file_name,
@@ -4942,18 +5128,14 @@ impl fmt::Display for Statement {
                     write!(f, " {}", display_comma_separated(modes))?;
                 }
                 if !statements.is_empty() {
-                    write!(f, " {}", display_separated(statements, "; "))?;
-                    // We manually insert semicolon for the last statement,
-                    // since display_separated doesn't handle that case.
-                    write!(f, ";")?;
+                    write!(f, " ")?;
+                    format_statement_list(f, statements)?;
                 }
                 if let Some(exception_statements) = exception_statements {
                     write!(f, " EXCEPTION WHEN ERROR THEN")?;
                     if !exception_statements.is_empty() {
-                        write!(f, " {}", 
display_separated(exception_statements, "; "))?;
-                        // We manually insert semicolon for the last statement,
-                        // since display_separated doesn't handle that case.
-                        write!(f, ";")?;
+                        write!(f, " ")?;
+                        format_statement_list(f, exception_statements)?;
                     }
                 }
                 if *has_end_keyword {
diff --git a/src/ast/spans.rs b/src/ast/spans.rs
index a4f5eb46..0ee11f23 100644
--- a/src/ast/spans.rs
+++ b/src/ast/spans.rs
@@ -22,20 +22,21 @@ use crate::tokenizer::Span;
 
 use super::{
     dcl::SecondaryRoles, value::ValueWithSpan, AccessExpr, 
AlterColumnOperation,
-    AlterIndexOperation, AlterTableOperation, Array, Assignment, 
AssignmentTarget, CloseCursor,
-    ClusteredIndex, ColumnDef, ColumnOption, ColumnOptionDef, ConflictTarget, 
ConnectBy,
-    ConstraintCharacteristics, CopySource, CreateIndex, CreateTable, 
CreateTableOptions, Cte,
-    Delete, DoUpdate, ExceptSelectItem, ExcludeSelectItem, Expr, 
ExprWithAlias, Fetch, FromTable,
-    Function, FunctionArg, FunctionArgExpr, FunctionArgumentClause, 
FunctionArgumentList,
-    FunctionArguments, GroupByExpr, HavingBound, IlikeSelectItem, Insert, 
Interpolate,
-    InterpolateExpr, Join, JoinConstraint, JoinOperator, JsonPath, 
JsonPathElem, LateralView,
-    LimitClause, MatchRecognizePattern, Measure, NamedWindowDefinition, 
ObjectName, ObjectNamePart,
-    Offset, OnConflict, OnConflictAction, OnInsert, OrderBy, OrderByExpr, 
OrderByKind, Partition,
-    PivotValueSource, ProjectionSelect, Query, ReferentialAction, 
RenameSelectItem,
-    ReplaceSelectElement, ReplaceSelectItem, Select, SelectInto, SelectItem, 
SetExpr, SqlOption,
-    Statement, Subscript, SymbolDefinition, TableAlias, TableAliasColumnDef, 
TableConstraint,
-    TableFactor, TableObject, TableOptionsClustered, TableWithJoins, 
UpdateTableFromKind, Use,
-    Value, Values, ViewColumnDef, WildcardAdditionalOptions, With, WithFill,
+    AlterIndexOperation, AlterTableOperation, Array, Assignment, 
AssignmentTarget, CaseStatement,
+    CloseCursor, ClusteredIndex, ColumnDef, ColumnOption, ColumnOptionDef, 
ConditionalStatements,
+    ConflictTarget, ConnectBy, ConstraintCharacteristics, CopySource, 
CreateIndex, CreateTable,
+    CreateTableOptions, Cte, Delete, DoUpdate, ExceptSelectItem, 
ExcludeSelectItem, Expr,
+    ExprWithAlias, Fetch, FromTable, Function, FunctionArg, FunctionArgExpr,
+    FunctionArgumentClause, FunctionArgumentList, FunctionArguments, 
GroupByExpr, HavingBound,
+    IfStatement, IlikeSelectItem, Insert, Interpolate, InterpolateExpr, Join, 
JoinConstraint,
+    JoinOperator, JsonPath, JsonPathElem, LateralView, LimitClause, 
MatchRecognizePattern, Measure,
+    NamedWindowDefinition, ObjectName, ObjectNamePart, Offset, OnConflict, 
OnConflictAction,
+    OnInsert, OrderBy, OrderByExpr, OrderByKind, Partition, PivotValueSource, 
ProjectionSelect,
+    Query, ReferentialAction, RenameSelectItem, ReplaceSelectElement, 
ReplaceSelectItem, Select,
+    SelectInto, SelectItem, SetExpr, SqlOption, Statement, Subscript, 
SymbolDefinition, TableAlias,
+    TableAliasColumnDef, TableConstraint, TableFactor, TableObject, 
TableOptionsClustered,
+    TableWithJoins, UpdateTableFromKind, Use, Value, Values, ViewColumnDef,
+    WildcardAdditionalOptions, With, WithFill,
 };
 
 /// Given an iterator of spans, return the [Span::union] of all spans.
@@ -334,6 +335,8 @@ impl Spanned for Statement {
                 file_format: _,
                 source,
             } => source.span(),
+            Statement::Case(stmt) => stmt.span(),
+            Statement::If(stmt) => stmt.span(),
             Statement::Call(function) => function.span(),
             Statement::Copy {
                 source,
@@ -732,6 +735,53 @@ impl Spanned for CreateIndex {
     }
 }
 
+impl Spanned for CaseStatement {
+    fn span(&self) -> Span {
+        let CaseStatement {
+            match_expr,
+            when_blocks,
+            else_block,
+            has_end_case: _,
+        } = self;
+
+        union_spans(
+            match_expr
+                .iter()
+                .map(|e| e.span())
+                .chain(when_blocks.iter().map(|b| b.span()))
+                .chain(else_block.iter().flat_map(|e| e.iter().map(|s| 
s.span()))),
+        )
+    }
+}
+
+impl Spanned for IfStatement {
+    fn span(&self) -> Span {
+        let IfStatement {
+            if_block,
+            elseif_blocks,
+            else_block,
+        } = self;
+
+        union_spans(
+            iter::once(if_block.span())
+                .chain(elseif_blocks.iter().map(|b| b.span()))
+                .chain(else_block.iter().flat_map(|e| e.iter().map(|s| 
s.span()))),
+        )
+    }
+}
+
+impl Spanned for ConditionalStatements {
+    fn span(&self) -> Span {
+        let ConditionalStatements {
+            condition,
+            statements,
+            kind: _,
+        } = self;
+
+        
union_spans(iter::once(condition.span()).chain(statements.iter().map(|s| 
s.span())))
+    }
+}
+
 /// # partial span
 ///
 /// Missing spans:
diff --git a/src/keywords.rs b/src/keywords.rs
index 195bbb17..47da1009 100644
--- a/src/keywords.rs
+++ b/src/keywords.rs
@@ -297,6 +297,7 @@ define_keywords!(
     ELEMENT,
     ELEMENTS,
     ELSE,
+    ELSEIF,
     EMPTY,
     ENABLE,
     ENABLE_SCHEMA_EVOLUTION,
diff --git a/src/parser/mod.rs b/src/parser/mod.rs
index 60e1c146..3adfe55e 100644
--- a/src/parser/mod.rs
+++ b/src/parser/mod.rs
@@ -528,6 +528,14 @@ impl<'a> Parser<'a> {
                 Keyword::DESCRIBE => 
self.parse_explain(DescribeAlias::Describe),
                 Keyword::EXPLAIN => self.parse_explain(DescribeAlias::Explain),
                 Keyword::ANALYZE => self.parse_analyze(),
+                Keyword::CASE => {
+                    self.prev_token();
+                    self.parse_case_stmt()
+                }
+                Keyword::IF => {
+                    self.prev_token();
+                    self.parse_if_stmt()
+                }
                 Keyword::SELECT | Keyword::WITH | Keyword::VALUES | 
Keyword::FROM => {
                     self.prev_token();
                     self.parse_query().map(Statement::Query)
@@ -615,6 +623,102 @@ impl<'a> Parser<'a> {
         }
     }
 
+    /// Parse a `CASE` statement.
+    ///
+    /// See [Statement::Case]
+    pub fn parse_case_stmt(&mut self) -> Result<Statement, ParserError> {
+        self.expect_keyword_is(Keyword::CASE)?;
+
+        let match_expr = if self.peek_keyword(Keyword::WHEN) {
+            None
+        } else {
+            Some(self.parse_expr()?)
+        };
+
+        self.expect_keyword_is(Keyword::WHEN)?;
+        let when_blocks = self.parse_keyword_separated(Keyword::WHEN, |parser| 
{
+            parser.parse_conditional_statements(
+                ConditionalStatementKind::When,
+                &[Keyword::WHEN, Keyword::ELSE, Keyword::END],
+            )
+        })?;
+
+        let else_block = if self.parse_keyword(Keyword::ELSE) {
+            Some(self.parse_statement_list(&[Keyword::END])?)
+        } else {
+            None
+        };
+
+        self.expect_keyword_is(Keyword::END)?;
+        let has_end_case = self.parse_keyword(Keyword::CASE);
+
+        Ok(Statement::Case(CaseStatement {
+            match_expr,
+            when_blocks,
+            else_block,
+            has_end_case,
+        }))
+    }
+
+    /// Parse an `IF` statement.
+    ///
+    /// See [Statement::If]
+    pub fn parse_if_stmt(&mut self) -> Result<Statement, ParserError> {
+        self.expect_keyword_is(Keyword::IF)?;
+        let if_block = self.parse_conditional_statements(
+            ConditionalStatementKind::If,
+            &[Keyword::ELSE, Keyword::ELSEIF, Keyword::END],
+        )?;
+
+        let elseif_blocks = if self.parse_keyword(Keyword::ELSEIF) {
+            self.parse_keyword_separated(Keyword::ELSEIF, |parser| {
+                parser.parse_conditional_statements(
+                    ConditionalStatementKind::ElseIf,
+                    &[Keyword::ELSEIF, Keyword::ELSE, Keyword::END],
+                )
+            })?
+        } else {
+            vec![]
+        };
+
+        let else_block = if self.parse_keyword(Keyword::ELSE) {
+            Some(self.parse_statement_list(&[Keyword::END])?)
+        } else {
+            None
+        };
+
+        self.expect_keywords(&[Keyword::END, Keyword::IF])?;
+
+        Ok(Statement::If(IfStatement {
+            if_block,
+            elseif_blocks,
+            else_block,
+        }))
+    }
+
+    /// Parses an expression and associated list of statements
+    /// belonging to a conditional statement like `IF` or `WHEN`.
+    ///
+    /// Example:
+    /// ```sql
+    /// IF condition THEN statement1; statement2;
+    /// ```
+    fn parse_conditional_statements(
+        &mut self,
+        kind: ConditionalStatementKind,
+        terminal_keywords: &[Keyword],
+    ) -> Result<ConditionalStatements, ParserError> {
+        let condition = self.parse_expr()?;
+        self.expect_keyword_is(Keyword::THEN)?;
+        let statements = self.parse_statement_list(terminal_keywords)?;
+
+        Ok(ConditionalStatements {
+            condition,
+            statements,
+            kind,
+        })
+    }
+
     pub fn parse_comment(&mut self) -> Result<Statement, ParserError> {
         let if_exists = self.parse_keywords(&[Keyword::IF, Keyword::EXISTS]);
 
diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs
index d0edafb7..8c9cae83 100644
--- a/tests/sqlparser_common.rs
+++ b/tests/sqlparser_common.rs
@@ -14179,6 +14179,120 @@ fn test_visit_order() {
     );
 }
 
+#[test]
+fn parse_case_statement() {
+    let sql = "CASE 1 WHEN 2 THEN SELECT 1; SELECT 2; ELSE SELECT 3; END CASE";
+    let Statement::Case(stmt) = verified_stmt(sql) else {
+        unreachable!()
+    };
+
+    assert_eq!(Some(Expr::value(number("1"))), stmt.match_expr);
+    assert_eq!(Expr::value(number("2")), stmt.when_blocks[0].condition);
+    assert_eq!(2, stmt.when_blocks[0].statements.len());
+    assert_eq!(1, stmt.else_block.unwrap().len());
+
+    verified_stmt(concat!(
+        "CASE 1",
+        " WHEN a THEN",
+        " SELECT 1; SELECT 2; SELECT 3;",
+        " WHEN b THEN",
+        " SELECT 4; SELECT 5;",
+        " ELSE",
+        " SELECT 7; SELECT 8;",
+        " END CASE"
+    ));
+    verified_stmt(concat!(
+        "CASE 1",
+        " WHEN a THEN",
+        " SELECT 1; SELECT 2; SELECT 3;",
+        " WHEN b THEN",
+        " SELECT 4; SELECT 5;",
+        " END CASE"
+    ));
+    verified_stmt(concat!(
+        "CASE 1",
+        " WHEN a THEN",
+        " SELECT 1; SELECT 2; SELECT 3;",
+        " END CASE"
+    ));
+    verified_stmt(concat!(
+        "CASE 1",
+        " WHEN a THEN",
+        " SELECT 1; SELECT 2; SELECT 3;",
+        " END"
+    ));
+
+    assert_eq!(
+        ParserError::ParserError("Expected: THEN, found: END".to_string()),
+        parse_sql_statements("CASE 1 WHEN a END").unwrap_err()
+    );
+    assert_eq!(
+        ParserError::ParserError("Expected: WHEN, found: ELSE".to_string()),
+        parse_sql_statements("CASE 1 ELSE SELECT 1; END").unwrap_err()
+    );
+}
+
+#[test]
+fn parse_if_statement() {
+    let sql = "IF 1 THEN SELECT 1; ELSEIF 2 THEN SELECT 2; ELSE SELECT 3; END 
IF";
+    let Statement::If(stmt) = verified_stmt(sql) else {
+        unreachable!()
+    };
+    assert_eq!(Expr::value(number("1")), stmt.if_block.condition);
+    assert_eq!(Expr::value(number("2")), stmt.elseif_blocks[0].condition);
+    assert_eq!(1, stmt.else_block.unwrap().len());
+
+    verified_stmt(concat!(
+        "IF 1 THEN",
+        " SELECT 1;",
+        " SELECT 2;",
+        " SELECT 3;",
+        " ELSEIF 2 THEN",
+        " SELECT 4;",
+        " SELECT 5;",
+        " ELSEIF 3 THEN",
+        " SELECT 6;",
+        " SELECT 7;",
+        " ELSE",
+        " SELECT 8;",
+        " SELECT 9;",
+        " END IF"
+    ));
+    verified_stmt(concat!(
+        "IF 1 THEN",
+        " SELECT 1;",
+        " SELECT 2;",
+        " ELSE",
+        " SELECT 3;",
+        " SELECT 4;",
+        " END IF"
+    ));
+    verified_stmt(concat!(
+        "IF 1 THEN",
+        " SELECT 1;",
+        " SELECT 2;",
+        " SELECT 3;",
+        " ELSEIF 2 THEN",
+        " SELECT 3;",
+        " SELECT 4;",
+        " END IF"
+    ));
+    verified_stmt(concat!("IF 1 THEN", " SELECT 1;", " SELECT 2;", " END IF"));
+    verified_stmt(concat!(
+        "IF (1) THEN",
+        " SELECT 1;",
+        " SELECT 2;",
+        " END IF"
+    ));
+    verified_stmt("IF 1 THEN END IF");
+    verified_stmt("IF 1 THEN SELECT 1; ELSEIF 1 THEN END IF");
+
+    assert_eq!(
+        ParserError::ParserError("Expected: IF, found: EOF".to_string()),
+        parse_sql_statements("IF 1 THEN SELECT 1; ELSEIF 1 THEN SELECT 2; 
END").unwrap_err()
+    );
+}
+
 #[test]
 fn test_lambdas() {
     let dialects = all_dialects_where(|d| d.supports_lambda_functions());


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to