taler-rust

GNU Taler code in Rust. Largely core banking integrations.
Log | Files | Refs | Submodules | README | LICENSE

api.rs (8597B)


      1 /*
      2   This file is part of TALER
      3   Copyright (C) 2025, 2026 Taler Systems SA
      4 
      5   TALER is free software; you can redistribute it and/or modify it under the
      6   terms of the GNU Affero General Public License as published by the Free Software
      7   Foundation; either version 3, or (at your option) any later version.
      8 
      9   TALER is distributed in the hope that it will be useful, but WITHOUT ANY
     10   WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
     11   A PARTICULAR PURPOSE.  See the GNU Affero General Public License for more details.
     12 
     13   You should have received a copy of the GNU Affero General Public License along with
     14   TALER; see the file COPYING.  If not, see <http://www.gnu.org/licenses/>
     15 */
     16 
     17 use std::{
     18     sync::{
     19         Arc,
     20         atomic::{AtomicU32, Ordering},
     21     },
     22     time::Instant,
     23 };
     24 
     25 use axum::{
     26     extract::{Request, State},
     27     middleware::{self, Next},
     28     response::Response,
     29 };
     30 use compact_str::CompactString;
     31 use rand::RngExt as _;
     32 use revenue::Revenue;
     33 use taler_common::{
     34     error_code::ErrorCode,
     35     log::LOG_TASK_ID,
     36     types::amount::{Amount, Currency},
     37 };
     38 use tokio::signal;
     39 use tracing::{Level, debug, info};
     40 use wire::WireGateway;
     41 
     42 use crate::{
     43     Listener, Serve,
     44     api::prepared::PreparedTransfer,
     45     auth::{AuthMethod, AuthMiddlewareState},
     46     error::{ApiResult, LoggedError, failure, failure_code},
     47 };
     48 
     49 pub mod prepared;
     50 pub mod revenue;
     51 pub mod wire;
     52 
     53 pub use axum::Router;
     54 
     55 pub trait Validation {
     56     fn check(&self, currency: &Currency) -> ApiResult<()>;
     57 }
     58 
     59 fn check_currency(currency: &Currency, amount: &Amount) -> ApiResult<()> {
     60     if &amount.currency != currency {
     61         Err(failure(
     62             ErrorCode::GENERIC_CURRENCY_MISMATCH,
     63             format!(
     64                 "wrong currency expected {} got {}",
     65                 currency, amount.currency
     66             ),
     67         ))
     68     } else {
     69         Ok(())
     70     }
     71 }
     72 
     73 pub trait TalerApi: Send + Sync + 'static {
     74     fn currency(&self) -> Currency;
     75     fn implementation(&self) -> &'static str;
     76 }
     77 
     78 pub trait RouterUtils {
     79     fn auth(self, auth: AuthMethod, realm: &str) -> Self;
     80 }
     81 
     82 impl<S: Send + Clone + Sync + 'static> RouterUtils for Router<S> {
     83     fn auth(self, auth: AuthMethod, realm: &str) -> Self {
     84         self.route_layer(middleware::from_fn_with_state(
     85             Arc::new(AuthMiddlewareState::new(auth, realm)),
     86             crate::auth::auth_middleware,
     87         ))
     88     }
     89 }
     90 
     91 pub trait TalerRouter {
     92     fn wire_gateway<T: WireGateway>(self, api: Arc<T>, auth: AuthMethod) -> Self;
     93     fn prepared_transfer<T: PreparedTransfer>(self, api: Arc<T>) -> Self;
     94     fn revenue<T: Revenue>(self, api: Arc<T>, auth: AuthMethod) -> Self;
     95     fn finalize(self) -> Self;
     96     fn serve(
     97         self,
     98         serve: &Serve,
     99         lifetime: Option<u32>,
    100     ) -> impl std::future::Future<Output = std::io::Result<()>> + Send;
    101 }
    102 
    103 impl TalerRouter for Router {
    104     fn wire_gateway<T: WireGateway>(self, api: Arc<T>, auth: AuthMethod) -> Self {
    105         self.nest("/taler-wire-gateway", wire::router(api, auth))
    106     }
    107 
    108     fn prepared_transfer<T: PreparedTransfer>(self, api: Arc<T>) -> Self {
    109         self.nest("/taler-prepared-transfer", prepared::router(api))
    110     }
    111 
    112     fn revenue<T: Revenue>(self, api: Arc<T>, auth: AuthMethod) -> Self {
    113         self.nest("/taler-revenue", revenue::router(api, auth))
    114     }
    115 
    116     fn finalize(self) -> Router {
    117         self.method_not_allowed_fallback(async || failure_code(ErrorCode::GENERIC_METHOD_INVALID))
    118             .fallback(async || failure_code(ErrorCode::GENERIC_ENDPOINT_UNKNOWN))
    119             .layer(middleware::from_fn(logger_middleware))
    120     }
    121 
    122     async fn serve(mut self, serve: &Serve, lifetime: Option<u32>) -> std::io::Result<()> {
    123         let listener = serve.resolve()?;
    124 
    125         let notify = Arc::new(tokio::sync::Notify::new());
    126         if let Some(lifetime) = lifetime {
    127             self = self.layer(middleware::from_fn_with_state(
    128                 Arc::new(LifetimeMiddlewareState {
    129                     notify: notify.clone(),
    130                     lifetime: AtomicU32::new(lifetime),
    131                 }),
    132                 lifetime_middleware,
    133             ))
    134         }
    135         let router = self.finalize();
    136         let signal = shutdown_signal(notify);
    137         match listener {
    138             Listener::Tcp(tcp_listener) => {
    139                 axum::serve(tcp_listener, router)
    140                     .with_graceful_shutdown(signal)
    141                     .await?;
    142             }
    143             Listener::Unix(unix_listener) => {
    144                 axum::serve(unix_listener, router)
    145                     .with_graceful_shutdown(signal)
    146                     .await?;
    147             }
    148         }
    149 
    150         info!(target: "api", "Server stopped");
    151         Ok(())
    152     }
    153 }
    154 
    155 struct LifetimeMiddlewareState {
    156     lifetime: AtomicU32,
    157     notify: Arc<tokio::sync::Notify>,
    158 }
    159 
    160 async fn lifetime_middleware(
    161     State(state): State<Arc<LifetimeMiddlewareState>>,
    162     request: Request,
    163     next: Next,
    164 ) -> Response {
    165     let mut current = state.lifetime.load(Ordering::Relaxed);
    166     while current != 0 {
    167         match state.lifetime.compare_exchange_weak(
    168             current,
    169             current - 1,
    170             Ordering::Relaxed,
    171             Ordering::Relaxed,
    172         ) {
    173             Ok(_) => break,
    174             Err(new) => current = new,
    175         }
    176     }
    177     if current == 0 {
    178         state.notify.notify_one();
    179     }
    180     next.run(request).await
    181 }
    182 
    183 /** Wait for manual shutdown or system signal shutdown */
    184 async fn shutdown_signal(manual_shutdown: Arc<tokio::sync::Notify>) {
    185     let ctrl_c = async {
    186         signal::ctrl_c()
    187             .await
    188             .expect("failed to install Ctrl+C handler");
    189     };
    190 
    191     #[cfg(unix)]
    192     let terminate = async {
    193         signal::unix::signal(signal::unix::SignalKind::terminate())
    194             .expect("failed to install signal handler")
    195             .recv()
    196             .await;
    197     };
    198 
    199     #[cfg(not(unix))]
    200     let terminate = std::future::pending::<()>();
    201 
    202     let manual = async { manual_shutdown.notified().await };
    203 
    204     tokio::select! {
    205         _ = ctrl_c => {},
    206         _ = terminate => {},
    207         _ = manual => {}
    208     }
    209 }
    210 
    211 #[macro_export]
    212 macro_rules! dyn_event {
    213     ($lvl:ident, $($arg:tt)+) => {
    214         match $lvl {
    215             ::tracing::Level::TRACE => ::tracing::trace!($($arg)+),
    216             ::tracing::Level::DEBUG => ::tracing::debug!($($arg)+),
    217             ::tracing::Level::INFO => ::tracing::info!($($arg)+),
    218             ::tracing::Level::WARN => ::tracing::warn!($($arg)+),
    219             ::tracing::Level::ERROR => ::tracing::error!($($arg)+),
    220         }
    221     };
    222 }
    223 
    224 /** Taler API logger */
    225 async fn logger_middleware(request: Request, next: Next) -> Response {
    226     let now = Instant::now();
    227     let request_id: compact_str::CompactString = {
    228         let charset = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";
    229         let mut rng = rand::rng();
    230 
    231         let mut ansi = [0u8; 10];
    232         for c in ansi.iter_mut() {
    233             let idx = rng.random_range(0..charset.len());
    234             *c = charset[idx];
    235         }
    236         unsafe { CompactString::from_utf8_unchecked(ansi) }
    237     };
    238     let method = request.method().clone();
    239     let path_and_query = request.uri().path_and_query().cloned();
    240     let path_and_query = path_and_query
    241         .as_ref()
    242         .map(|it| it.as_str())
    243         .unwrap_or_default();
    244     LOG_TASK_ID
    245         .scope(request_id, async {
    246             debug!(target: "api", "{method} {path_and_query}");
    247             let response = next.run(request).await;
    248             let elapsed = now.elapsed();
    249             let status = response.status();
    250             let level = match status.as_u16() {
    251                 400..500 => Level::WARN,
    252                 500..600 => Level::ERROR,
    253                 _ => Level::INFO,
    254             };
    255 
    256             if let Some(log) = response.extensions().get::<LoggedError>() {
    257                 let LoggedError { code, info } = log;
    258                 dyn_event!(level, target: "api",
    259                     "{} {method} {path_and_query} {}ms: {code}{}",
    260                     response.status(),
    261                     elapsed.as_millis(),
    262                     std::fmt::from_fn(|f|{
    263                         if let Some(info) = info {
    264                             write!(f, " {info}")?;
    265                         }
    266                         Ok(())
    267                     })
    268                 );
    269             } else {
    270                 dyn_event!(level, target: "api",
    271                     "{} {method} {path_and_query} {}ms",
    272                     response.status(),
    273                     elapsed.as_millis()
    274                 );
    275             }
    276             response
    277         })
    278         .await
    279 }