taler-rust

GNU Taler code in Rust. Largely core banking integrations.
Log | Files | Refs | Submodules | README | LICENSE

lib.rs (11592B)


      1 /*
      2   This file is part of TALER
      3   Copyright (C) 2026 Taler Systems SA
      4 
      5   TALER is free software; you can redistribute it and/or modify it under the
      6   terms of the GNU Affero General Public License as published by the Free Software
      7   Foundation; either version 3, or (at your option) any later version.
      8 
      9   TALER is distributed in the hope that it will be useful, but WITHOUT ANY
     10   WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
     11   A PARTICULAR PURPOSE.  See the GNU Affero General Public License for more details.
     12 
     13   You should have received a copy of the GNU Affero General Public License along with
     14   TALER; see the file COPYING.  If not, see <http://www.gnu.org/licenses/>
     15 */
     16 
     17 use proc_macro::TokenStream;
     18 use quote::quote;
     19 use syn::{
     20     Data, DeriveInput, Error, Expr, Field, Fields, ItemStruct, Lit, LitStr, Meta, parse::Parser,
     21     parse_macro_input,
     22 };
     23 
     24 #[proc_macro_derive(EnumMeta, attributes(enum_meta, code))]
     25 pub fn derive_domain_code(input: TokenStream) -> TokenStream {
     26     let input = parse_macro_input!(input as DeriveInput);
     27     let name = &input.ident;
     28 
     29     // Parse features
     30     let mut enabled_doc = false;
     31     let mut enabled_domain_code = false;
     32     let mut enabled_str = false;
     33 
     34     for attr in &input.attrs {
     35         if attr.path().is_ident("enum_meta")
     36             && let Err(e) = attr.parse_nested_meta(|meta| {
     37                 if meta.path.is_ident("Description") {
     38                     enabled_doc = true;
     39                 } else if meta.path.is_ident("DomainCode") {
     40                     enabled_domain_code = true;
     41                 } else if meta.path.is_ident("Str") {
     42                     enabled_str = true;
     43                 } else {
     44                     return Err(meta.error("unknown enum_meta option"));
     45                 }
     46                 Ok(())
     47             })
     48         {
     49             return e.to_compile_error().into();
     50         }
     51     }
     52 
     53     let repr_type = input.attrs.iter().find_map(|attr| {
     54         if !attr.path().is_ident("repr") {
     55             return None;
     56         }
     57         let mut found = None;
     58         attr.parse_nested_meta(|it| {
     59             found = it.path.get_ident().cloned();
     60             Ok(())
     61         })
     62         .unwrap();
     63         found
     64     });
     65 
     66     let variants = if let Data::Enum(data) = &input.data {
     67         &data.variants
     68     } else {
     69         return Error::new(input.ident.span(), "EnumMeta only supports enums")
     70             .to_compile_error()
     71             .into();
     72     };
     73 
     74     // Helper: extract the first string literal from a name-value attribute.
     75     let extract_str_attr = |variant: &syn::Variant, ident: &str| -> Option<String> {
     76         variant.attrs.iter().find_map(|a| {
     77             if a.path().is_ident(ident)
     78                 && let Meta::NameValue(nv) = &a.meta
     79                 && let Expr::Lit(expr) = &nv.value
     80                 && let Lit::Str(s) = &expr.lit
     81             {
     82                 Some(s.value())
     83             } else {
     84                 None
     85             }
     86         })
     87     };
     88 
     89     let mut entries = Vec::new();
     90     let mut description_arms = Vec::new();
     91     let mut code_arms = Vec::new();
     92     let mut from_str_arms = Vec::new();
     93     let mut as_ref_arms = Vec::new();
     94     let mut try_from_arms = Vec::new();
     95 
     96     for variant in variants {
     97         let v_ident = &variant.ident;
     98         let v_str = variant
     99             .attrs
    100             .iter()
    101             .find_map(|attr| {
    102                 if attr.path().is_ident("enum_meta") {
    103                     let mut res = None;
    104                     let _ = attr.parse_nested_meta(|meta| {
    105                         if meta.path.is_ident("rename") {
    106                             let value: LitStr = meta.value()?.parse()?;
    107                             res = Some(value.value());
    108                         }
    109                         Ok(())
    110                     });
    111                     res
    112                 } else {
    113                     None
    114                 }
    115             })
    116             .unwrap_or_else(|| v_ident.to_string());
    117 
    118         if repr_type.is_some() {
    119             if let Some((_, discriminant)) = &variant.discriminant {
    120                 try_from_arms.push(quote! { #discriminant => Ok(Self::#v_ident) });
    121             } else {
    122                 return Error::new(v_ident.span(), "missing discriminant expression")
    123                     .to_compile_error()
    124                     .into();
    125             };
    126         }
    127 
    128         // Single pass: collect doc and code in one go, then use what's needed.
    129         let doc =
    130             enabled_doc.then(|| extract_str_attr(variant, "doc").map(|s| s.trim().to_string()));
    131         let code = (enabled_domain_code).then(|| extract_str_attr(variant, "code"));
    132 
    133         if let Some(doc) = doc {
    134             let doc = match doc {
    135                 Some(d) => d,
    136                 None => {
    137                     return Error::new(
    138                         v_ident.span(),
    139                         format!("variant `{v_str}` is missing `/// documentation`"),
    140                     )
    141                     .to_compile_error()
    142                     .into();
    143                 }
    144             };
    145             description_arms.push(quote! { Self::#v_ident => #doc });
    146         }
    147 
    148         if let Some(code) = code {
    149             let code = match code {
    150                 Some(c) => c,
    151                 None => {
    152                     return Error::new(
    153                         v_ident.span(),
    154                         format!("variant `{v_str}` is missing `#[code = \"...\"]`"),
    155                     )
    156                     .to_compile_error()
    157                     .into();
    158                 }
    159             };
    160             from_str_arms.push(quote! { #code => Ok(Self::#v_ident) });
    161             code_arms.push(quote! { Self::#v_ident => #code });
    162         } else if enabled_str {
    163             from_str_arms.push(quote! { #v_str => Ok(Self::#v_ident) });
    164         }
    165 
    166         if enabled_str {
    167             as_ref_arms.push(quote! { Self::#v_ident => #v_str });
    168         }
    169         entries.push(quote! { Self::#v_ident, });
    170     }
    171 
    172     let mut expanded = quote! {
    173         impl #name {
    174             /// Returns a slice of all enum variants.
    175             pub const entries: &'static [Self] = &[#(#entries)*];
    176         }
    177     };
    178 
    179     if enabled_doc {
    180         expanded.extend(quote! {
    181             impl #name {
    182                 /// Returns the documentation description associated
    183                 pub fn description(&self) -> &'static str {
    184                     match self { #(#description_arms),* }
    185                 }
    186             }
    187         });
    188     }
    189 
    190     if enabled_domain_code {
    191         expanded.extend(quote! {
    192             impl #name {
    193                 /// Returns the domain code associated
    194                 pub fn code(&self) -> &'static str {
    195                     match self { #(#code_arms),* }
    196                 }
    197             }
    198         });
    199     }
    200 
    201     if enabled_str {
    202         expanded.extend(quote! {
    203             impl AsRef<str> for #name {
    204                 fn as_ref(&self) -> &str {
    205                     match self { #(#as_ref_arms),* }
    206                 }
    207             }
    208 
    209             impl std::fmt::Display for #name {
    210                 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
    211                     f.write_str(self.as_ref())
    212                 }
    213             }
    214         });
    215     }
    216 
    217     if enabled_domain_code || enabled_str {
    218         let unknown_label = if enabled_domain_code { "code" } else { "name" };
    219         expanded.extend(quote! {
    220             impl std::str::FromStr for #name {
    221                 type Err = String;
    222                 fn from_str(s: &str) -> Result<Self, Self::Err> {
    223                     match s {
    224                         #(#from_str_arms,)*
    225                         _ => Err(format!("Unknown {0} for {1}: {2}", #unknown_label, stringify!(#name), s))
    226                     }
    227                 }
    228             }
    229         });
    230     }
    231 
    232     if let Some(repr) = repr_type {
    233         expanded.extend(quote! {
    234             impl TryFrom<#repr> for #name {
    235                 type Error = #repr;
    236                 fn try_from(value: #repr) -> Result<Self, Self::Error> {
    237                     match value {
    238                         #(#try_from_arms,)*
    239                         _ => Err(value),
    240                     }
    241                 }
    242             }
    243         });
    244     }
    245 
    246     TokenStream::from(expanded)
    247 }
    248 
    249 #[proc_macro_attribute]
    250 pub fn api_config(attr: TokenStream, item: TokenStream) -> TokenStream {
    251     // 1. Cleanly parse the string attribute argument
    252     let api_name_lit = parse_macro_input!(attr as LitStr);
    253     let api_name_value = api_name_lit.value();
    254 
    255     // 2. Parse the target struct
    256     let mut input_struct = parse_macro_input!(item as ItemStruct);
    257     let struct_name = &input_struct.ident;
    258 
    259     // 3. Generate deterministic function names for serialization AND deserialization
    260     let struct_lower = struct_name.to_string().to_lowercase();
    261     let serialize_fn_ident = syn::Ident::new(
    262         &format!("_serialize_api_name_for_{}", struct_lower),
    263         proc_macro2::Span::call_site(),
    264     );
    265     let deserialize_fn_ident = syn::Ident::new(
    266         &format!("_deserialize_api_name_for_{}", struct_lower),
    267         proc_macro2::Span::call_site(),
    268     );
    269 
    270     // 4. Convert identifiers to string paths for Serde attributes
    271     let serialize_fn_str = serialize_fn_ident.to_string();
    272     let deserialize_fn_str = deserialize_fn_ident.to_string();
    273 
    274     let crate_path = if std::env::var("CARGO_CRATE_NAME").unwrap_or_default() == "taler_common" {
    275         quote! { crate }
    276     } else {
    277         quote! { ::taler_common }
    278     };
    279 
    280     // 5. Inject the `name` field with BOTH serialization and deserialization hooks
    281     if let Fields::Named(ref mut fields) = input_struct.fields {
    282         fields.named.insert(
    283             0,
    284             Field::parse_named
    285                 .parse2(quote! {
    286                     #[serde(
    287                         serialize_with = #serialize_fn_str,
    288                         deserialize_with = #deserialize_fn_str
    289                     )]
    290                     pub name: ()
    291                 })
    292                 .unwrap(),
    293         );
    294         fields.named.insert(
    295             1,
    296             Field::parse_named
    297                 .parse2(quote! {
    298                     pub version: #crate_path::api::LibtoolVersion
    299                 })
    300                 .unwrap(),
    301         );
    302         fields.named.insert(
    303             2,
    304             Field::parse_named
    305                 .parse2(quote! {
    306                     pub implementation: Option<&'a str>
    307                 })
    308                 .unwrap(),
    309         );
    310     } else {
    311         return syn::Error::new_spanned(
    312             input_struct,
    313             "#[api_config] only works on structs with named fields",
    314         )
    315         .to_compile_error()
    316         .into();
    317     }
    318 
    319     let expanded = quote! {
    320         #input_struct
    321 
    322         #[doc(hidden)]
    323         #[allow(non_snake_case)]
    324         pub fn #serialize_fn_ident<S>(_: &(), s: S) -> ::std::result::Result<S::Ok, S::Error>
    325         where
    326             S: ::serde::Serializer
    327         {
    328             s.serialize_str(#api_name_value)
    329         }
    330 
    331         #[doc(hidden)]
    332         #[allow(non_snake_case)]
    333         pub fn #deserialize_fn_ident<'de, D>(deserializer: D) -> ::std::result::Result<(), D::Error>
    334         where
    335             D: ::serde::Deserializer<'de>,
    336         {
    337             let s: ::std::string::String = ::serde::Deserialize::deserialize(deserializer)?;
    338             if s == #api_name_value {
    339                 Ok(())
    340             } else {
    341                 Err(::serde::de::Error::custom(::std::format!(
    342                     "invalid API name: expected '{}', found '{}'",
    343                     #api_name_value, s
    344                 )))
    345             }
    346         }
    347     };
    348 
    349     TokenStream::from(expanded)
    350 }