Aa123456https://users.rust-lang.org/latest

⚓ Rust    📅 2025-10-22    👤 surdeus    👁️ 4      

surdeus

// macro:
mod imp {
    use proc_macro2::TokenStream;
    use quote::{quote, quote_spanned};
    use syn::spanned::Spanned;
    use syn::{parse2, parse_quote, Data, DeriveInput, Fields, GenericParam, Generics, Index};

    pub fn derive_heap_size(input: TokenStream) -> TokenStream {
        // Parse the input tokens into a syntax tree.
        let input: DeriveInput = parse2(input).unwrap();

        // Used in the quasi-quotation below as `#name`.
        let name = input.ident;

        // Add a bound `T: HeapSize` to every type parameter T.
        let generics = add_trait_bounds(input.generics);
        let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();

        // Generate an expression to sum up the heap size of each field.
        let sum = heap_size_sum(&input.data);

        let expanded = quote! {
            // The generated impl.
            impl #impl_generics heapsize::HeapSize for #name #ty_generics #where_clause {
                fn heap_size_of_children(&self) -> usize {
                    #sum
                }
            }
        };

        // Hand the output tokens back to the compiler.
        expanded
    }

    // Add a bound `T: HeapSize` to every type parameter T.
    fn add_trait_bounds(mut generics: Generics) -> Generics {
        for param in &mut generics.params {
            if let GenericParam::Type(ref mut type_param) = *param {
                type_param.bounds.push(parse_quote!(heapsize::HeapSize));
            }
        }
        generics
    }

    // Generate an expression to sum up the heap size of each field.
    fn heap_size_sum(data: &Data) -> TokenStream {
        match *data {
            Data::Struct(ref data) => {
                match data.fields {
                    Fields::Named(ref fields) => {
                        // Expands to an expression like
                        //
                        //     0 + self.x.heap_size() + self.y.heap_size() + self.z.heap_size()
                        //
                        // but using fully qualified function call syntax.
                        //
                        // We take some care to use the span of each `syn::Field` as
                        // the span of the corresponding `heap_size_of_children`
                        // call. This way if one of the field types does not
                        // implement `HeapSize` then the compiler's error message
                        // underlines which field it is. An example is shown in the
                        // readme of the parent directory.
                        let recurse = fields.named.iter().map(|f| {
                            let name = &f.ident;
                            quote_spanned! {f.span()=>
                                heapsize::HeapSize::heap_size_of_children(&self.#name)
                            }
                        });
                        quote! {
                            0 #(+ #recurse)*
                        }
                    }
                    Fields::Unnamed(ref fields) => {
                        // Expands to an expression like
                        //
                        //     0 + self.0.heap_size() + self.1.heap_size() + self.2.heap_size()
                        let recurse = fields.unnamed.iter().enumerate().map(|(i, f)| {
                            let index = Index::from(i);
                            quote_spanned! {f.span()=>
                                heapsize::HeapSize::heap_size_of_children(&self.#index)
                            }
                        });
                        quote! {
                            0 #(+ #recurse)*
                        }
                    }
                    Fields::Unit => {
                        // Unit structs cannot own more than 0 bytes of heap memory.
                        quote!(0)
                    }
                }
            }
            Data::Enum(_) | Data::Union(_) => unimplemented!(),
        }
    }
}

// usage:
use imp::derive_heap_size;
use proc_macro2::TokenStream;

fn main() {
    let input_str = r#"
#[derive(HeapSize)]
struct Demo<'a, T: ?Sized> {
    a: Box<T>,
    b: u8,
    c: &'a str,
    d: String,
}
    "#;
    let input: TokenStream = input_str.parse().unwrap();
    let output = derive_heap_size(input);
    let output_str = output.to_string();
    println!("in:\n{}\n\n---\n", input_str.trim());
    println!("out:\n{}", output_str);
}

(Playground)

1 post - 1 participant

Read full topic

🏷️ Rust_feed