From: Gary Guo <[email protected]>

This allows significant cleanups.

Signed-off-by: Gary Guo <[email protected]>
---
 rust/macros/kunit.rs | 274 +++++++++++++++++++------------------------
 rust/macros/lib.rs   |   6 +-
 2 files changed, 123 insertions(+), 157 deletions(-)

diff --git a/rust/macros/kunit.rs b/rust/macros/kunit.rs
index 7427c17ee5f5c..516219f5b1356 100644
--- a/rust/macros/kunit.rs
+++ b/rust/macros/kunit.rs
@@ -4,81 +4,50 @@
 //!
 //! Copyright (c) 2023 José Expósito <[email protected]>
 
-use std::collections::HashMap;
-use std::fmt::Write;
-
-use proc_macro2::{Delimiter, Group, TokenStream, TokenTree};
-
-pub(crate) fn kunit_tests(attr: TokenStream, ts: TokenStream) -> TokenStream {
-    let attr = attr.to_string();
-
-    if attr.is_empty() {
-        panic!("Missing test name in `#[kunit_tests(test_name)]` macro")
-    }
-
-    if attr.len() > 255 {
-        panic!("The test suite name `{attr}` exceeds the maximum length of 255 
bytes")
+use std::ffi::CString;
+
+use proc_macro2::TokenStream;
+use quote::{
+    format_ident,
+    quote,
+    ToTokens, //
+};
+use syn::{
+    parse_quote,
+    Error,
+    Ident,
+    Item,
+    ItemMod,
+    LitCStr,
+    Result, //
+};
+
+pub(crate) fn kunit_tests(test_suite: Ident, mut module: ItemMod) -> 
Result<TokenStream> {
+    if test_suite.to_string().len() > 255 {
+        return Err(Error::new_spanned(
+            test_suite,
+            "test suite names cannot exceed the maximum length of 255 bytes",
+        ));
     }
 
-    let mut tokens: Vec<_> = ts.into_iter().collect();
-
-    // Scan for the `mod` keyword.
-    tokens
-        .iter()
-        .find_map(|token| match token {
-            TokenTree::Ident(ident) => match ident.to_string().as_str() {
-                "mod" => Some(true),
-                _ => None,
-            },
-            _ => None,
-        })
-        .expect("`#[kunit_tests(test_name)]` attribute should only be applied 
to modules");
-
-    // Retrieve the main body. The main body should be the last token tree.
-    let body = match tokens.pop() {
-        Some(TokenTree::Group(group)) if group.delimiter() == Delimiter::Brace 
=> group,
-        _ => panic!("Cannot locate main body of module"),
+    // We cannot handle modules that defer to another file (e.g. `mod foo;`).
+    let Some((module_brace, module_items)) = module.content.take() else {
+        Err(Error::new_spanned(
+            module,
+            "`#[kunit_tests(test_name)]` attribute should only be applied to 
inline modules",
+        ))?
     };
 
-    // Get the functions set as tests. Search for `[test]` -> `fn`.
-    let mut body_it = body.stream().into_iter();
-    let mut tests = Vec::new();
-    let mut attributes: HashMap<String, TokenStream> = HashMap::new();
-    while let Some(token) = body_it.next() {
-        match token {
-            TokenTree::Punct(ref p) if p.as_char() == '#' => match 
body_it.next() {
-                Some(TokenTree::Group(g)) if g.delimiter() == 
Delimiter::Bracket => {
-                    if let Some(TokenTree::Ident(name)) = 
g.stream().into_iter().next() {
-                        // Collect attributes because we need to find which 
are tests. We also
-                        // need to copy `cfg` attributes so tests can be 
conditionally enabled.
-                        attributes
-                            .entry(name.to_string())
-                            .or_default()
-                            .extend([token, TokenTree::Group(g)]);
-                    }
-                    continue;
-                }
-                _ => (),
-            },
-            TokenTree::Ident(i) if i.to_string() == "fn" && 
attributes.contains_key("test") => {
-                if let Some(TokenTree::Ident(test_name)) = body_it.next() {
-                    tests.push((test_name, 
attributes.remove("cfg").unwrap_or_default()))
-                }
-            }
-
-            _ => (),
-        }
-        attributes.clear();
-    }
+    // Make the entire module gated behind `CONFIG_KUNIT`.
+    module
+        .attrs
+        .insert(0, parse_quote!(#[cfg(CONFIG_KUNIT="y")]));
 
-    // Add `#[cfg(CONFIG_KUNIT="y")]` before the module declaration.
-    let config_kunit = 
"#[cfg(CONFIG_KUNIT=\"y\")]".to_owned().parse().unwrap();
-    tokens.insert(
-        0,
-        TokenTree::Group(Group::new(Delimiter::None, config_kunit)),
-    );
+    let mut processed_items = Vec::new();
+    let mut test_cases = Vec::new();
 
     // Generate the test KUnit test suite and a test case for each `#[test]`.
+    //
     // The code generated for the following test module:
     //
     // ```
@@ -110,98 +79,93 @@ pub(crate) fn kunit_tests(attr: TokenStream, ts: 
TokenStream) -> TokenStream {
     //
     // ::kernel::kunit_unsafe_test_suite!(kunit_test_suit_name, TEST_CASES);
     // ```
-    let mut kunit_macros = "".to_owned();
-    let mut test_cases = "".to_owned();
-    let mut assert_macros = "".to_owned();
-    let path = crate::helpers::file();
-    let num_tests = tests.len();
-    for (test, cfg_attr) in tests {
-        let kunit_wrapper_fn_name = format!("kunit_rust_wrapper_{test}");
-        // Append any `cfg` attributes the user might have written on their 
tests so we don't
-        // attempt to call them when they are `cfg`'d out. An extra `use` is 
used here to reduce
-        // the length of the assert message.
-        let kunit_wrapper = format!(
-            r#"unsafe extern "C" fn {kunit_wrapper_fn_name}(_test: *mut 
::kernel::bindings::kunit)
-            {{
-                (*_test).status = 
::kernel::bindings::kunit_status_KUNIT_SKIPPED;
-                {cfg_attr} {{
-                    (*_test).status = 
::kernel::bindings::kunit_status_KUNIT_SUCCESS;
-                    use ::kernel::kunit::is_test_result_ok;
-                    assert!(is_test_result_ok({test}()));
+    //
+    // Non-function items (e.g. imports) are preserved.
+    for item in module_items {
+        let Item::Fn(mut f) = item else {
+            processed_items.push(item);
+            continue;
+        };
+
+        // TODO: Replace below with `extract_if` when MSRV is bumped above 
1.85.
+        // Remove `#[test]` attributes applied on the function and count if 
any.
+        if !f.attrs.iter().any(|attr| attr.path().is_ident("test")) {
+            processed_items.push(Item::Fn(f));
+            continue;
+        }
+        f.attrs.retain(|attr| !attr.path().is_ident("test"));
+
+        let test = f.sig.ident.clone();
+
+        // Retrieve `#[cfg]` applied on the function which needs to be present 
on derived items too.
+        let cfg_attrs: Vec<_> = f
+            .attrs
+            .iter()
+            .filter(|attr| attr.path().is_ident("cfg"))
+            .cloned()
+            .collect();
+
+        // Before the test, override usual `assert!` and `assert_eq!` macros 
with ones that call
+        // KUnit instead.
+        let test_str = test.to_string();
+        let path = crate::helpers::file();
+        processed_items.push(parse_quote! {
+            #[allow(unused)]
+            macro_rules! assert {
+                ($cond:expr $(,)?) => {{
+                    kernel::kunit_assert!(#test_str, #path, 0, $cond);
+                }}
+            }
+        });
+        processed_items.push(parse_quote! {
+            #[allow(unused)]
+            macro_rules! assert_eq {
+                ($left:expr, $right:expr $(,)?) => {{
+                    kernel::kunit_assert_eq!(#test_str, #path, 0, $left, 
$right);
                 }}
-            }}"#,
+            }
+        });
+
+        // Add back the test item.
+        processed_items.push(Item::Fn(f));
+
+        let kunit_wrapper_fn_name = format_ident!("kunit_rust_wrapper_{test}");
+        let test_cstr = LitCStr::new(
+            &CString::new(test_str.as_str()).expect("identifier cannot contain 
NUL"),
+            test.span(),
         );
-        writeln!(kunit_macros, "{kunit_wrapper}").unwrap();
-        writeln!(
-            test_cases,
-            "    ::kernel::kunit::kunit_case(::kernel::c_str!(\"{test}\"), 
{kunit_wrapper_fn_name}),"
-        )
-        .unwrap();
-        writeln!(
-            assert_macros,
-            r#"
-/// Overrides the usual [`assert!`] macro with one that calls KUnit instead.
-#[allow(unused)]
-macro_rules! assert {{
-    ($cond:expr $(,)?) => {{{{
-        kernel::kunit_assert!("{test}", "{path}", 0, $cond);
-    }}}}
-}}
-
-/// Overrides the usual [`assert_eq!`] macro with one that calls KUnit instead.
-#[allow(unused)]
-macro_rules! assert_eq {{
-    ($left:expr, $right:expr $(,)?) => {{{{
-        kernel::kunit_assert_eq!("{test}", "{path}", 0, $left, $right);
-    }}}}
-}}
-        "#
-        )
-        .unwrap();
-    }
+        processed_items.push(parse_quote! {
+            unsafe extern "C" fn #kunit_wrapper_fn_name(_test: *mut 
::kernel::bindings::kunit) {
+                (*_test).status = 
::kernel::bindings::kunit_status_KUNIT_SKIPPED;
 
-    writeln!(kunit_macros).unwrap();
-    writeln!(
-        kunit_macros,
-        "static mut TEST_CASES: [::kernel::bindings::kunit_case; {}] = 
[\n{test_cases}    ::kernel::kunit::kunit_case_null(),\n];",
-        num_tests + 1
-    )
-    .unwrap();
-
-    writeln!(
-        kunit_macros,
-        "::kernel::kunit_unsafe_test_suite!({attr}, TEST_CASES);"
-    )
-    .unwrap();
-
-    // Remove the `#[test]` macros.
-    // We do this at a token level, in order to preserve span information.
-    let mut new_body = vec![];
-    let mut body_it = body.stream().into_iter();
-
-    while let Some(token) = body_it.next() {
-        match token {
-            TokenTree::Punct(ref c) if c.as_char() == '#' => match 
body_it.next() {
-                Some(TokenTree::Group(group)) if group.to_string() == "[test]" 
=> (),
-                Some(next) => {
-                    new_body.extend([token, next]);
-                }
-                _ => {
-                    new_body.push(token);
+                // Append any `cfg` attributes the user might have written on 
their tests so we
+                // don't attempt to call them when they are `cfg`'d out. An 
extra `use` is used
+                // here to reduce the length of the assert message.
+                #(#cfg_attrs)*
+                {
+                    (*_test).status = 
::kernel::bindings::kunit_status_KUNIT_SUCCESS;
+                    use ::kernel::kunit::is_test_result_ok;
+                    assert!(is_test_result_ok(#test()));
                 }
-            },
-            _ => {
-                new_body.push(token);
             }
-        }
-    }
-
-    let mut final_body = TokenStream::new();
-    final_body.extend::<TokenStream>(assert_macros.parse().unwrap());
-    final_body.extend(new_body);
-    final_body.extend::<TokenStream>(kunit_macros.parse().unwrap());
+        });
 
-    tokens.push(TokenTree::Group(Group::new(Delimiter::Brace, final_body)));
+        test_cases.push(quote!(
+            ::kernel::kunit::kunit_case(#test_cstr, #kunit_wrapper_fn_name)
+        ));
+    }
 
-    tokens.into_iter().collect()
+    let num_tests_plus_1 = test_cases.len() + 1;
+    processed_items.push(parse_quote! {
+        static mut TEST_CASES: [::kernel::bindings::kunit_case; 
#num_tests_plus_1] = [
+            #(#test_cases,)*
+            ::kernel::kunit::kunit_case_null(),
+        ];
+    });
+    processed_items.push(parse_quote! {
+        ::kernel::kunit_unsafe_test_suite!(#test_suite, TEST_CASES);
+    });
+
+    module.content = Some((module_brace, processed_items));
+    Ok(module.to_token_stream())
 }
diff --git a/rust/macros/lib.rs b/rust/macros/lib.rs
index bb2dfd4a4dafc..9cfac9fce0d36 100644
--- a/rust/macros/lib.rs
+++ b/rust/macros/lib.rs
@@ -453,6 +453,8 @@ pub fn paste(input: TokenStream) -> TokenStream {
 /// }
 /// ```
 #[proc_macro_attribute]
-pub fn kunit_tests(attr: TokenStream, ts: TokenStream) -> TokenStream {
-    kunit::kunit_tests(attr.into(), ts.into()).into()
+pub fn kunit_tests(attr: TokenStream, input: TokenStream) -> TokenStream {
+    kunit::kunit_tests(parse_macro_input!(attr), parse_macro_input!(input))
+        .unwrap_or_else(|e| e.into_compile_error())
+        .into()
 }
-- 
2.51.2


Reply via email to