This is an automated email from the ASF dual-hosted git repository.
github-bot pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-sqlparser-rs.git
The following commit(s) were added to refs/heads/main by this push:
new 2ac82e94 Streamlined derivation of new `Dialect` objects (#2174)
2ac82e94 is described below
commit 2ac82e946e5f6513b51b747f1783c1cf5f4a733d
Author: Alexander Beedie <[email protected]>
AuthorDate: Tue Feb 3 20:09:11 2026 +0700
Streamlined derivation of new `Dialect` objects (#2174)
---
Cargo.toml | 7 +-
derive/Cargo.toml | 2 +-
derive/src/dialect.rs | 305 ++++++++++++++++++++++++++++++++++++++
derive/src/lib.rs | 276 +++-------------------------------
derive/src/{lib.rs => visit.rs} | 47 ++----
src/dialect/ansi.rs | 2 +-
src/dialect/clickhouse.rs | 2 +-
src/dialect/hive.rs | 2 +-
src/dialect/mod.rs | 108 +++++++++++++-
src/dialect/mssql.rs | 2 +-
src/dialect/mysql.rs | 2 +-
src/dialect/oracle.rs | 2 +-
src/dialect/postgresql.rs | 2 +-
src/dialect/redshift.rs | 2 +-
src/dialect/sqlite.rs | 2 +-
src/lib.rs | 3 +
tests/sqlparser_derive_dialect.rs | 123 +++++++++++++++
17 files changed, 585 insertions(+), 304 deletions(-)
diff --git a/Cargo.toml b/Cargo.toml
index 177ab3db..8945adef 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -42,6 +42,7 @@ std = []
recursive-protection = ["std", "recursive"]
# Enable JSON output in the `cli` example:
json_example = ["serde_json", "serde"]
+derive-dialect = ["sqlparser_derive"]
visitor = ["sqlparser_derive"]
[dependencies]
@@ -61,6 +62,10 @@ simple_logger = "5.0"
matches = "0.1"
pretty_assertions = "1"
+[[test]]
+name = "sqlparser_derive_dialect"
+required-features = ["derive-dialect"]
+
[package.metadata.docs.rs]
# Document these features on docs.rs
-features = ["serde", "visitor"]
+features = ["serde", "visitor", "derive-dialect"]
diff --git a/derive/Cargo.toml b/derive/Cargo.toml
index 54947704..f2f54926 100644
--- a/derive/Cargo.toml
+++ b/derive/Cargo.toml
@@ -36,6 +36,6 @@ edition = "2021"
proc-macro = true
[dependencies]
-syn = { version = "2.0", default-features = false, features = ["printing",
"parsing", "derive", "proc-macro"] }
+syn = { version = "2.0", default-features = false, features = ["full",
"printing", "parsing", "derive", "proc-macro", "clone-impls"] }
proc-macro2 = "1.0"
quote = "1.0"
diff --git a/derive/src/dialect.rs b/derive/src/dialect.rs
new file mode 100644
index 00000000..9873e4f7
--- /dev/null
+++ b/derive/src/dialect.rs
@@ -0,0 +1,305 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Implementation of the `derive_dialect!` macro for creating custom SQL
dialects.
+
+use proc_macro2::TokenStream;
+use quote::{quote, quote_spanned};
+use std::collections::HashSet;
+use syn::{
+ braced,
+ parse::{Parse, ParseStream},
+ Error, File, FnArg, Ident, Item, LitBool, LitChar, Pat, ReturnType,
Signature, Token,
+ TraitItem, Type,
+};
+
+/// Override value types supported by the macro
+pub(crate) enum Override {
+ Bool(LitBool),
+ Char(LitChar),
+ None,
+}
+
+/// Parsed input for the `derive_dialect!` macro
+pub(crate) struct DeriveDialectInput {
+ pub name: Ident,
+ pub base: Type,
+ pub preserve_type_id: bool,
+ pub overrides: Vec<(Ident, Override)>,
+}
+
+/// `Dialect` trait method attrs
+struct DialectMethod {
+ name: Ident,
+ signature: Signature,
+}
+
+impl Parse for DeriveDialectInput {
+ fn parse(input: ParseStream) -> syn::Result<Self> {
+ let name: Ident = input.parse()?;
+ input.parse::<Token![,]>()?;
+ let base: Type = input.parse()?;
+
+ let mut preserve_type_id = false;
+ let mut overrides = Vec::new();
+
+ while input.peek(Token![,]) {
+ input.parse::<Token![,]>()?;
+ if input.is_empty() {
+ break;
+ }
+ if input.peek(Ident) {
+ let ident: Ident = input.parse()?;
+ match ident.to_string().as_str() {
+ "preserve_type_id" => {
+ input.parse::<Token![=]>()?;
+ preserve_type_id = input.parse::<LitBool>()?.value();
+ }
+ "overrides" => {
+ input.parse::<Token![=]>()?;
+ let content;
+ braced!(content in input);
+ while !content.is_empty() {
+ let key: Ident = content.parse()?;
+ content.parse::<Token![=]>()?;
+ let value = if content.peek(LitBool) {
+ Override::Bool(content.parse()?)
+ } else if content.peek(LitChar) {
+ Override::Char(content.parse()?)
+ } else if content.peek(Ident) {
+ let ident: Ident = content.parse()?;
+ if ident == "None" {
+ Override::None
+ } else {
+ return Err(Error::new(
+ ident.span(),
+ format!("Expected `true`, `false`, a
char, or `None`, found `{ident}`"),
+ ));
+ }
+ } else {
+ return Err(
+ content.error("Expected `true`, `false`, a
char, or `None`")
+ );
+ };
+ overrides.push((key, value));
+ if content.peek(Token![,]) {
+ content.parse::<Token![,]>()?;
+ }
+ }
+ }
+ other => {
+ return Err(Error::new(ident.span(), format!(
+ "Unknown argument `{other}`. Expected
`preserve_type_id` or `overrides`."
+ )));
+ }
+ }
+ }
+ }
+ Ok(DeriveDialectInput {
+ name,
+ base,
+ preserve_type_id,
+ overrides,
+ })
+ }
+}
+
+/// Entry point for the `derive_dialect!` macro
+pub(crate) fn derive_dialect(input: DeriveDialectInput) ->
proc_macro::TokenStream {
+ let err = |msg: String| {
+ Error::new(proc_macro2::Span::call_site(), msg)
+ .to_compile_error()
+ .into()
+ };
+
+ let source = match read_dialect_mod_file() {
+ Ok(s) => s,
+ Err(e) => return err(format!("Failed to read dialect/mod.rs: {e}")),
+ };
+ let file: File = match syn::parse_str(&source) {
+ Ok(f) => f,
+ Err(e) => return err(format!("Failed to parse source: {e}")),
+ };
+ let methods = match extract_dialect_methods(&file) {
+ Ok(m) => m,
+ Err(e) => return e.to_compile_error().into(),
+ };
+
+ // Validate overrides
+ let bool_names: HashSet<_> = methods
+ .iter()
+ .filter(|m| is_bool_method(&m.signature))
+ .map(|m| m.name.to_string())
+ .collect();
+ for (key, value) in &input.overrides {
+ let key_str = key.to_string();
+ let err = |msg| Error::new(key.span(), msg).to_compile_error().into();
+ match value {
+ Override::Bool(_) if !bool_names.contains(&key_str) => {
+ return err(format!("Unknown boolean method `{key_str}`"));
+ }
+ Override::Char(_) | Override::None if key_str !=
"identifier_quote_style" => {
+ return err(format!(
+ "Char/None only valid for `identifier_quote_style`, not
`{key_str}`"
+ ));
+ }
+ _ => {}
+ }
+ }
+ generate_derived_dialect(&input, &methods).into()
+}
+
+/// Generate the complete derived `Dialect` implementation
+fn generate_derived_dialect(input: &DeriveDialectInput, methods:
&[DialectMethod]) -> TokenStream {
+ let name = &input.name;
+ let base = &input.base;
+
+ // Helper to find an override by method name
+ let find_override = |method_name: &str| {
+ input
+ .overrides
+ .iter()
+ .find(|(k, _)| k == method_name)
+ .map(|(_, v)| v)
+ };
+
+ // Helper to generate delegation to base dialect
+ let delegate = |method: &DialectMethod| {
+ let sig = &method.signature;
+ let method_name = &method.name;
+ let params = extract_param_names(sig);
+ quote_spanned! { method_name.span() => #sig {
self.dialect.#method_name(#(#params),*) } }
+ };
+
+ // Generate the struct
+ let struct_def = quote_spanned! { name.span() =>
+ #[derive(Debug, Default)]
+ pub struct #name {
+ dialect: #base,
+ }
+ impl #name {
+ pub fn new() -> Self { Self::default() }
+ }
+ };
+
+ // Generate TypeId method body
+ let type_id_body = if input.preserve_type_id {
+ quote! { Dialect::dialect(&self.dialect) }
+ } else {
+ quote! { ::core::any::TypeId::of::<#name>() }
+ };
+
+ // Generate method implementations
+ let method_impls = methods.iter().map(|method| {
+ let method_name = &method.name;
+ match find_override(&method_name.to_string()) {
+ Some(Override::Bool(value)) => {
+ quote_spanned! { method_name.span() => fn #method_name(&self)
-> bool { #value } }
+ }
+ Some(Override::Char(c)) => {
+ quote_spanned! { method_name.span() =>
+ fn identifier_quote_style(&self, _: &str) -> Option<char>
{ Some(#c) }
+ }
+ }
+ Some(Override::None) => {
+ quote_spanned! { method_name.span() =>
+ fn identifier_quote_style(&self, _: &str) -> Option<char>
{ None }
+ }
+ }
+ None => delegate(method),
+ }
+ });
+
+ // Wrap impl in a const block with scoped imports so types resolve without
qualification
+ quote! {
+ #struct_def
+ const _: () = {
+ use ::core::iter::Peekable;
+ use ::core::str::Chars;
+ use sqlparser::ast::{ColumnOption, Expr, GranteesType, Ident,
ObjectNamePart, Statement};
+ use sqlparser::dialect::{Dialect, Precedence};
+ use sqlparser::keywords::Keyword;
+ use sqlparser::parser::{Parser, ParserError};
+
+ impl Dialect for #name {
+ fn dialect(&self) -> ::core::any::TypeId { #type_id_body }
+ #(#method_impls)*
+ }
+ };
+ }
+}
+
+/// Extract parameter names from a method signature (excluding self)
+fn extract_param_names(sig: &Signature) -> Vec<&Ident> {
+ sig.inputs
+ .iter()
+ .filter_map(|arg| match arg {
+ FnArg::Typed(pt) => match pt.pat.as_ref() {
+ Pat::Ident(pi) => Some(&pi.ident),
+ _ => None,
+ },
+ _ => None,
+ })
+ .collect()
+}
+
+/// Read the `dialect/mod.rs` file that contains the Dialect trait.
+fn read_dialect_mod_file() -> Result<String, String> {
+ let manifest_dir =
+ std::env::var("CARGO_MANIFEST_DIR").map_err(|_| "CARGO_MANIFEST_DIR
not set")?;
+ let path = std::path::Path::new(&manifest_dir).join("src/dialect/mod.rs");
+ std::fs::read_to_string(&path).map_err(|e| format!("Failed to read {}:
{e}", path.display()))
+}
+
+/// Extract all methods from the `Dialect` trait (excluding `dialect` for
TypeId)
+fn extract_dialect_methods(file: &File) -> Result<Vec<DialectMethod>, Error> {
+ let dialect_trait = file
+ .items
+ .iter()
+ .find_map(|item| match item {
+ Item::Trait(t) if t.ident == "Dialect" => Some(t),
+ _ => None,
+ })
+ .ok_or_else(|| Error::new(proc_macro2::Span::call_site(), "Dialect
trait not found"))?;
+
+ let mut methods: Vec<_> = dialect_trait
+ .items
+ .iter()
+ .filter_map(|item| match item {
+ TraitItem::Fn(m) if m.sig.ident != "dialect" => Some(DialectMethod
{
+ name: m.sig.ident.clone(),
+ signature: m.sig.clone(),
+ }),
+ _ => None,
+ })
+ .collect();
+ methods.sort_by_key(|m| m.name.to_string());
+ Ok(methods)
+}
+
+/// Check if a method signature is `fn name(&self) -> bool`
+fn is_bool_method(sig: &Signature) -> bool {
+ sig.inputs.len() == 1
+ && matches!(
+ sig.inputs.first(),
+ Some(FnArg::Receiver(r)) if r.reference.is_some() &&
r.mutability.is_none()
+ )
+ && matches!(
+ &sig.output,
+ ReturnType::Type(_, ty) if matches!(ty.as_ref(), Type::Path(p) if
p.path.is_ident("bool"))
+ )
+}
diff --git a/derive/src/lib.rs b/derive/src/lib.rs
index 08c5c5db..e3eaeea6 100644
--- a/derive/src/lib.rs
+++ b/derive/src/lib.rs
@@ -15,22 +15,25 @@
// specific language governing permissions and limitations
// under the License.
-use proc_macro2::TokenStream;
-use quote::{format_ident, quote, quote_spanned, ToTokens};
-use syn::spanned::Spanned;
-use syn::{
- parse::{Parse, ParseStream},
- parse_macro_input, parse_quote, Attribute, Data, DeriveInput, Fields,
GenericParam, Generics,
- Ident, Index, LitStr, Meta, Token, Type, TypePath,
-};
-use syn::{Path, PathArguments};
+//! Procedural macros for sqlparser.
+//!
+//! This crate provides:
+//! - [`Visit`] and [`VisitMut`] derive macros for AST traversal.
+//! - [`derive_dialect!`] macro for creating custom SQL dialects.
-/// Implementation of `[#derive(Visit)]`
+use quote::quote;
+use syn::parse_macro_input;
+
+mod dialect;
+mod visit;
+
+/// Implementation of `#[derive(VisitMut)]`
#[proc_macro_derive(VisitMut, attributes(visit))]
pub fn derive_visit_mut(input: proc_macro::TokenStream) ->
proc_macro::TokenStream {
- derive_visit(
+ let input = parse_macro_input!(input as syn::DeriveInput);
+ visit::derive_visit(
input,
- &VisitType {
+ &visit::VisitType {
visit_trait: quote!(VisitMut),
visitor_trait: quote!(VisitorMut),
modifier: Some(quote!(mut)),
@@ -38,12 +41,13 @@ pub fn derive_visit_mut(input: proc_macro::TokenStream) ->
proc_macro::TokenStre
)
}
-/// Implementation of `[#derive(Visit)]`
+/// Implementation of `#[derive(Visit)]`
#[proc_macro_derive(Visit, attributes(visit))]
pub fn derive_visit_immutable(input: proc_macro::TokenStream) ->
proc_macro::TokenStream {
- derive_visit(
+ let input = parse_macro_input!(input as syn::DeriveInput);
+ visit::derive_visit(
input,
- &VisitType {
+ &visit::VisitType {
visit_trait: quote!(Visit),
visitor_trait: quote!(Visitor),
modifier: None,
@@ -51,241 +55,9 @@ pub fn derive_visit_immutable(input:
proc_macro::TokenStream) -> proc_macro::Tok
)
}
-struct VisitType {
- visit_trait: TokenStream,
- visitor_trait: TokenStream,
- modifier: Option<TokenStream>,
-}
-
-fn derive_visit(input: proc_macro::TokenStream, visit_type: &VisitType) ->
proc_macro::TokenStream {
- // Parse the input tokens into a syntax tree.
- let input = parse_macro_input!(input as DeriveInput);
- let name = input.ident;
-
- let VisitType {
- visit_trait,
- visitor_trait,
- modifier,
- } = visit_type;
-
- let attributes = Attributes::parse(&input.attrs);
- // Add a bound `T: Visit` to every type parameter T.
- let generics = add_trait_bounds(input.generics, visit_type);
- let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
-
- let (pre_visit, post_visit) = attributes.visit(quote!(self));
- let children = visit_children(&input.data, visit_type);
-
- let expanded = quote! {
- // The generated impl.
- // Note that it uses [`recursive::recursive`] to protect from stack
overflow.
- // See tests in
https://github.com/apache/datafusion-sqlparser-rs/pull/1522/ for more info.
- impl #impl_generics sqlparser::ast::#visit_trait for #name
#ty_generics #where_clause {
- #[cfg_attr(feature = "recursive-protection",
recursive::recursive)]
- fn visit<V: sqlparser::ast::#visitor_trait>(
- &#modifier self,
- visitor: &mut V
- ) -> ::std::ops::ControlFlow<V::Break> {
- #pre_visit
- #children
- #post_visit
- ::std::ops::ControlFlow::Continue(())
- }
- }
- };
-
- proc_macro::TokenStream::from(expanded)
-}
-
-/// Parses attributes that can be provided to this macro
-///
-/// `#[visit(leaf, with = "visit_expr")]`
-#[derive(Default)]
-struct Attributes {
- /// Content for the `with` attribute
- with: Option<Ident>,
-}
-
-struct WithIdent {
- with: Option<Ident>,
-}
-impl Parse for WithIdent {
- fn parse(input: ParseStream) -> Result<Self, syn::Error> {
- let mut result = WithIdent { with: None };
- let ident = input.parse::<Ident>()?;
- if ident != "with" {
- return Err(syn::Error::new(
- ident.span(),
- "Expected identifier to be `with`",
- ));
- }
- input.parse::<Token!(=)>()?;
- let s = input.parse::<LitStr>()?;
- result.with = Some(format_ident!("{}", s.value(), span = s.span()));
- Ok(result)
- }
-}
-
-impl Attributes {
- fn parse(attrs: &[Attribute]) -> Self {
- let mut out = Self::default();
- for attr in attrs {
- if let Meta::List(ref metalist) = attr.meta {
- if metalist.path.is_ident("visit") {
- match syn::parse2::<WithIdent>(metalist.tokens.clone()) {
- Ok(with_ident) => {
- out.with = with_ident.with;
- }
- Err(e) => {
- panic!("{}", e);
- }
- }
- }
- }
- }
- out
- }
-
- /// Returns the pre and post visit token streams
- fn visit(&self, s: TokenStream) -> (Option<TokenStream>,
Option<TokenStream>) {
- let pre_visit = self.with.as_ref().map(|m| {
- let m = format_ident!("pre_{}", m);
- quote!(visitor.#m(#s)?;)
- });
- let post_visit = self.with.as_ref().map(|m| {
- let m = format_ident!("post_{}", m);
- quote!(visitor.#m(#s)?;)
- });
- (pre_visit, post_visit)
- }
-}
-
-// Add a bound `T: Visit` to every type parameter T.
-fn add_trait_bounds(mut generics: Generics, VisitType { visit_trait, .. }:
&VisitType) -> Generics {
- for param in &mut generics.params {
- if let GenericParam::Type(ref mut type_param) = *param {
- type_param
- .bounds
- .push(parse_quote!(sqlparser::ast::#visit_trait));
- }
- }
- generics
-}
-
-// Generate the body of the visit implementation for the given type
-fn visit_children(
- data: &Data,
- VisitType {
- visit_trait,
- modifier,
- ..
- }: &VisitType,
-) -> TokenStream {
- match data {
- Data::Struct(data) => match &data.fields {
- Fields::Named(fields) => {
- let recurse = fields.named.iter().map(|f| {
- let name = &f.ident;
- let is_option = is_option(&f.ty);
- let attributes = Attributes::parse(&f.attrs);
- if is_option && attributes.with.is_some() {
- let (pre_visit, post_visit) =
attributes.visit(quote!(value));
- quote_spanned!(f.span() =>
- if let Some(value) = &#modifier self.#name {
- #pre_visit
sqlparser::ast::#visit_trait::visit(value, visitor)?; #post_visit
- }
- )
- } else {
- let (pre_visit, post_visit) =
attributes.visit(quote!(&#modifier self.#name));
- quote_spanned!(f.span() =>
- #pre_visit
sqlparser::ast::#visit_trait::visit(&#modifier self.#name, visitor)?;
#post_visit
- )
- }
- });
- quote! {
- #(#recurse)*
- }
- }
- Fields::Unnamed(fields) => {
- let recurse = fields.unnamed.iter().enumerate().map(|(i, f)| {
- let index = Index::from(i);
- let attributes = Attributes::parse(&f.attrs);
- let (pre_visit, post_visit) =
attributes.visit(quote!(&self.#index));
- quote_spanned!(f.span() => #pre_visit
sqlparser::ast::#visit_trait::visit(&#modifier self.#index, visitor)?;
#post_visit)
- });
- quote! {
- #(#recurse)*
- }
- }
- Fields::Unit => {
- quote!()
- }
- },
- Data::Enum(data) => {
- let statements = data.variants.iter().map(|v| {
- let name = &v.ident;
- match &v.fields {
- Fields::Named(fields) => {
- let names = fields.named.iter().map(|f| &f.ident);
- let visit = fields.named.iter().map(|f| {
- let name = &f.ident;
- let attributes = Attributes::parse(&f.attrs);
- let (pre_visit, post_visit) =
attributes.visit(name.to_token_stream());
- quote_spanned!(f.span() => #pre_visit
sqlparser::ast::#visit_trait::visit(#name, visitor)?; #post_visit)
- });
-
- quote!(
- Self::#name { #(#names),* } => {
- #(#visit)*
- }
- )
- }
- Fields::Unnamed(fields) => {
- let names = fields.unnamed.iter().enumerate().map(|(i,
f)| format_ident!("_{}", i, span = f.span()));
- let visit = fields.unnamed.iter().enumerate().map(|(i,
f)| {
- let name = format_ident!("_{}", i);
- let attributes = Attributes::parse(&f.attrs);
- let (pre_visit, post_visit) =
attributes.visit(name.to_token_stream());
- quote_spanned!(f.span() => #pre_visit
sqlparser::ast::#visit_trait::visit(#name, visitor)?; #post_visit)
- });
-
- quote! {
- Self::#name ( #(#names),*) => {
- #(#visit)*
- }
- }
- }
- Fields::Unit => {
- quote! {
- Self::#name => {}
- }
- }
- }
- });
-
- quote! {
- match self {
- #(#statements),*
- }
- }
- }
- Data::Union(_) => unimplemented!(),
- }
-}
-
-fn is_option(ty: &Type) -> bool {
- if let Type::Path(TypePath {
- path: Path { segments, .. },
- ..
- }) = ty
- {
- if let Some(segment) = segments.last() {
- if segment.ident == "Option" {
- if let PathArguments::AngleBracketed(args) =
&segment.arguments {
- return args.args.len() == 1;
- }
- }
- }
- }
- false
+/// Procedural macro for deriving new SQL dialects.
+#[proc_macro]
+pub fn derive_dialect(input: proc_macro::TokenStream) ->
proc_macro::TokenStream {
+ let input = parse_macro_input!(input as dialect::DeriveDialectInput);
+ dialect::derive_dialect(input)
}
diff --git a/derive/src/lib.rs b/derive/src/visit.rs
similarity index 88%
copy from derive/src/lib.rs
copy to derive/src/visit.rs
index 08c5c5db..baf3eb58 100644
--- a/derive/src/lib.rs
+++ b/derive/src/visit.rs
@@ -15,51 +15,28 @@
// specific language governing permissions and limitations
// under the License.
+//! Implementation of the `Visit` and `VisitMut` derive macros.
+
use proc_macro2::TokenStream;
use quote::{format_ident, quote, quote_spanned, ToTokens};
use syn::spanned::Spanned;
use syn::{
parse::{Parse, ParseStream},
- parse_macro_input, parse_quote, Attribute, Data, DeriveInput, Fields,
GenericParam, Generics,
- Ident, Index, LitStr, Meta, Token, Type, TypePath,
+ parse_quote, Attribute, Data, Fields, GenericParam, Generics, Ident,
Index, LitStr, Meta,
+ Token, Type, TypePath,
};
use syn::{Path, PathArguments};
-/// Implementation of `[#derive(Visit)]`
-#[proc_macro_derive(VisitMut, attributes(visit))]
-pub fn derive_visit_mut(input: proc_macro::TokenStream) ->
proc_macro::TokenStream {
- derive_visit(
- input,
- &VisitType {
- visit_trait: quote!(VisitMut),
- visitor_trait: quote!(VisitorMut),
- modifier: Some(quote!(mut)),
- },
- )
-}
-
-/// Implementation of `[#derive(Visit)]`
-#[proc_macro_derive(Visit, attributes(visit))]
-pub fn derive_visit_immutable(input: proc_macro::TokenStream) ->
proc_macro::TokenStream {
- derive_visit(
- input,
- &VisitType {
- visit_trait: quote!(Visit),
- visitor_trait: quote!(Visitor),
- modifier: None,
- },
- )
-}
-
-struct VisitType {
- visit_trait: TokenStream,
- visitor_trait: TokenStream,
- modifier: Option<TokenStream>,
+pub(crate) struct VisitType {
+ pub visit_trait: TokenStream,
+ pub visitor_trait: TokenStream,
+ pub modifier: Option<TokenStream>,
}
-fn derive_visit(input: proc_macro::TokenStream, visit_type: &VisitType) ->
proc_macro::TokenStream {
- // Parse the input tokens into a syntax tree.
- let input = parse_macro_input!(input as DeriveInput);
+pub(crate) fn derive_visit(
+ input: syn::DeriveInput,
+ visit_type: &VisitType,
+) -> proc_macro::TokenStream {
let name = input.ident;
let VisitType {
diff --git a/src/dialect/ansi.rs b/src/dialect/ansi.rs
index ec3c095b..5a54390c 100644
--- a/src/dialect/ansi.rs
+++ b/src/dialect/ansi.rs
@@ -18,7 +18,7 @@
use crate::dialect::Dialect;
/// A [`Dialect`] for [ANSI SQL](https://en.wikipedia.org/wiki/SQL:2011).
-#[derive(Debug)]
+#[derive(Debug, Default)]
pub struct AnsiDialect {}
impl Dialect for AnsiDialect {
diff --git a/src/dialect/clickhouse.rs b/src/dialect/clickhouse.rs
index 041b94ec..f8b6807f 100644
--- a/src/dialect/clickhouse.rs
+++ b/src/dialect/clickhouse.rs
@@ -18,7 +18,7 @@
use crate::dialect::Dialect;
/// A [`Dialect`] for [ClickHouse](https://clickhouse.com/).
-#[derive(Debug)]
+#[derive(Debug, Default)]
pub struct ClickHouseDialect {}
impl Dialect for ClickHouseDialect {
diff --git a/src/dialect/hive.rs b/src/dialect/hive.rs
index 3e15d395..32a982e9 100644
--- a/src/dialect/hive.rs
+++ b/src/dialect/hive.rs
@@ -18,7 +18,7 @@
use crate::dialect::Dialect;
/// A [`Dialect`] for [Hive](https://hive.apache.org/).
-#[derive(Debug)]
+#[derive(Debug, Default)]
pub struct HiveDialect {}
impl Dialect for HiveDialect {
diff --git a/src/dialect/mod.rs b/src/dialect/mod.rs
index ef563fc1..477d60f8 100644
--- a/src/dialect/mod.rs
+++ b/src/dialect/mod.rs
@@ -51,6 +51,82 @@ pub use self::postgresql::PostgreSqlDialect;
pub use self::redshift::RedshiftSqlDialect;
pub use self::snowflake::SnowflakeDialect;
pub use self::sqlite::SQLiteDialect;
+
+/// Macro for streamlining the creation of derived `Dialect` objects.
+/// The generated struct includes `new()` and `default()` constructors.
+/// Requires the `derive-dialect` feature.
+///
+/// # Syntax
+///
+/// ```text
+/// derive_dialect!(NewDialect, BaseDialect);
+/// derive_dialect!(NewDialect, BaseDialect, overrides = { method = value, ...
});
+/// derive_dialect!(NewDialect, BaseDialect, preserve_type_id = true);
+/// derive_dialect!(NewDialect, BaseDialect, preserve_type_id = true,
overrides = { ... });
+/// ```
+///
+/// # Example
+///
+/// ```
+/// use sqlparser::derive_dialect;
+/// use sqlparser::dialect::{Dialect, GenericDialect};
+///
+/// // Override boolean methods (supports_*, allow_*, etc.)
+/// derive_dialect!(CustomDialect, GenericDialect, overrides = {
+/// supports_order_by_all = true,
+/// supports_nested_comments = true,
+/// });
+///
+/// let dialect = CustomDialect::new();
+/// assert!(dialect.supports_order_by_all());
+/// assert!(dialect.supports_nested_comments());
+/// ```
+///
+/// # Overriding `identifier_quote_style`
+///
+/// Use a char literal or `None`:
+/// ```
+/// use sqlparser::derive_dialect;
+/// use sqlparser::dialect::{Dialect, PostgreSqlDialect};
+///
+/// derive_dialect!(BacktickPostgreSqlDialect, PostgreSqlDialect,
+/// preserve_type_id = true,
+/// overrides = { identifier_quote_style = '`' }
+/// );
+/// let d: &dyn Dialect = &BacktickPostgreSqlDialect::new();
+/// assert_eq!(d.identifier_quote_style("foo"), Some('`'));
+///
+/// derive_dialect!(QuotelessPostgreSqlDialect, PostgreSqlDialect,
+/// preserve_type_id = true,
+/// overrides = { identifier_quote_style = None }
+/// );
+/// let d: &dyn Dialect = &QuotelessPostgreSqlDialect::new();
+/// assert_eq!(d.identifier_quote_style("foo"), None);
+/// ```
+///
+/// # Type Identity
+///
+/// By default, derived dialects have their own `TypeId`. Set
`preserve_type_id = true` to
+/// retain the base dialect's identity with respect to the parser's
`dialect.is::<T>()` checks:
+/// ```
+/// use sqlparser::derive_dialect;
+/// use sqlparser::dialect::{Dialect, GenericDialect};
+///
+/// derive_dialect!(EnhancedGenericDialect, GenericDialect,
+/// preserve_type_id = true,
+/// overrides = {
+/// supports_order_by_all = true,
+/// supports_nested_comments = true,
+/// }
+/// );
+/// let d: &dyn Dialect = &EnhancedGenericDialect::new();
+/// assert!(d.is::<GenericDialect>()); // still recognized as a GenericDialect
+/// assert!(d.supports_nested_comments());
+/// assert!(d.supports_order_by_all());
+/// ```
+#[cfg(feature = "derive-dialect")]
+pub use sqlparser_derive::derive_dialect;
+
use crate::ast::{ColumnOption, Expr, GranteesType, Ident, ObjectNamePart,
Statement};
pub use crate::keywords;
use crate::keywords::Keyword;
@@ -62,14 +138,14 @@ use alloc::boxed::Box;
/// Convenience check if a [`Parser`] uses a certain dialect.
///
-/// Note: when possible please the new style, adding a method to the
[`Dialect`]
-/// trait rather than using this macro.
+/// Note: when possible, please use the new style, adding a method to
+/// the [`Dialect`] trait rather than using this macro.
///
/// The benefits of adding a method on `Dialect` over this macro are:
/// 1. user defined [`Dialect`]s can customize the parsing behavior
/// 2. The differences between dialects can be clearly documented in the trait
///
-/// `dialect_of!(parser is SQLiteDialect | GenericDialect)` evaluates
+/// `dialect_of!(parser is SQLiteDialect | GenericDialect)` evaluates
/// to `true` if `parser.dialect` is one of the [`Dialect`]s specified.
macro_rules! dialect_of {
( $parsed_dialect: ident is $($dialect_type: ty)|+ ) => {
@@ -123,9 +199,8 @@ macro_rules! dialect_is {
pub trait Dialect: Debug + Any {
/// Determine the [`TypeId`] of this dialect.
///
- /// By default, return the same [`TypeId`] as [`Any::type_id`]. Can be
overridden
- /// by dialects that behave like other dialects
- /// (for example when wrapping a dialect).
+ /// By default, return the same [`TypeId`] as [`Any::type_id`]. Can be
overridden by
+ /// dialects that behave like other dialects (for example, when wrapping a
dialect).
fn dialect(&self) -> TypeId {
self.type_id()
}
@@ -1646,6 +1721,27 @@ mod tests {
dialect_from_str(v).unwrap()
}
+ #[test]
+ #[cfg(feature = "derive-dialect")]
+ fn test_dialect_override() {
+ derive_dialect!(EnhancedGenericDialect, GenericDialect,
+ preserve_type_id = true,
+ overrides = {
+ supports_order_by_all = true,
+ supports_nested_comments = true,
+ supports_triple_quoted_string = true,
+ },
+ );
+ let dialect = EnhancedGenericDialect::new();
+
+ assert!(dialect.supports_order_by_all());
+ assert!(dialect.supports_nested_comments());
+ assert!(dialect.supports_triple_quoted_string());
+
+ let d: &dyn Dialect = &dialect;
+ assert!(d.is::<GenericDialect>());
+ }
+
#[test]
fn identifier_quote_style() {
let tests: Vec<(&dyn Dialect, &str, Option<char>)> = vec![
diff --git a/src/dialect/mssql.rs b/src/dialect/mssql.rs
index 9f8e7265..24f7c7c4 100644
--- a/src/dialect/mssql.rs
+++ b/src/dialect/mssql.rs
@@ -28,7 +28,7 @@ use crate::tokenizer::Token;
use alloc::{vec, vec::Vec};
/// A [`Dialect`] for [Microsoft SQL
Server](https://www.microsoft.com/en-us/sql-server/)
-#[derive(Debug)]
+#[derive(Debug, Default)]
pub struct MsSqlDialect {}
impl Dialect for MsSqlDialect {
diff --git a/src/dialect/mysql.rs b/src/dialect/mysql.rs
index b44001fe..ad3ba6f3 100644
--- a/src/dialect/mysql.rs
+++ b/src/dialect/mysql.rs
@@ -35,7 +35,7 @@ const RESERVED_FOR_TABLE_ALIAS_MYSQL: &[Keyword] = &[
];
/// A [`Dialect`] for [MySQL](https://www.mysql.com/)
-#[derive(Debug)]
+#[derive(Debug, Default)]
pub struct MySqlDialect {}
impl Dialect for MySqlDialect {
diff --git a/src/dialect/oracle.rs b/src/dialect/oracle.rs
index 7ff93262..a72d5d7a 100644
--- a/src/dialect/oracle.rs
+++ b/src/dialect/oracle.rs
@@ -25,7 +25,7 @@ use crate::{
use super::{Dialect, Precedence};
/// A [`Dialect`] for [Oracle
Databases](https://docs.oracle.com/en/database/oracle/oracle-database/21/sqlrf/index.html)
-#[derive(Debug)]
+#[derive(Debug, Default)]
pub struct OracleDialect;
impl Dialect for OracleDialect {
diff --git a/src/dialect/postgresql.rs b/src/dialect/postgresql.rs
index 7c9e7db8..1924a5e3 100644
--- a/src/dialect/postgresql.rs
+++ b/src/dialect/postgresql.rs
@@ -34,7 +34,7 @@ use crate::parser::{Parser, ParserError};
use crate::tokenizer::Token;
/// A [`Dialect`] for [PostgreSQL](https://www.postgresql.org/)
-#[derive(Debug)]
+#[derive(Debug, Default)]
pub struct PostgreSqlDialect {}
const PERIOD_PREC: u8 = 200;
diff --git a/src/dialect/redshift.rs b/src/dialect/redshift.rs
index c028061d..7b35848b 100644
--- a/src/dialect/redshift.rs
+++ b/src/dialect/redshift.rs
@@ -22,7 +22,7 @@ use core::str::Chars;
use super::PostgreSqlDialect;
/// A [`Dialect`] for [RedShift](https://aws.amazon.com/redshift/)
-#[derive(Debug)]
+#[derive(Debug, Default)]
pub struct RedshiftSqlDialect {}
// In most cases the redshift dialect is identical to [`PostgresSqlDialect`].
diff --git a/src/dialect/sqlite.rs b/src/dialect/sqlite.rs
index ba4cb617..7d1c935f 100644
--- a/src/dialect/sqlite.rs
+++ b/src/dialect/sqlite.rs
@@ -30,7 +30,7 @@ use crate::parser::{Parser, ParserError};
/// [`CREATE TABLE`](https://sqlite.org/lang_createtable.html) statement with
no
/// type specified, as in `CREATE TABLE t1 (a)`. In the AST, these columns will
/// have the data type [`Unspecified`](crate::ast::DataType::Unspecified).
-#[derive(Debug)]
+#[derive(Debug, Default)]
pub struct SQLiteDialect {}
impl Dialect for SQLiteDialect {
diff --git a/src/lib.rs b/src/lib.rs
index f5d23a21..e68d7f93 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -170,6 +170,9 @@ pub mod ast;
#[macro_use]
/// Submodules for SQL dialects.
pub mod dialect;
+
+#[cfg(feature = "derive-dialect")]
+pub use dialect::derive_dialect;
mod display_utils;
pub mod keywords;
pub mod parser;
diff --git a/tests/sqlparser_derive_dialect.rs
b/tests/sqlparser_derive_dialect.rs
new file mode 100644
index 00000000..d60fa1e1
--- /dev/null
+++ b/tests/sqlparser_derive_dialect.rs
@@ -0,0 +1,123 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Tests for the `derive_dialect!` macro.
+
+use sqlparser::derive_dialect;
+use sqlparser::dialect::{Dialect, GenericDialect, MySqlDialect,
PostgreSqlDialect};
+use sqlparser::parser::Parser;
+
+#[test]
+fn test_method_overrides() {
+ derive_dialect!(EnhancedGenericDialect, GenericDialect, overrides = {
+ supports_order_by_all = true,
+ supports_triple_quoted_string = true,
+ });
+ let dialect = EnhancedGenericDialect::new();
+
+ // Overridden methods
+ assert!(dialect.supports_order_by_all());
+ assert!(dialect.supports_triple_quoted_string());
+
+ // Non-overridden retains base behavior
+ assert!(!dialect.supports_factorial_operator());
+
+ // Parsing works with the overrides
+ let result = Parser::new(&dialect)
+ .try_with_sql("SELECT '''value''' FROM t ORDER BY ALL")
+ .unwrap()
+ .parse_statements();
+
+ assert!(result.is_ok());
+}
+
+#[test]
+fn test_preserve_type_id() {
+ // Check the override works and the parser recognizes it as the base type
+ derive_dialect!(
+ PreservedTypeDialect,
+ GenericDialect,
+ preserve_type_id = true,
+ overrides = { supports_order_by_all = true }
+ );
+ let dialect = PreservedTypeDialect::new();
+ let d: &dyn Dialect = &dialect;
+
+ assert!(dialect.supports_order_by_all());
+ assert!(d.is::<GenericDialect>());
+}
+
+#[test]
+fn test_different_base_dialects() {
+ derive_dialect!(
+ EnhancedMySqlDialect,
+ MySqlDialect,
+ overrides = { supports_order_by_all = true }
+ );
+ derive_dialect!(UniquePostgreSqlDialect, PostgreSqlDialect);
+
+ let pg = UniquePostgreSqlDialect::new();
+ let mysql = EnhancedMySqlDialect::new();
+
+ // Inherit different base behaviors
+ assert!(pg.supports_filter_during_aggregation()); // PostgreSQL feature
+ assert!(mysql.supports_string_literal_backslash_escape()); // MySQL feature
+ assert!(mysql.supports_order_by_all()); // Override
+
+ // Each has unique TypeId
+ let pg_ref: &dyn Dialect = &pg;
+ let mysql_ref: &dyn Dialect = &mysql;
+ assert!(pg_ref.is::<UniquePostgreSqlDialect>());
+ assert!(!pg_ref.is::<PostgreSqlDialect>());
+ assert!(mysql_ref.is::<EnhancedMySqlDialect>());
+}
+
+#[test]
+fn test_identifier_quote_style_overrides() {
+ derive_dialect!(
+ BacktickGenericDialect,
+ GenericDialect,
+ overrides = { identifier_quote_style = '`' }
+ );
+ derive_dialect!(
+ AnotherBacktickDialect,
+ GenericDialect,
+ overrides = { identifier_quote_style = '[' }
+ );
+ derive_dialect!(
+ QuotelessPostgreSqlDialect,
+ PostgreSqlDialect,
+ preserve_type_id = true,
+ overrides = { identifier_quote_style = None }
+ );
+
+ // Char literal (auto-wrapped in Some)
+ assert_eq!(
+ BacktickGenericDialect::new().identifier_quote_style("x"),
+ Some('`')
+ );
+ // Another char literal
+ assert_eq!(
+ AnotherBacktickDialect::new().identifier_quote_style("x"),
+ Some('[')
+ );
+ // None (overrides PostgreSQL's default '"')
+ assert_eq!(
+ QuotelessPostgreSqlDialect::new().identifier_quote_style("x"),
+ None
+ );
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]