This is an automated email from the ASF dual-hosted git repository.
iffyio pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-sqlparser-rs.git
The following commit(s) were added to refs/heads/main by this push:
new 862e887a Add `CASE` and `IF` statement support (#1741)
862e887a is described below
commit 862e887a66cf8ede2dc3a641db7cdcf52b061b76
Author: Ifeanyi Ubah <[email protected]>
AuthorDate: Fri Mar 14 07:49:25 2025 +0100
Add `CASE` and `IF` statement support (#1741)
---
src/ast/mod.rs | 198 ++++++++++++++++++++++++++++++++++++++++++++--
src/ast/spans.rs | 78 ++++++++++++++----
src/keywords.rs | 1 +
src/parser/mod.rs | 104 ++++++++++++++++++++++++
tests/sqlparser_common.rs | 114 ++++++++++++++++++++++++++
5 files changed, 473 insertions(+), 22 deletions(-)
diff --git a/src/ast/mod.rs b/src/ast/mod.rs
index 4c0ffea9..66fd4c6f 100644
--- a/src/ast/mod.rs
+++ b/src/ast/mod.rs
@@ -151,6 +151,15 @@ where
DisplaySeparated { slice, sep: ", " }
}
+/// Writes the given statements to the formatter, each ending with
+/// a semicolon and space separated.
+fn format_statement_list(f: &mut fmt::Formatter, statements: &[Statement]) ->
fmt::Result {
+ write!(f, "{}", display_separated(statements, "; "))?;
+ // We manually insert semicolon for the last statement,
+ // since display_separated doesn't handle that case.
+ write!(f, ";")
+}
+
/// An identifier, decomposed into its value or character data and the quote
style.
#[derive(Debug, Clone, PartialOrd, Ord)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
@@ -2080,6 +2089,173 @@ pub enum Password {
NullPassword,
}
+/// A `CASE` statement.
+///
+/// Examples:
+/// ```sql
+/// CASE
+/// WHEN EXISTS(SELECT 1)
+/// THEN SELECT 1 FROM T;
+/// WHEN EXISTS(SELECT 2)
+/// THEN SELECT 1 FROM U;
+/// ELSE
+/// SELECT 1 FROM V;
+/// END CASE;
+/// ```
+///
+///
[BigQuery](https://cloud.google.com/bigquery/docs/reference/standard-sql/procedural-language#case_search_expression)
+///
[Snowflake](https://docs.snowflake.com/en/sql-reference/snowflake-scripting/case)
+#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
+#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
+#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
+pub struct CaseStatement {
+ pub match_expr: Option<Expr>,
+ pub when_blocks: Vec<ConditionalStatements>,
+ pub else_block: Option<Vec<Statement>>,
+ /// TRUE if the statement ends with `END CASE` (vs `END`).
+ pub has_end_case: bool,
+}
+
+impl fmt::Display for CaseStatement {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ let CaseStatement {
+ match_expr,
+ when_blocks,
+ else_block,
+ has_end_case,
+ } = self;
+
+ write!(f, "CASE")?;
+
+ if let Some(expr) = match_expr {
+ write!(f, " {expr}")?;
+ }
+
+ if !when_blocks.is_empty() {
+ write!(f, " {}", display_separated(when_blocks, " "))?;
+ }
+
+ if let Some(else_block) = else_block {
+ write!(f, " ELSE ")?;
+ format_statement_list(f, else_block)?;
+ }
+
+ write!(f, " END")?;
+ if *has_end_case {
+ write!(f, " CASE")?;
+ }
+
+ Ok(())
+ }
+}
+
+/// An `IF` statement.
+///
+/// Examples:
+/// ```sql
+/// IF TRUE THEN
+/// SELECT 1;
+/// SELECT 2;
+/// ELSEIF TRUE THEN
+/// SELECT 3;
+/// ELSE
+/// SELECT 4;
+/// END IF
+/// ```
+///
+///
[BigQuery](https://cloud.google.com/bigquery/docs/reference/standard-sql/procedural-language#if)
+///
[Snowflake](https://docs.snowflake.com/en/sql-reference/snowflake-scripting/if)
+#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
+#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
+#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
+pub struct IfStatement {
+ pub if_block: ConditionalStatements,
+ pub elseif_blocks: Vec<ConditionalStatements>,
+ pub else_block: Option<Vec<Statement>>,
+}
+
+impl fmt::Display for IfStatement {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ let IfStatement {
+ if_block,
+ elseif_blocks,
+ else_block,
+ } = self;
+
+ write!(f, "{if_block}")?;
+
+ if !elseif_blocks.is_empty() {
+ write!(f, " {}", display_separated(elseif_blocks, " "))?;
+ }
+
+ if let Some(else_block) = else_block {
+ write!(f, " ELSE ")?;
+ format_statement_list(f, else_block)?;
+ }
+
+ write!(f, " END IF")?;
+
+ Ok(())
+ }
+}
+
+/// Represents a type of [ConditionalStatements]
+#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
+#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
+#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
+pub enum ConditionalStatementKind {
+ /// `WHEN <condition> THEN <statements>`
+ When,
+ /// `IF <condition> THEN <statements>`
+ If,
+ /// `ELSEIF <condition> THEN <statements>`
+ ElseIf,
+}
+
+/// A block within a [Statement::Case] or [Statement::If]-like statement
+///
+/// Examples:
+/// ```sql
+/// WHEN EXISTS(SELECT 1) THEN SELECT 1;
+///
+/// IF TRUE THEN SELECT 1; SELECT 2;
+/// ```
+#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
+#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
+#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
+pub struct ConditionalStatements {
+ /// The condition expression.
+ pub condition: Expr,
+ /// Statement list of the `THEN` clause.
+ pub statements: Vec<Statement>,
+ pub kind: ConditionalStatementKind,
+}
+
+impl fmt::Display for ConditionalStatements {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ let ConditionalStatements {
+ condition: expr,
+ statements,
+ kind,
+ } = self;
+
+ let kind = match kind {
+ ConditionalStatementKind::When => "WHEN",
+ ConditionalStatementKind::If => "IF",
+ ConditionalStatementKind::ElseIf => "ELSEIF",
+ };
+
+ write!(f, "{kind} {expr} THEN")?;
+
+ if !statements.is_empty() {
+ write!(f, " ")?;
+ format_statement_list(f, statements)?;
+ }
+
+ Ok(())
+ }
+}
+
/// Represents an expression assignment within a variable `DECLARE` statement.
///
/// Examples:
@@ -2647,6 +2823,10 @@ pub enum Statement {
file_format: Option<FileFormat>,
source: Box<Query>,
},
+ /// A `CASE` statement.
+ Case(CaseStatement),
+ /// An `IF` statement.
+ If(IfStatement),
/// ```sql
/// CALL <function>
/// ```
@@ -3940,6 +4120,12 @@ impl fmt::Display for Statement {
}
Ok(())
}
+ Statement::Case(stmt) => {
+ write!(f, "{stmt}")
+ }
+ Statement::If(stmt) => {
+ write!(f, "{stmt}")
+ }
Statement::AttachDatabase {
schema_name,
database_file_name,
@@ -4942,18 +5128,14 @@ impl fmt::Display for Statement {
write!(f, " {}", display_comma_separated(modes))?;
}
if !statements.is_empty() {
- write!(f, " {}", display_separated(statements, "; "))?;
- // We manually insert semicolon for the last statement,
- // since display_separated doesn't handle that case.
- write!(f, ";")?;
+ write!(f, " ")?;
+ format_statement_list(f, statements)?;
}
if let Some(exception_statements) = exception_statements {
write!(f, " EXCEPTION WHEN ERROR THEN")?;
if !exception_statements.is_empty() {
- write!(f, " {}",
display_separated(exception_statements, "; "))?;
- // We manually insert semicolon for the last statement,
- // since display_separated doesn't handle that case.
- write!(f, ";")?;
+ write!(f, " ")?;
+ format_statement_list(f, exception_statements)?;
}
}
if *has_end_keyword {
diff --git a/src/ast/spans.rs b/src/ast/spans.rs
index a4f5eb46..0ee11f23 100644
--- a/src/ast/spans.rs
+++ b/src/ast/spans.rs
@@ -22,20 +22,21 @@ use crate::tokenizer::Span;
use super::{
dcl::SecondaryRoles, value::ValueWithSpan, AccessExpr,
AlterColumnOperation,
- AlterIndexOperation, AlterTableOperation, Array, Assignment,
AssignmentTarget, CloseCursor,
- ClusteredIndex, ColumnDef, ColumnOption, ColumnOptionDef, ConflictTarget,
ConnectBy,
- ConstraintCharacteristics, CopySource, CreateIndex, CreateTable,
CreateTableOptions, Cte,
- Delete, DoUpdate, ExceptSelectItem, ExcludeSelectItem, Expr,
ExprWithAlias, Fetch, FromTable,
- Function, FunctionArg, FunctionArgExpr, FunctionArgumentClause,
FunctionArgumentList,
- FunctionArguments, GroupByExpr, HavingBound, IlikeSelectItem, Insert,
Interpolate,
- InterpolateExpr, Join, JoinConstraint, JoinOperator, JsonPath,
JsonPathElem, LateralView,
- LimitClause, MatchRecognizePattern, Measure, NamedWindowDefinition,
ObjectName, ObjectNamePart,
- Offset, OnConflict, OnConflictAction, OnInsert, OrderBy, OrderByExpr,
OrderByKind, Partition,
- PivotValueSource, ProjectionSelect, Query, ReferentialAction,
RenameSelectItem,
- ReplaceSelectElement, ReplaceSelectItem, Select, SelectInto, SelectItem,
SetExpr, SqlOption,
- Statement, Subscript, SymbolDefinition, TableAlias, TableAliasColumnDef,
TableConstraint,
- TableFactor, TableObject, TableOptionsClustered, TableWithJoins,
UpdateTableFromKind, Use,
- Value, Values, ViewColumnDef, WildcardAdditionalOptions, With, WithFill,
+ AlterIndexOperation, AlterTableOperation, Array, Assignment,
AssignmentTarget, CaseStatement,
+ CloseCursor, ClusteredIndex, ColumnDef, ColumnOption, ColumnOptionDef,
ConditionalStatements,
+ ConflictTarget, ConnectBy, ConstraintCharacteristics, CopySource,
CreateIndex, CreateTable,
+ CreateTableOptions, Cte, Delete, DoUpdate, ExceptSelectItem,
ExcludeSelectItem, Expr,
+ ExprWithAlias, Fetch, FromTable, Function, FunctionArg, FunctionArgExpr,
+ FunctionArgumentClause, FunctionArgumentList, FunctionArguments,
GroupByExpr, HavingBound,
+ IfStatement, IlikeSelectItem, Insert, Interpolate, InterpolateExpr, Join,
JoinConstraint,
+ JoinOperator, JsonPath, JsonPathElem, LateralView, LimitClause,
MatchRecognizePattern, Measure,
+ NamedWindowDefinition, ObjectName, ObjectNamePart, Offset, OnConflict,
OnConflictAction,
+ OnInsert, OrderBy, OrderByExpr, OrderByKind, Partition, PivotValueSource,
ProjectionSelect,
+ Query, ReferentialAction, RenameSelectItem, ReplaceSelectElement,
ReplaceSelectItem, Select,
+ SelectInto, SelectItem, SetExpr, SqlOption, Statement, Subscript,
SymbolDefinition, TableAlias,
+ TableAliasColumnDef, TableConstraint, TableFactor, TableObject,
TableOptionsClustered,
+ TableWithJoins, UpdateTableFromKind, Use, Value, Values, ViewColumnDef,
+ WildcardAdditionalOptions, With, WithFill,
};
/// Given an iterator of spans, return the [Span::union] of all spans.
@@ -334,6 +335,8 @@ impl Spanned for Statement {
file_format: _,
source,
} => source.span(),
+ Statement::Case(stmt) => stmt.span(),
+ Statement::If(stmt) => stmt.span(),
Statement::Call(function) => function.span(),
Statement::Copy {
source,
@@ -732,6 +735,53 @@ impl Spanned for CreateIndex {
}
}
+impl Spanned for CaseStatement {
+ fn span(&self) -> Span {
+ let CaseStatement {
+ match_expr,
+ when_blocks,
+ else_block,
+ has_end_case: _,
+ } = self;
+
+ union_spans(
+ match_expr
+ .iter()
+ .map(|e| e.span())
+ .chain(when_blocks.iter().map(|b| b.span()))
+ .chain(else_block.iter().flat_map(|e| e.iter().map(|s|
s.span()))),
+ )
+ }
+}
+
+impl Spanned for IfStatement {
+ fn span(&self) -> Span {
+ let IfStatement {
+ if_block,
+ elseif_blocks,
+ else_block,
+ } = self;
+
+ union_spans(
+ iter::once(if_block.span())
+ .chain(elseif_blocks.iter().map(|b| b.span()))
+ .chain(else_block.iter().flat_map(|e| e.iter().map(|s|
s.span()))),
+ )
+ }
+}
+
+impl Spanned for ConditionalStatements {
+ fn span(&self) -> Span {
+ let ConditionalStatements {
+ condition,
+ statements,
+ kind: _,
+ } = self;
+
+
union_spans(iter::once(condition.span()).chain(statements.iter().map(|s|
s.span())))
+ }
+}
+
/// # partial span
///
/// Missing spans:
diff --git a/src/keywords.rs b/src/keywords.rs
index 195bbb17..47da1009 100644
--- a/src/keywords.rs
+++ b/src/keywords.rs
@@ -297,6 +297,7 @@ define_keywords!(
ELEMENT,
ELEMENTS,
ELSE,
+ ELSEIF,
EMPTY,
ENABLE,
ENABLE_SCHEMA_EVOLUTION,
diff --git a/src/parser/mod.rs b/src/parser/mod.rs
index 60e1c146..3adfe55e 100644
--- a/src/parser/mod.rs
+++ b/src/parser/mod.rs
@@ -528,6 +528,14 @@ impl<'a> Parser<'a> {
Keyword::DESCRIBE =>
self.parse_explain(DescribeAlias::Describe),
Keyword::EXPLAIN => self.parse_explain(DescribeAlias::Explain),
Keyword::ANALYZE => self.parse_analyze(),
+ Keyword::CASE => {
+ self.prev_token();
+ self.parse_case_stmt()
+ }
+ Keyword::IF => {
+ self.prev_token();
+ self.parse_if_stmt()
+ }
Keyword::SELECT | Keyword::WITH | Keyword::VALUES |
Keyword::FROM => {
self.prev_token();
self.parse_query().map(Statement::Query)
@@ -615,6 +623,102 @@ impl<'a> Parser<'a> {
}
}
+ /// Parse a `CASE` statement.
+ ///
+ /// See [Statement::Case]
+ pub fn parse_case_stmt(&mut self) -> Result<Statement, ParserError> {
+ self.expect_keyword_is(Keyword::CASE)?;
+
+ let match_expr = if self.peek_keyword(Keyword::WHEN) {
+ None
+ } else {
+ Some(self.parse_expr()?)
+ };
+
+ self.expect_keyword_is(Keyword::WHEN)?;
+ let when_blocks = self.parse_keyword_separated(Keyword::WHEN, |parser|
{
+ parser.parse_conditional_statements(
+ ConditionalStatementKind::When,
+ &[Keyword::WHEN, Keyword::ELSE, Keyword::END],
+ )
+ })?;
+
+ let else_block = if self.parse_keyword(Keyword::ELSE) {
+ Some(self.parse_statement_list(&[Keyword::END])?)
+ } else {
+ None
+ };
+
+ self.expect_keyword_is(Keyword::END)?;
+ let has_end_case = self.parse_keyword(Keyword::CASE);
+
+ Ok(Statement::Case(CaseStatement {
+ match_expr,
+ when_blocks,
+ else_block,
+ has_end_case,
+ }))
+ }
+
+ /// Parse an `IF` statement.
+ ///
+ /// See [Statement::If]
+ pub fn parse_if_stmt(&mut self) -> Result<Statement, ParserError> {
+ self.expect_keyword_is(Keyword::IF)?;
+ let if_block = self.parse_conditional_statements(
+ ConditionalStatementKind::If,
+ &[Keyword::ELSE, Keyword::ELSEIF, Keyword::END],
+ )?;
+
+ let elseif_blocks = if self.parse_keyword(Keyword::ELSEIF) {
+ self.parse_keyword_separated(Keyword::ELSEIF, |parser| {
+ parser.parse_conditional_statements(
+ ConditionalStatementKind::ElseIf,
+ &[Keyword::ELSEIF, Keyword::ELSE, Keyword::END],
+ )
+ })?
+ } else {
+ vec![]
+ };
+
+ let else_block = if self.parse_keyword(Keyword::ELSE) {
+ Some(self.parse_statement_list(&[Keyword::END])?)
+ } else {
+ None
+ };
+
+ self.expect_keywords(&[Keyword::END, Keyword::IF])?;
+
+ Ok(Statement::If(IfStatement {
+ if_block,
+ elseif_blocks,
+ else_block,
+ }))
+ }
+
+ /// Parses an expression and associated list of statements
+ /// belonging to a conditional statement like `IF` or `WHEN`.
+ ///
+ /// Example:
+ /// ```sql
+ /// IF condition THEN statement1; statement2;
+ /// ```
+ fn parse_conditional_statements(
+ &mut self,
+ kind: ConditionalStatementKind,
+ terminal_keywords: &[Keyword],
+ ) -> Result<ConditionalStatements, ParserError> {
+ let condition = self.parse_expr()?;
+ self.expect_keyword_is(Keyword::THEN)?;
+ let statements = self.parse_statement_list(terminal_keywords)?;
+
+ Ok(ConditionalStatements {
+ condition,
+ statements,
+ kind,
+ })
+ }
+
pub fn parse_comment(&mut self) -> Result<Statement, ParserError> {
let if_exists = self.parse_keywords(&[Keyword::IF, Keyword::EXISTS]);
diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs
index d0edafb7..8c9cae83 100644
--- a/tests/sqlparser_common.rs
+++ b/tests/sqlparser_common.rs
@@ -14179,6 +14179,120 @@ fn test_visit_order() {
);
}
+#[test]
+fn parse_case_statement() {
+ let sql = "CASE 1 WHEN 2 THEN SELECT 1; SELECT 2; ELSE SELECT 3; END CASE";
+ let Statement::Case(stmt) = verified_stmt(sql) else {
+ unreachable!()
+ };
+
+ assert_eq!(Some(Expr::value(number("1"))), stmt.match_expr);
+ assert_eq!(Expr::value(number("2")), stmt.when_blocks[0].condition);
+ assert_eq!(2, stmt.when_blocks[0].statements.len());
+ assert_eq!(1, stmt.else_block.unwrap().len());
+
+ verified_stmt(concat!(
+ "CASE 1",
+ " WHEN a THEN",
+ " SELECT 1; SELECT 2; SELECT 3;",
+ " WHEN b THEN",
+ " SELECT 4; SELECT 5;",
+ " ELSE",
+ " SELECT 7; SELECT 8;",
+ " END CASE"
+ ));
+ verified_stmt(concat!(
+ "CASE 1",
+ " WHEN a THEN",
+ " SELECT 1; SELECT 2; SELECT 3;",
+ " WHEN b THEN",
+ " SELECT 4; SELECT 5;",
+ " END CASE"
+ ));
+ verified_stmt(concat!(
+ "CASE 1",
+ " WHEN a THEN",
+ " SELECT 1; SELECT 2; SELECT 3;",
+ " END CASE"
+ ));
+ verified_stmt(concat!(
+ "CASE 1",
+ " WHEN a THEN",
+ " SELECT 1; SELECT 2; SELECT 3;",
+ " END"
+ ));
+
+ assert_eq!(
+ ParserError::ParserError("Expected: THEN, found: END".to_string()),
+ parse_sql_statements("CASE 1 WHEN a END").unwrap_err()
+ );
+ assert_eq!(
+ ParserError::ParserError("Expected: WHEN, found: ELSE".to_string()),
+ parse_sql_statements("CASE 1 ELSE SELECT 1; END").unwrap_err()
+ );
+}
+
+#[test]
+fn parse_if_statement() {
+ let sql = "IF 1 THEN SELECT 1; ELSEIF 2 THEN SELECT 2; ELSE SELECT 3; END
IF";
+ let Statement::If(stmt) = verified_stmt(sql) else {
+ unreachable!()
+ };
+ assert_eq!(Expr::value(number("1")), stmt.if_block.condition);
+ assert_eq!(Expr::value(number("2")), stmt.elseif_blocks[0].condition);
+ assert_eq!(1, stmt.else_block.unwrap().len());
+
+ verified_stmt(concat!(
+ "IF 1 THEN",
+ " SELECT 1;",
+ " SELECT 2;",
+ " SELECT 3;",
+ " ELSEIF 2 THEN",
+ " SELECT 4;",
+ " SELECT 5;",
+ " ELSEIF 3 THEN",
+ " SELECT 6;",
+ " SELECT 7;",
+ " ELSE",
+ " SELECT 8;",
+ " SELECT 9;",
+ " END IF"
+ ));
+ verified_stmt(concat!(
+ "IF 1 THEN",
+ " SELECT 1;",
+ " SELECT 2;",
+ " ELSE",
+ " SELECT 3;",
+ " SELECT 4;",
+ " END IF"
+ ));
+ verified_stmt(concat!(
+ "IF 1 THEN",
+ " SELECT 1;",
+ " SELECT 2;",
+ " SELECT 3;",
+ " ELSEIF 2 THEN",
+ " SELECT 3;",
+ " SELECT 4;",
+ " END IF"
+ ));
+ verified_stmt(concat!("IF 1 THEN", " SELECT 1;", " SELECT 2;", " END IF"));
+ verified_stmt(concat!(
+ "IF (1) THEN",
+ " SELECT 1;",
+ " SELECT 2;",
+ " END IF"
+ ));
+ verified_stmt("IF 1 THEN END IF");
+ verified_stmt("IF 1 THEN SELECT 1; ELSEIF 1 THEN END IF");
+
+ assert_eq!(
+ ParserError::ParserError("Expected: IF, found: EOF".to_string()),
+ parse_sql_statements("IF 1 THEN SELECT 1; ELSEIF 1 THEN SELECT 2;
END").unwrap_err()
+ );
+}
+
#[test]
fn test_lambdas() {
let dialects = all_dialects_where(|d| d.supports_lambda_functions());
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]