This is an automated email from the ASF dual-hosted git repository.

github-bot pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-sqlparser-rs.git


The following commit(s) were added to refs/heads/main by this push:
     new 2ac82e94 Streamlined derivation of new `Dialect` objects (#2174)
2ac82e94 is described below

commit 2ac82e946e5f6513b51b747f1783c1cf5f4a733d
Author: Alexander Beedie <[email protected]>
AuthorDate: Tue Feb 3 20:09:11 2026 +0700

    Streamlined derivation of new `Dialect` objects (#2174)
---
 Cargo.toml                        |   7 +-
 derive/Cargo.toml                 |   2 +-
 derive/src/dialect.rs             | 305 ++++++++++++++++++++++++++++++++++++++
 derive/src/lib.rs                 | 276 +++-------------------------------
 derive/src/{lib.rs => visit.rs}   |  47 ++----
 src/dialect/ansi.rs               |   2 +-
 src/dialect/clickhouse.rs         |   2 +-
 src/dialect/hive.rs               |   2 +-
 src/dialect/mod.rs                | 108 +++++++++++++-
 src/dialect/mssql.rs              |   2 +-
 src/dialect/mysql.rs              |   2 +-
 src/dialect/oracle.rs             |   2 +-
 src/dialect/postgresql.rs         |   2 +-
 src/dialect/redshift.rs           |   2 +-
 src/dialect/sqlite.rs             |   2 +-
 src/lib.rs                        |   3 +
 tests/sqlparser_derive_dialect.rs | 123 +++++++++++++++
 17 files changed, 585 insertions(+), 304 deletions(-)

diff --git a/Cargo.toml b/Cargo.toml
index 177ab3db..8945adef 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -42,6 +42,7 @@ std = []
 recursive-protection = ["std", "recursive"]
 # Enable JSON output in the `cli` example:
 json_example = ["serde_json", "serde"]
+derive-dialect = ["sqlparser_derive"]
 visitor = ["sqlparser_derive"]
 
 [dependencies]
@@ -61,6 +62,10 @@ simple_logger = "5.0"
 matches = "0.1"
 pretty_assertions = "1"
 
+[[test]]
+name = "sqlparser_derive_dialect"
+required-features = ["derive-dialect"]
+
 [package.metadata.docs.rs]
 # Document these features on docs.rs
-features = ["serde", "visitor"]
+features = ["serde", "visitor", "derive-dialect"]
diff --git a/derive/Cargo.toml b/derive/Cargo.toml
index 54947704..f2f54926 100644
--- a/derive/Cargo.toml
+++ b/derive/Cargo.toml
@@ -36,6 +36,6 @@ edition = "2021"
 proc-macro = true
 
 [dependencies]
-syn = { version = "2.0", default-features = false, features = ["printing", 
"parsing", "derive", "proc-macro"] }
+syn = { version = "2.0", default-features = false, features = ["full", 
"printing", "parsing", "derive", "proc-macro", "clone-impls"] }
 proc-macro2 = "1.0"
 quote = "1.0"
diff --git a/derive/src/dialect.rs b/derive/src/dialect.rs
new file mode 100644
index 00000000..9873e4f7
--- /dev/null
+++ b/derive/src/dialect.rs
@@ -0,0 +1,305 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Implementation of the `derive_dialect!` macro for creating custom SQL 
dialects.
+
+use proc_macro2::TokenStream;
+use quote::{quote, quote_spanned};
+use std::collections::HashSet;
+use syn::{
+    braced,
+    parse::{Parse, ParseStream},
+    Error, File, FnArg, Ident, Item, LitBool, LitChar, Pat, ReturnType, 
Signature, Token,
+    TraitItem, Type,
+};
+
+/// Override value types supported by the macro
+pub(crate) enum Override {
+    Bool(LitBool),
+    Char(LitChar),
+    None,
+}
+
+/// Parsed input for the `derive_dialect!` macro
+pub(crate) struct DeriveDialectInput {
+    pub name: Ident,
+    pub base: Type,
+    pub preserve_type_id: bool,
+    pub overrides: Vec<(Ident, Override)>,
+}
+
+/// `Dialect` trait method attrs
+struct DialectMethod {
+    name: Ident,
+    signature: Signature,
+}
+
+impl Parse for DeriveDialectInput {
+    fn parse(input: ParseStream) -> syn::Result<Self> {
+        let name: Ident = input.parse()?;
+        input.parse::<Token![,]>()?;
+        let base: Type = input.parse()?;
+
+        let mut preserve_type_id = false;
+        let mut overrides = Vec::new();
+
+        while input.peek(Token![,]) {
+            input.parse::<Token![,]>()?;
+            if input.is_empty() {
+                break;
+            }
+            if input.peek(Ident) {
+                let ident: Ident = input.parse()?;
+                match ident.to_string().as_str() {
+                    "preserve_type_id" => {
+                        input.parse::<Token![=]>()?;
+                        preserve_type_id = input.parse::<LitBool>()?.value();
+                    }
+                    "overrides" => {
+                        input.parse::<Token![=]>()?;
+                        let content;
+                        braced!(content in input);
+                        while !content.is_empty() {
+                            let key: Ident = content.parse()?;
+                            content.parse::<Token![=]>()?;
+                            let value = if content.peek(LitBool) {
+                                Override::Bool(content.parse()?)
+                            } else if content.peek(LitChar) {
+                                Override::Char(content.parse()?)
+                            } else if content.peek(Ident) {
+                                let ident: Ident = content.parse()?;
+                                if ident == "None" {
+                                    Override::None
+                                } else {
+                                    return Err(Error::new(
+                                        ident.span(),
+                                        format!("Expected `true`, `false`, a 
char, or `None`, found `{ident}`"),
+                                    ));
+                                }
+                            } else {
+                                return Err(
+                                    content.error("Expected `true`, `false`, a 
char, or `None`")
+                                );
+                            };
+                            overrides.push((key, value));
+                            if content.peek(Token![,]) {
+                                content.parse::<Token![,]>()?;
+                            }
+                        }
+                    }
+                    other => {
+                        return Err(Error::new(ident.span(), format!(
+                            "Unknown argument `{other}`. Expected 
`preserve_type_id` or `overrides`."
+                        )));
+                    }
+                }
+            }
+        }
+        Ok(DeriveDialectInput {
+            name,
+            base,
+            preserve_type_id,
+            overrides,
+        })
+    }
+}
+
+/// Entry point for the `derive_dialect!` macro
+pub(crate) fn derive_dialect(input: DeriveDialectInput) -> 
proc_macro::TokenStream {
+    let err = |msg: String| {
+        Error::new(proc_macro2::Span::call_site(), msg)
+            .to_compile_error()
+            .into()
+    };
+
+    let source = match read_dialect_mod_file() {
+        Ok(s) => s,
+        Err(e) => return err(format!("Failed to read dialect/mod.rs: {e}")),
+    };
+    let file: File = match syn::parse_str(&source) {
+        Ok(f) => f,
+        Err(e) => return err(format!("Failed to parse source: {e}")),
+    };
+    let methods = match extract_dialect_methods(&file) {
+        Ok(m) => m,
+        Err(e) => return e.to_compile_error().into(),
+    };
+
+    // Validate overrides
+    let bool_names: HashSet<_> = methods
+        .iter()
+        .filter(|m| is_bool_method(&m.signature))
+        .map(|m| m.name.to_string())
+        .collect();
+    for (key, value) in &input.overrides {
+        let key_str = key.to_string();
+        let err = |msg| Error::new(key.span(), msg).to_compile_error().into();
+        match value {
+            Override::Bool(_) if !bool_names.contains(&key_str) => {
+                return err(format!("Unknown boolean method `{key_str}`"));
+            }
+            Override::Char(_) | Override::None if key_str != 
"identifier_quote_style" => {
+                return err(format!(
+                    "Char/None only valid for `identifier_quote_style`, not 
`{key_str}`"
+                ));
+            }
+            _ => {}
+        }
+    }
+    generate_derived_dialect(&input, &methods).into()
+}
+
+/// Generate the complete derived `Dialect` implementation
+fn generate_derived_dialect(input: &DeriveDialectInput, methods: 
&[DialectMethod]) -> TokenStream {
+    let name = &input.name;
+    let base = &input.base;
+
+    // Helper to find an override by method name
+    let find_override = |method_name: &str| {
+        input
+            .overrides
+            .iter()
+            .find(|(k, _)| k == method_name)
+            .map(|(_, v)| v)
+    };
+
+    // Helper to generate delegation to base dialect
+    let delegate = |method: &DialectMethod| {
+        let sig = &method.signature;
+        let method_name = &method.name;
+        let params = extract_param_names(sig);
+        quote_spanned! { method_name.span() => #sig { 
self.dialect.#method_name(#(#params),*) } }
+    };
+
+    // Generate the struct
+    let struct_def = quote_spanned! { name.span() =>
+        #[derive(Debug, Default)]
+        pub struct #name {
+            dialect: #base,
+        }
+        impl #name {
+            pub fn new() -> Self { Self::default() }
+        }
+    };
+
+    // Generate TypeId method body
+    let type_id_body = if input.preserve_type_id {
+        quote! { Dialect::dialect(&self.dialect) }
+    } else {
+        quote! { ::core::any::TypeId::of::<#name>() }
+    };
+
+    // Generate method implementations
+    let method_impls = methods.iter().map(|method| {
+        let method_name = &method.name;
+        match find_override(&method_name.to_string()) {
+            Some(Override::Bool(value)) => {
+                quote_spanned! { method_name.span() => fn #method_name(&self) 
-> bool { #value } }
+            }
+            Some(Override::Char(c)) => {
+                quote_spanned! { method_name.span() =>
+                    fn identifier_quote_style(&self, _: &str) -> Option<char> 
{ Some(#c) }
+                }
+            }
+            Some(Override::None) => {
+                quote_spanned! { method_name.span() =>
+                    fn identifier_quote_style(&self, _: &str) -> Option<char> 
{ None }
+                }
+            }
+            None => delegate(method),
+        }
+    });
+
+    // Wrap impl in a const block with scoped imports so types resolve without 
qualification
+    quote! {
+        #struct_def
+        const _: () = {
+            use ::core::iter::Peekable;
+            use ::core::str::Chars;
+            use sqlparser::ast::{ColumnOption, Expr, GranteesType, Ident, 
ObjectNamePart, Statement};
+            use sqlparser::dialect::{Dialect, Precedence};
+            use sqlparser::keywords::Keyword;
+            use sqlparser::parser::{Parser, ParserError};
+
+            impl Dialect for #name {
+                fn dialect(&self) -> ::core::any::TypeId { #type_id_body }
+                #(#method_impls)*
+            }
+        };
+    }
+}
+
+/// Extract parameter names from a method signature (excluding self)
+fn extract_param_names(sig: &Signature) -> Vec<&Ident> {
+    sig.inputs
+        .iter()
+        .filter_map(|arg| match arg {
+            FnArg::Typed(pt) => match pt.pat.as_ref() {
+                Pat::Ident(pi) => Some(&pi.ident),
+                _ => None,
+            },
+            _ => None,
+        })
+        .collect()
+}
+
+/// Read the `dialect/mod.rs` file that contains the Dialect trait.
+fn read_dialect_mod_file() -> Result<String, String> {
+    let manifest_dir =
+        std::env::var("CARGO_MANIFEST_DIR").map_err(|_| "CARGO_MANIFEST_DIR 
not set")?;
+    let path = std::path::Path::new(&manifest_dir).join("src/dialect/mod.rs");
+    std::fs::read_to_string(&path).map_err(|e| format!("Failed to read {}: 
{e}", path.display()))
+}
+
+/// Extract all methods from the `Dialect` trait (excluding `dialect` for 
TypeId)
+fn extract_dialect_methods(file: &File) -> Result<Vec<DialectMethod>, Error> {
+    let dialect_trait = file
+        .items
+        .iter()
+        .find_map(|item| match item {
+            Item::Trait(t) if t.ident == "Dialect" => Some(t),
+            _ => None,
+        })
+        .ok_or_else(|| Error::new(proc_macro2::Span::call_site(), "Dialect 
trait not found"))?;
+
+    let mut methods: Vec<_> = dialect_trait
+        .items
+        .iter()
+        .filter_map(|item| match item {
+            TraitItem::Fn(m) if m.sig.ident != "dialect" => Some(DialectMethod 
{
+                name: m.sig.ident.clone(),
+                signature: m.sig.clone(),
+            }),
+            _ => None,
+        })
+        .collect();
+    methods.sort_by_key(|m| m.name.to_string());
+    Ok(methods)
+}
+
+/// Check if a method signature is `fn name(&self) -> bool`
+fn is_bool_method(sig: &Signature) -> bool {
+    sig.inputs.len() == 1
+        && matches!(
+            sig.inputs.first(),
+            Some(FnArg::Receiver(r)) if r.reference.is_some() && 
r.mutability.is_none()
+        )
+        && matches!(
+            &sig.output,
+            ReturnType::Type(_, ty) if matches!(ty.as_ref(), Type::Path(p) if 
p.path.is_ident("bool"))
+        )
+}
diff --git a/derive/src/lib.rs b/derive/src/lib.rs
index 08c5c5db..e3eaeea6 100644
--- a/derive/src/lib.rs
+++ b/derive/src/lib.rs
@@ -15,22 +15,25 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use proc_macro2::TokenStream;
-use quote::{format_ident, quote, quote_spanned, ToTokens};
-use syn::spanned::Spanned;
-use syn::{
-    parse::{Parse, ParseStream},
-    parse_macro_input, parse_quote, Attribute, Data, DeriveInput, Fields, 
GenericParam, Generics,
-    Ident, Index, LitStr, Meta, Token, Type, TypePath,
-};
-use syn::{Path, PathArguments};
+//! Procedural macros for sqlparser.
+//!
+//! This crate provides:
+//! - [`Visit`] and [`VisitMut`] derive macros for AST traversal.
+//! - [`derive_dialect!`] macro for creating custom SQL dialects.
 
-/// Implementation of `[#derive(Visit)]`
+use quote::quote;
+use syn::parse_macro_input;
+
+mod dialect;
+mod visit;
+
+/// Implementation of `#[derive(VisitMut)]`
 #[proc_macro_derive(VisitMut, attributes(visit))]
 pub fn derive_visit_mut(input: proc_macro::TokenStream) -> 
proc_macro::TokenStream {
-    derive_visit(
+    let input = parse_macro_input!(input as syn::DeriveInput);
+    visit::derive_visit(
         input,
-        &VisitType {
+        &visit::VisitType {
             visit_trait: quote!(VisitMut),
             visitor_trait: quote!(VisitorMut),
             modifier: Some(quote!(mut)),
@@ -38,12 +41,13 @@ pub fn derive_visit_mut(input: proc_macro::TokenStream) -> 
proc_macro::TokenStre
     )
 }
 
-/// Implementation of `[#derive(Visit)]`
+/// Implementation of `#[derive(Visit)]`
 #[proc_macro_derive(Visit, attributes(visit))]
 pub fn derive_visit_immutable(input: proc_macro::TokenStream) -> 
proc_macro::TokenStream {
-    derive_visit(
+    let input = parse_macro_input!(input as syn::DeriveInput);
+    visit::derive_visit(
         input,
-        &VisitType {
+        &visit::VisitType {
             visit_trait: quote!(Visit),
             visitor_trait: quote!(Visitor),
             modifier: None,
@@ -51,241 +55,9 @@ pub fn derive_visit_immutable(input: 
proc_macro::TokenStream) -> proc_macro::Tok
     )
 }
 
-struct VisitType {
-    visit_trait: TokenStream,
-    visitor_trait: TokenStream,
-    modifier: Option<TokenStream>,
-}
-
-fn derive_visit(input: proc_macro::TokenStream, visit_type: &VisitType) -> 
proc_macro::TokenStream {
-    // Parse the input tokens into a syntax tree.
-    let input = parse_macro_input!(input as DeriveInput);
-    let name = input.ident;
-
-    let VisitType {
-        visit_trait,
-        visitor_trait,
-        modifier,
-    } = visit_type;
-
-    let attributes = Attributes::parse(&input.attrs);
-    // Add a bound `T: Visit` to every type parameter T.
-    let generics = add_trait_bounds(input.generics, visit_type);
-    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
-
-    let (pre_visit, post_visit) = attributes.visit(quote!(self));
-    let children = visit_children(&input.data, visit_type);
-
-    let expanded = quote! {
-        // The generated impl.
-        // Note that it uses [`recursive::recursive`] to protect from stack 
overflow.
-        // See tests in 
https://github.com/apache/datafusion-sqlparser-rs/pull/1522/ for more info.
-        impl #impl_generics sqlparser::ast::#visit_trait for #name 
#ty_generics #where_clause {
-             #[cfg_attr(feature = "recursive-protection", 
recursive::recursive)]
-            fn visit<V: sqlparser::ast::#visitor_trait>(
-                &#modifier self,
-                visitor: &mut V
-            ) -> ::std::ops::ControlFlow<V::Break> {
-                #pre_visit
-                #children
-                #post_visit
-                ::std::ops::ControlFlow::Continue(())
-            }
-        }
-    };
-
-    proc_macro::TokenStream::from(expanded)
-}
-
-/// Parses attributes that can be provided to this macro
-///
-/// `#[visit(leaf, with = "visit_expr")]`
-#[derive(Default)]
-struct Attributes {
-    /// Content for the `with` attribute
-    with: Option<Ident>,
-}
-
-struct WithIdent {
-    with: Option<Ident>,
-}
-impl Parse for WithIdent {
-    fn parse(input: ParseStream) -> Result<Self, syn::Error> {
-        let mut result = WithIdent { with: None };
-        let ident = input.parse::<Ident>()?;
-        if ident != "with" {
-            return Err(syn::Error::new(
-                ident.span(),
-                "Expected identifier to be `with`",
-            ));
-        }
-        input.parse::<Token!(=)>()?;
-        let s = input.parse::<LitStr>()?;
-        result.with = Some(format_ident!("{}", s.value(), span = s.span()));
-        Ok(result)
-    }
-}
-
-impl Attributes {
-    fn parse(attrs: &[Attribute]) -> Self {
-        let mut out = Self::default();
-        for attr in attrs {
-            if let Meta::List(ref metalist) = attr.meta {
-                if metalist.path.is_ident("visit") {
-                    match syn::parse2::<WithIdent>(metalist.tokens.clone()) {
-                        Ok(with_ident) => {
-                            out.with = with_ident.with;
-                        }
-                        Err(e) => {
-                            panic!("{}", e);
-                        }
-                    }
-                }
-            }
-        }
-        out
-    }
-
-    /// Returns the pre and post visit token streams
-    fn visit(&self, s: TokenStream) -> (Option<TokenStream>, 
Option<TokenStream>) {
-        let pre_visit = self.with.as_ref().map(|m| {
-            let m = format_ident!("pre_{}", m);
-            quote!(visitor.#m(#s)?;)
-        });
-        let post_visit = self.with.as_ref().map(|m| {
-            let m = format_ident!("post_{}", m);
-            quote!(visitor.#m(#s)?;)
-        });
-        (pre_visit, post_visit)
-    }
-}
-
-// Add a bound `T: Visit` to every type parameter T.
-fn add_trait_bounds(mut generics: Generics, VisitType { visit_trait, .. }: 
&VisitType) -> Generics {
-    for param in &mut generics.params {
-        if let GenericParam::Type(ref mut type_param) = *param {
-            type_param
-                .bounds
-                .push(parse_quote!(sqlparser::ast::#visit_trait));
-        }
-    }
-    generics
-}
-
-// Generate the body of the visit implementation for the given type
-fn visit_children(
-    data: &Data,
-    VisitType {
-        visit_trait,
-        modifier,
-        ..
-    }: &VisitType,
-) -> TokenStream {
-    match data {
-        Data::Struct(data) => match &data.fields {
-            Fields::Named(fields) => {
-                let recurse = fields.named.iter().map(|f| {
-                    let name = &f.ident;
-                    let is_option = is_option(&f.ty);
-                    let attributes = Attributes::parse(&f.attrs);
-                    if is_option && attributes.with.is_some() {
-                        let (pre_visit, post_visit) = 
attributes.visit(quote!(value));
-                        quote_spanned!(f.span() =>
-                            if let Some(value) = &#modifier self.#name {
-                                #pre_visit 
sqlparser::ast::#visit_trait::visit(value, visitor)?; #post_visit
-                            }
-                        )
-                    } else {
-                        let (pre_visit, post_visit) = 
attributes.visit(quote!(&#modifier self.#name));
-                        quote_spanned!(f.span() =>
-                            #pre_visit 
sqlparser::ast::#visit_trait::visit(&#modifier self.#name, visitor)?; 
#post_visit
-                        )
-                    }
-                });
-                quote! {
-                    #(#recurse)*
-                }
-            }
-            Fields::Unnamed(fields) => {
-                let recurse = fields.unnamed.iter().enumerate().map(|(i, f)| {
-                    let index = Index::from(i);
-                    let attributes = Attributes::parse(&f.attrs);
-                    let (pre_visit, post_visit) = 
attributes.visit(quote!(&self.#index));
-                    quote_spanned!(f.span() => #pre_visit 
sqlparser::ast::#visit_trait::visit(&#modifier self.#index, visitor)?; 
#post_visit)
-                });
-                quote! {
-                    #(#recurse)*
-                }
-            }
-            Fields::Unit => {
-                quote!()
-            }
-        },
-        Data::Enum(data) => {
-            let statements = data.variants.iter().map(|v| {
-                let name = &v.ident;
-                match &v.fields {
-                    Fields::Named(fields) => {
-                        let names = fields.named.iter().map(|f| &f.ident);
-                        let visit = fields.named.iter().map(|f| {
-                            let name = &f.ident;
-                            let attributes = Attributes::parse(&f.attrs);
-                            let (pre_visit, post_visit) = 
attributes.visit(name.to_token_stream());
-                            quote_spanned!(f.span() => #pre_visit 
sqlparser::ast::#visit_trait::visit(#name, visitor)?; #post_visit)
-                        });
-
-                        quote!(
-                            Self::#name { #(#names),* } => {
-                                #(#visit)*
-                            }
-                        )
-                    }
-                    Fields::Unnamed(fields) => {
-                        let names = fields.unnamed.iter().enumerate().map(|(i, 
f)| format_ident!("_{}", i, span = f.span()));
-                        let visit = fields.unnamed.iter().enumerate().map(|(i, 
f)| {
-                            let name = format_ident!("_{}", i);
-                            let attributes = Attributes::parse(&f.attrs);
-                            let (pre_visit, post_visit) = 
attributes.visit(name.to_token_stream());
-                            quote_spanned!(f.span() => #pre_visit 
sqlparser::ast::#visit_trait::visit(#name, visitor)?; #post_visit)
-                        });
-
-                        quote! {
-                            Self::#name ( #(#names),*) => {
-                                #(#visit)*
-                            }
-                        }
-                    }
-                    Fields::Unit => {
-                        quote! {
-                            Self::#name => {}
-                        }
-                    }
-                }
-            });
-
-            quote! {
-                match self {
-                    #(#statements),*
-                }
-            }
-        }
-        Data::Union(_) => unimplemented!(),
-    }
-}
-
-fn is_option(ty: &Type) -> bool {
-    if let Type::Path(TypePath {
-        path: Path { segments, .. },
-        ..
-    }) = ty
-    {
-        if let Some(segment) = segments.last() {
-            if segment.ident == "Option" {
-                if let PathArguments::AngleBracketed(args) = 
&segment.arguments {
-                    return args.args.len() == 1;
-                }
-            }
-        }
-    }
-    false
+/// Procedural macro for deriving new SQL dialects.
+#[proc_macro]
+pub fn derive_dialect(input: proc_macro::TokenStream) -> 
proc_macro::TokenStream {
+    let input = parse_macro_input!(input as dialect::DeriveDialectInput);
+    dialect::derive_dialect(input)
 }
diff --git a/derive/src/lib.rs b/derive/src/visit.rs
similarity index 88%
copy from derive/src/lib.rs
copy to derive/src/visit.rs
index 08c5c5db..baf3eb58 100644
--- a/derive/src/lib.rs
+++ b/derive/src/visit.rs
@@ -15,51 +15,28 @@
 // specific language governing permissions and limitations
 // under the License.
 
+//! Implementation of the `Visit` and `VisitMut` derive macros.
+
 use proc_macro2::TokenStream;
 use quote::{format_ident, quote, quote_spanned, ToTokens};
 use syn::spanned::Spanned;
 use syn::{
     parse::{Parse, ParseStream},
-    parse_macro_input, parse_quote, Attribute, Data, DeriveInput, Fields, 
GenericParam, Generics,
-    Ident, Index, LitStr, Meta, Token, Type, TypePath,
+    parse_quote, Attribute, Data, Fields, GenericParam, Generics, Ident, 
Index, LitStr, Meta,
+    Token, Type, TypePath,
 };
 use syn::{Path, PathArguments};
 
-/// Implementation of `[#derive(Visit)]`
-#[proc_macro_derive(VisitMut, attributes(visit))]
-pub fn derive_visit_mut(input: proc_macro::TokenStream) -> 
proc_macro::TokenStream {
-    derive_visit(
-        input,
-        &VisitType {
-            visit_trait: quote!(VisitMut),
-            visitor_trait: quote!(VisitorMut),
-            modifier: Some(quote!(mut)),
-        },
-    )
-}
-
-/// Implementation of `[#derive(Visit)]`
-#[proc_macro_derive(Visit, attributes(visit))]
-pub fn derive_visit_immutable(input: proc_macro::TokenStream) -> 
proc_macro::TokenStream {
-    derive_visit(
-        input,
-        &VisitType {
-            visit_trait: quote!(Visit),
-            visitor_trait: quote!(Visitor),
-            modifier: None,
-        },
-    )
-}
-
-struct VisitType {
-    visit_trait: TokenStream,
-    visitor_trait: TokenStream,
-    modifier: Option<TokenStream>,
+pub(crate) struct VisitType {
+    pub visit_trait: TokenStream,
+    pub visitor_trait: TokenStream,
+    pub modifier: Option<TokenStream>,
 }
 
-fn derive_visit(input: proc_macro::TokenStream, visit_type: &VisitType) -> 
proc_macro::TokenStream {
-    // Parse the input tokens into a syntax tree.
-    let input = parse_macro_input!(input as DeriveInput);
+pub(crate) fn derive_visit(
+    input: syn::DeriveInput,
+    visit_type: &VisitType,
+) -> proc_macro::TokenStream {
     let name = input.ident;
 
     let VisitType {
diff --git a/src/dialect/ansi.rs b/src/dialect/ansi.rs
index ec3c095b..5a54390c 100644
--- a/src/dialect/ansi.rs
+++ b/src/dialect/ansi.rs
@@ -18,7 +18,7 @@
 use crate::dialect::Dialect;
 
 /// A [`Dialect`] for [ANSI SQL](https://en.wikipedia.org/wiki/SQL:2011).
-#[derive(Debug)]
+#[derive(Debug, Default)]
 pub struct AnsiDialect {}
 
 impl Dialect for AnsiDialect {
diff --git a/src/dialect/clickhouse.rs b/src/dialect/clickhouse.rs
index 041b94ec..f8b6807f 100644
--- a/src/dialect/clickhouse.rs
+++ b/src/dialect/clickhouse.rs
@@ -18,7 +18,7 @@
 use crate::dialect::Dialect;
 
 /// A [`Dialect`] for [ClickHouse](https://clickhouse.com/).
-#[derive(Debug)]
+#[derive(Debug, Default)]
 pub struct ClickHouseDialect {}
 
 impl Dialect for ClickHouseDialect {
diff --git a/src/dialect/hive.rs b/src/dialect/hive.rs
index 3e15d395..32a982e9 100644
--- a/src/dialect/hive.rs
+++ b/src/dialect/hive.rs
@@ -18,7 +18,7 @@
 use crate::dialect::Dialect;
 
 /// A [`Dialect`] for [Hive](https://hive.apache.org/).
-#[derive(Debug)]
+#[derive(Debug, Default)]
 pub struct HiveDialect {}
 
 impl Dialect for HiveDialect {
diff --git a/src/dialect/mod.rs b/src/dialect/mod.rs
index ef563fc1..477d60f8 100644
--- a/src/dialect/mod.rs
+++ b/src/dialect/mod.rs
@@ -51,6 +51,82 @@ pub use self::postgresql::PostgreSqlDialect;
 pub use self::redshift::RedshiftSqlDialect;
 pub use self::snowflake::SnowflakeDialect;
 pub use self::sqlite::SQLiteDialect;
+
+/// Macro for streamlining the creation of derived `Dialect` objects.
+/// The generated struct includes `new()` and `default()` constructors.
+/// Requires the `derive-dialect` feature.
+///
+/// # Syntax
+///
+/// ```text
+/// derive_dialect!(NewDialect, BaseDialect);
+/// derive_dialect!(NewDialect, BaseDialect, overrides = { method = value, ... 
});
+/// derive_dialect!(NewDialect, BaseDialect, preserve_type_id = true);
+/// derive_dialect!(NewDialect, BaseDialect, preserve_type_id = true, 
overrides = { ... });
+/// ```
+///
+/// # Example
+///
+/// ```
+/// use sqlparser::derive_dialect;
+/// use sqlparser::dialect::{Dialect, GenericDialect};
+///
+/// // Override boolean methods (supports_*, allow_*, etc.)
+/// derive_dialect!(CustomDialect, GenericDialect, overrides = {
+///     supports_order_by_all = true,
+///     supports_nested_comments = true,
+/// });
+///
+/// let dialect = CustomDialect::new();
+/// assert!(dialect.supports_order_by_all());
+/// assert!(dialect.supports_nested_comments());
+/// ```
+///
+/// # Overriding `identifier_quote_style`
+///
+/// Use a char literal or `None`:
+/// ```
+/// use sqlparser::derive_dialect;
+/// use sqlparser::dialect::{Dialect, PostgreSqlDialect};
+///
+/// derive_dialect!(BacktickPostgreSqlDialect, PostgreSqlDialect,
+///     preserve_type_id = true,
+///     overrides = { identifier_quote_style = '`' }
+/// );
+/// let d: &dyn Dialect = &BacktickPostgreSqlDialect::new();
+/// assert_eq!(d.identifier_quote_style("foo"), Some('`'));
+///
+/// derive_dialect!(QuotelessPostgreSqlDialect, PostgreSqlDialect,
+///     preserve_type_id = true,
+///     overrides = { identifier_quote_style = None }
+/// );
+/// let d: &dyn Dialect = &QuotelessPostgreSqlDialect::new();
+/// assert_eq!(d.identifier_quote_style("foo"), None);
+/// ```
+///
+/// # Type Identity
+///
+/// By default, derived dialects have their own `TypeId`. Set 
`preserve_type_id = true` to
+/// retain the base dialect's identity with respect to the parser's 
`dialect.is::<T>()` checks:
+/// ```
+/// use sqlparser::derive_dialect;
+/// use sqlparser::dialect::{Dialect, GenericDialect};
+///
+/// derive_dialect!(EnhancedGenericDialect, GenericDialect,
+///     preserve_type_id = true,
+///     overrides = {
+///         supports_order_by_all = true,
+///         supports_nested_comments = true,
+///     }
+/// );
+/// let d: &dyn Dialect = &EnhancedGenericDialect::new();
+/// assert!(d.is::<GenericDialect>());  // still recognized as a GenericDialect
+/// assert!(d.supports_nested_comments());
+/// assert!(d.supports_order_by_all());
+/// ```
+#[cfg(feature = "derive-dialect")]
+pub use sqlparser_derive::derive_dialect;
+
 use crate::ast::{ColumnOption, Expr, GranteesType, Ident, ObjectNamePart, 
Statement};
 pub use crate::keywords;
 use crate::keywords::Keyword;
@@ -62,14 +138,14 @@ use alloc::boxed::Box;
 
 /// Convenience check if a [`Parser`] uses a certain dialect.
 ///
-/// Note: when possible please the new style, adding a method to the 
[`Dialect`]
-/// trait rather than using this macro.
+/// Note: when possible, please use the new style, adding a method to
+/// the [`Dialect`] trait rather than using this macro.
 ///
 /// The benefits of adding a method on `Dialect` over this macro are:
 /// 1. user defined [`Dialect`]s can customize the parsing behavior
 /// 2. The differences between dialects can be clearly documented in the trait
 ///
-/// `dialect_of!(parser is SQLiteDialect |  GenericDialect)` evaluates
+/// `dialect_of!(parser is SQLiteDialect | GenericDialect)` evaluates
 /// to `true` if `parser.dialect` is one of the [`Dialect`]s specified.
 macro_rules! dialect_of {
     ( $parsed_dialect: ident is $($dialect_type: ty)|+ ) => {
@@ -123,9 +199,8 @@ macro_rules! dialect_is {
 pub trait Dialect: Debug + Any {
     /// Determine the [`TypeId`] of this dialect.
     ///
-    /// By default, return the same [`TypeId`] as [`Any::type_id`]. Can be 
overridden
-    /// by dialects that behave like other dialects
-    /// (for example when wrapping a dialect).
+    /// By default, return the same [`TypeId`] as [`Any::type_id`]. Can be 
overridden by
+    /// dialects that behave like other dialects (for example, when wrapping a 
dialect).
     fn dialect(&self) -> TypeId {
         self.type_id()
     }
@@ -1646,6 +1721,27 @@ mod tests {
         dialect_from_str(v).unwrap()
     }
 
+    #[test]
+    #[cfg(feature = "derive-dialect")]
+    fn test_dialect_override() {
+        derive_dialect!(EnhancedGenericDialect, GenericDialect,
+            preserve_type_id = true,
+            overrides = {
+                supports_order_by_all = true,
+                supports_nested_comments = true,
+                supports_triple_quoted_string = true,
+            },
+        );
+        let dialect = EnhancedGenericDialect::new();
+
+        assert!(dialect.supports_order_by_all());
+        assert!(dialect.supports_nested_comments());
+        assert!(dialect.supports_triple_quoted_string());
+
+        let d: &dyn Dialect = &dialect;
+        assert!(d.is::<GenericDialect>());
+    }
+
     #[test]
     fn identifier_quote_style() {
         let tests: Vec<(&dyn Dialect, &str, Option<char>)> = vec![
diff --git a/src/dialect/mssql.rs b/src/dialect/mssql.rs
index 9f8e7265..24f7c7c4 100644
--- a/src/dialect/mssql.rs
+++ b/src/dialect/mssql.rs
@@ -28,7 +28,7 @@ use crate::tokenizer::Token;
 use alloc::{vec, vec::Vec};
 
 /// A [`Dialect`] for [Microsoft SQL 
Server](https://www.microsoft.com/en-us/sql-server/)
-#[derive(Debug)]
+#[derive(Debug, Default)]
 pub struct MsSqlDialect {}
 
 impl Dialect for MsSqlDialect {
diff --git a/src/dialect/mysql.rs b/src/dialect/mysql.rs
index b44001fe..ad3ba6f3 100644
--- a/src/dialect/mysql.rs
+++ b/src/dialect/mysql.rs
@@ -35,7 +35,7 @@ const RESERVED_FOR_TABLE_ALIAS_MYSQL: &[Keyword] = &[
 ];
 
 /// A [`Dialect`] for [MySQL](https://www.mysql.com/)
-#[derive(Debug)]
+#[derive(Debug, Default)]
 pub struct MySqlDialect {}
 
 impl Dialect for MySqlDialect {
diff --git a/src/dialect/oracle.rs b/src/dialect/oracle.rs
index 7ff93262..a72d5d7a 100644
--- a/src/dialect/oracle.rs
+++ b/src/dialect/oracle.rs
@@ -25,7 +25,7 @@ use crate::{
 use super::{Dialect, Precedence};
 
 /// A [`Dialect`] for [Oracle 
Databases](https://docs.oracle.com/en/database/oracle/oracle-database/21/sqlrf/index.html)
-#[derive(Debug)]
+#[derive(Debug, Default)]
 pub struct OracleDialect;
 
 impl Dialect for OracleDialect {
diff --git a/src/dialect/postgresql.rs b/src/dialect/postgresql.rs
index 7c9e7db8..1924a5e3 100644
--- a/src/dialect/postgresql.rs
+++ b/src/dialect/postgresql.rs
@@ -34,7 +34,7 @@ use crate::parser::{Parser, ParserError};
 use crate::tokenizer::Token;
 
 /// A [`Dialect`] for [PostgreSQL](https://www.postgresql.org/)
-#[derive(Debug)]
+#[derive(Debug, Default)]
 pub struct PostgreSqlDialect {}
 
 const PERIOD_PREC: u8 = 200;
diff --git a/src/dialect/redshift.rs b/src/dialect/redshift.rs
index c028061d..7b35848b 100644
--- a/src/dialect/redshift.rs
+++ b/src/dialect/redshift.rs
@@ -22,7 +22,7 @@ use core::str::Chars;
 use super::PostgreSqlDialect;
 
 /// A [`Dialect`] for [RedShift](https://aws.amazon.com/redshift/)
-#[derive(Debug)]
+#[derive(Debug, Default)]
 pub struct RedshiftSqlDialect {}
 
 // In most cases the redshift dialect is identical to [`PostgresSqlDialect`].
diff --git a/src/dialect/sqlite.rs b/src/dialect/sqlite.rs
index ba4cb617..7d1c935f 100644
--- a/src/dialect/sqlite.rs
+++ b/src/dialect/sqlite.rs
@@ -30,7 +30,7 @@ use crate::parser::{Parser, ParserError};
 /// [`CREATE TABLE`](https://sqlite.org/lang_createtable.html) statement with 
no
 /// type specified, as in `CREATE TABLE t1 (a)`. In the AST, these columns will
 /// have the data type [`Unspecified`](crate::ast::DataType::Unspecified).
-#[derive(Debug)]
+#[derive(Debug, Default)]
 pub struct SQLiteDialect {}
 
 impl Dialect for SQLiteDialect {
diff --git a/src/lib.rs b/src/lib.rs
index f5d23a21..e68d7f93 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -170,6 +170,9 @@ pub mod ast;
 #[macro_use]
 /// Submodules for SQL dialects.
 pub mod dialect;
+
+#[cfg(feature = "derive-dialect")]
+pub use dialect::derive_dialect;
 mod display_utils;
 pub mod keywords;
 pub mod parser;
diff --git a/tests/sqlparser_derive_dialect.rs 
b/tests/sqlparser_derive_dialect.rs
new file mode 100644
index 00000000..d60fa1e1
--- /dev/null
+++ b/tests/sqlparser_derive_dialect.rs
@@ -0,0 +1,123 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Tests for the `derive_dialect!` macro.
+
+use sqlparser::derive_dialect;
+use sqlparser::dialect::{Dialect, GenericDialect, MySqlDialect, 
PostgreSqlDialect};
+use sqlparser::parser::Parser;
+
+#[test]
+fn test_method_overrides() {
+    derive_dialect!(EnhancedGenericDialect, GenericDialect, overrides = {
+        supports_order_by_all = true,
+        supports_triple_quoted_string = true,
+    });
+    let dialect = EnhancedGenericDialect::new();
+
+    // Overridden methods
+    assert!(dialect.supports_order_by_all());
+    assert!(dialect.supports_triple_quoted_string());
+
+    // Non-overridden retains base behavior
+    assert!(!dialect.supports_factorial_operator());
+
+    // Parsing works with the overrides
+    let result = Parser::new(&dialect)
+        .try_with_sql("SELECT '''value''' FROM t ORDER BY ALL")
+        .unwrap()
+        .parse_statements();
+
+    assert!(result.is_ok());
+}
+
+#[test]
+fn test_preserve_type_id() {
+    // Check the override works and the parser recognizes it as the base type
+    derive_dialect!(
+        PreservedTypeDialect,
+        GenericDialect,
+        preserve_type_id = true,
+        overrides = { supports_order_by_all = true }
+    );
+    let dialect = PreservedTypeDialect::new();
+    let d: &dyn Dialect = &dialect;
+
+    assert!(dialect.supports_order_by_all());
+    assert!(d.is::<GenericDialect>());
+}
+
+#[test]
+fn test_different_base_dialects() {
+    derive_dialect!(
+        EnhancedMySqlDialect,
+        MySqlDialect,
+        overrides = { supports_order_by_all = true }
+    );
+    derive_dialect!(UniquePostgreSqlDialect, PostgreSqlDialect);
+
+    let pg = UniquePostgreSqlDialect::new();
+    let mysql = EnhancedMySqlDialect::new();
+
+    // Inherit different base behaviors
+    assert!(pg.supports_filter_during_aggregation()); // PostgreSQL feature
+    assert!(mysql.supports_string_literal_backslash_escape()); // MySQL feature
+    assert!(mysql.supports_order_by_all()); // Override
+
+    // Each has unique TypeId
+    let pg_ref: &dyn Dialect = &pg;
+    let mysql_ref: &dyn Dialect = &mysql;
+    assert!(pg_ref.is::<UniquePostgreSqlDialect>());
+    assert!(!pg_ref.is::<PostgreSqlDialect>());
+    assert!(mysql_ref.is::<EnhancedMySqlDialect>());
+}
+
+#[test]
+fn test_identifier_quote_style_overrides() {
+    derive_dialect!(
+        BacktickGenericDialect,
+        GenericDialect,
+        overrides = { identifier_quote_style = '`' }
+    );
+    derive_dialect!(
+        AnotherBacktickDialect,
+        GenericDialect,
+        overrides = { identifier_quote_style = '[' }
+    );
+    derive_dialect!(
+        QuotelessPostgreSqlDialect,
+        PostgreSqlDialect,
+        preserve_type_id = true,
+        overrides = { identifier_quote_style = None }
+    );
+
+    // Char literal (auto-wrapped in Some)
+    assert_eq!(
+        BacktickGenericDialect::new().identifier_quote_style("x"),
+        Some('`')
+    );
+    // Another char literal
+    assert_eq!(
+        AnotherBacktickDialect::new().identifier_quote_style("x"),
+        Some('[')
+    );
+    // None (overrides PostgreSQL's default '"')
+    assert_eq!(
+        QuotelessPostgreSqlDialect::new().identifier_quote_style("x"),
+        None
+    );
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to