This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-sqlparser-rs.git


The following commit(s) were added to refs/heads/main by this push:
     new 024a878e Support Snowflake Update-From-Select (#1604)
024a878e is described below

commit 024a878ee7f027ca2f9c635c9398ba59653f1a4e
Author: Yuval Shkolar <[email protected]>
AuthorDate: Tue Dec 24 17:00:59 2024 +0200

    Support Snowflake Update-From-Select (#1604)
    
    Co-authored-by: Ifeanyi Ubah <[email protected]>
---
 src/ast/mod.rs            | 11 +++++++----
 src/ast/query.rs          | 13 +++++++++++++
 src/ast/spans.rs          | 13 +++++++++++--
 src/keywords.rs           |  1 +
 src/parser/mod.rs         | 15 ++++++++++-----
 tests/sqlparser_common.rs | 20 ++++++++++++++++----
 6 files changed, 58 insertions(+), 15 deletions(-)

diff --git a/src/ast/mod.rs b/src/ast/mod.rs
index 9fb2bb9c..5bdce21e 100644
--- a/src/ast/mod.rs
+++ b/src/ast/mod.rs
@@ -72,8 +72,8 @@ pub use self::query::{
     TableAlias, TableAliasColumnDef, TableFactor, TableFunctionArgs, 
TableSample,
     TableSampleBucket, TableSampleKind, TableSampleMethod, TableSampleModifier,
     TableSampleQuantity, TableSampleSeed, TableSampleSeedModifier, 
TableSampleUnit, TableVersion,
-    TableWithJoins, Top, TopQuantity, ValueTableMode, Values, 
WildcardAdditionalOptions, With,
-    WithFill,
+    TableWithJoins, Top, TopQuantity, UpdateTableFromKind, ValueTableMode, 
Values,
+    WildcardAdditionalOptions, With, WithFill,
 };
 
 pub use self::trigger::{
@@ -2473,7 +2473,7 @@ pub enum Statement {
         /// Column assignments
         assignments: Vec<Assignment>,
         /// Table which provide value to be set
-        from: Option<TableWithJoins>,
+        from: Option<UpdateTableFromKind>,
         /// WHERE
         selection: Option<Expr>,
         /// RETURNING
@@ -3745,10 +3745,13 @@ impl fmt::Display for Statement {
                     write!(f, "{or} ")?;
                 }
                 write!(f, "{table}")?;
+                if let Some(UpdateTableFromKind::BeforeSet(from)) = from {
+                    write!(f, " FROM {from}")?;
+                }
                 if !assignments.is_empty() {
                     write!(f, " SET {}", 
display_comma_separated(assignments))?;
                 }
-                if let Some(from) = from {
+                if let Some(UpdateTableFromKind::AfterSet(from)) = from {
                     write!(f, " FROM {from}")?;
                 }
                 if let Some(selection) = selection {
diff --git a/src/ast/query.rs b/src/ast/query.rs
index 69b7ea1c..9e4e9e2e 100644
--- a/src/ast/query.rs
+++ b/src/ast/query.rs
@@ -2790,3 +2790,16 @@ impl fmt::Display for ValueTableMode {
         }
     }
 }
+
+/// The `FROM` clause of an `UPDATE TABLE` statement
+#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
+#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
+#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
+pub enum UpdateTableFromKind {
+    /// Update Statment where the 'FROM' clause is before the 'SET' keyword 
(Supported by Snowflake)
+    /// For Example: `UPDATE FROM t1 SET t1.name='aaa'`
+    BeforeSet(TableWithJoins),
+    /// Update Statment where the 'FROM' clause is after the 'SET' keyword 
(Which is the standard way)
+    /// For Example: `UPDATE SET t1.name='aaa' FROM t1`
+    AfterSet(TableWithJoins),
+}
diff --git a/src/ast/spans.rs b/src/ast/spans.rs
index 9ba3bdd9..521b5399 100644
--- a/src/ast/spans.rs
+++ b/src/ast/spans.rs
@@ -32,8 +32,8 @@ use super::{
     OrderBy, OrderByExpr, Partition, PivotValueSource, ProjectionSelect, 
Query, ReferentialAction,
     RenameSelectItem, ReplaceSelectElement, ReplaceSelectItem, Select, 
SelectInto, SelectItem,
     SetExpr, SqlOption, Statement, Subscript, SymbolDefinition, TableAlias, 
TableAliasColumnDef,
-    TableConstraint, TableFactor, TableOptionsClustered, TableWithJoins, Use, 
Value, Values,
-    ViewColumnDef, WildcardAdditionalOptions, With, WithFill,
+    TableConstraint, TableFactor, TableOptionsClustered, TableWithJoins, 
UpdateTableFromKind, Use,
+    Value, Values, ViewColumnDef, WildcardAdditionalOptions, With, WithFill,
 };
 
 /// Given an iterator of spans, return the [Span::union] of all spans.
@@ -2106,6 +2106,15 @@ impl Spanned for SelectInto {
     }
 }
 
+impl Spanned for UpdateTableFromKind {
+    fn span(&self) -> Span {
+        match self {
+            UpdateTableFromKind::BeforeSet(from) => from.span(),
+            UpdateTableFromKind::AfterSet(from) => from.span(),
+        }
+    }
+}
+
 #[cfg(test)]
 pub mod tests {
     use crate::dialect::{Dialect, GenericDialect, SnowflakeDialect};
diff --git a/src/keywords.rs b/src/keywords.rs
index bbfd00ca..43abc2b0 100644
--- a/src/keywords.rs
+++ b/src/keywords.rs
@@ -941,6 +941,7 @@ pub const RESERVED_FOR_TABLE_ALIAS: &[Keyword] = &[
     // Reserved for Snowflake table sample
     Keyword::SAMPLE,
     Keyword::TABLESAMPLE,
+    Keyword::FROM,
 ];
 
 /// Can't be used as a column alias, so that `SELECT <expr> alias`
diff --git a/src/parser/mod.rs b/src/parser/mod.rs
index cc0a57e4..57c4dc6e 100644
--- a/src/parser/mod.rs
+++ b/src/parser/mod.rs
@@ -11791,14 +11791,19 @@ impl<'a> Parser<'a> {
     pub fn parse_update(&mut self) -> Result<Statement, ParserError> {
         let or = self.parse_conflict_clause();
         let table = self.parse_table_and_joins()?;
+        let from_before_set = if self.parse_keyword(Keyword::FROM) {
+            Some(UpdateTableFromKind::BeforeSet(
+                self.parse_table_and_joins()?,
+            ))
+        } else {
+            None
+        };
         self.expect_keyword(Keyword::SET)?;
         let assignments = 
self.parse_comma_separated(Parser::parse_assignment)?;
-        let from = if self.parse_keyword(Keyword::FROM)
-            && dialect_of!(self is GenericDialect | PostgreSqlDialect | 
DuckDbDialect | BigQueryDialect | SnowflakeDialect | RedshiftSqlDialect | 
MsSqlDialect | SQLiteDialect )
-        {
-            Some(self.parse_table_and_joins()?)
+        let from = if from_before_set.is_none() && 
self.parse_keyword(Keyword::FROM) {
+            Some(UpdateTableFromKind::AfterSet(self.parse_table_and_joins()?))
         } else {
-            None
+            from_before_set
         };
         let selection = if self.parse_keyword(Keyword::WHERE) {
             Some(self.parse_expr()?)
diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs
index 79f5c8d3..cbbbb45f 100644
--- a/tests/sqlparser_common.rs
+++ b/tests/sqlparser_common.rs
@@ -366,7 +366,7 @@ fn parse_update_set_from() {
                 target: 
AssignmentTarget::ColumnName(ObjectName(vec![Ident::new("name")])),
                 value: Expr::CompoundIdentifier(vec![Ident::new("t2"), 
Ident::new("name")])
             }],
-            from: Some(TableWithJoins {
+            from: Some(UpdateTableFromKind::AfterSet(TableWithJoins {
                 relation: TableFactor::Derived {
                     lateral: false,
                     subquery: Box::new(Query {
@@ -417,8 +417,8 @@ fn parse_update_set_from() {
                         columns: vec![],
                     })
                 },
-                joins: vec![],
-            }),
+                joins: vec![]
+            })),
             selection: Some(Expr::BinaryOp {
                 left: Box::new(Expr::CompoundIdentifier(vec![
                     Ident::new("t1"),
@@ -12577,9 +12577,21 @@ fn overflow() {
     let statement = statements.pop().unwrap();
     assert_eq!(statement.to_string(), sql);
 }
-
 #[test]
 fn parse_select_without_projection() {
     let dialects = all_dialects_where(|d| d.supports_empty_projections());
     dialects.verified_stmt("SELECT FROM users");
 }
+
+#[test]
+fn parse_update_from_before_select() {
+    all_dialects()
+    .verified_stmt("UPDATE t1 FROM (SELECT name, id FROM t1 GROUP BY id) AS t2 
SET name = t2.name WHERE t1.id = t2.id");
+
+    let query =
+    "UPDATE t1 FROM (SELECT name, id FROM t1 GROUP BY id) AS t2 SET name = 
t2.name FROM (SELECT name from t2) AS t2";
+    assert_eq!(
+        ParserError::ParserError("Expected: end of statement, found: 
FROM".to_string()),
+        parse_sql_statements(query).unwrap_err()
+    );
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to