taler-rust

GNU Taler code in Rust. Largely core banking integrations.
Log | Files | Refs | Submodules | README | LICENSE

api.rs (7469B)


      1 /*
      2   This file is part of TALER
      3   Copyright (C) 2025, 2026 Taler Systems SA
      4 
      5   TALER is free software; you can redistribute it and/or modify it under the
      6   terms of the GNU Affero General Public License as published by the Free Software
      7   Foundation; either version 3, or (at your option) any later version.
      8 
      9   TALER is distributed in the hope that it will be useful, but WITHOUT ANY
     10   WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
     11   A PARTICULAR PURPOSE.  See the GNU Affero General Public License for more details.
     12 
     13   You should have received a copy of the GNU Affero General Public License along with
     14   TALER; see the file COPYING.  If not, see <http://www.gnu.org/licenses/>
     15 */
     16 
     17 use std::{
     18     sync::{
     19         Arc,
     20         atomic::{AtomicU32, Ordering},
     21     },
     22     time::Instant,
     23 };
     24 
     25 use axum::{
     26     extract::{Request, State},
     27     middleware::{self, Next},
     28     response::Response,
     29 };
     30 use revenue::Revenue;
     31 use taler_common::{error_code::ErrorCode, types::amount::Amount};
     32 use tokio::signal;
     33 use tracing::{Level, info};
     34 use wire::WireGateway;
     35 
     36 use crate::{
     37     Listener, Serve,
     38     api::transfer::WireTransferGateway,
     39     auth::{AuthMethod, AuthMiddlewareState},
     40     error::{ApiResult, failure, failure_code},
     41 };
     42 
     43 pub mod revenue;
     44 pub mod transfer;
     45 pub mod wire;
     46 
     47 pub use axum::Router;
     48 
     49 pub trait TalerApi: Send + Sync + 'static {
     50     fn currency(&self) -> &str;
     51     fn implementation(&self) -> &'static str;
     52     fn check_currency(&self, amount: &Amount) -> ApiResult<()> {
     53         let currency = self.currency();
     54         if amount.currency.as_ref() != currency {
     55             Err(failure(
     56                 ErrorCode::GENERIC_CURRENCY_MISMATCH,
     57                 format!(
     58                     "wrong currency expected {} got {}",
     59                     currency, amount.currency
     60                 ),
     61             ))
     62         } else {
     63             Ok(())
     64         }
     65     }
     66 }
     67 
     68 pub trait RouterUtils {
     69     fn auth(self, auth: AuthMethod, realm: &str) -> Self;
     70 }
     71 
     72 impl<S: Send + Clone + Sync + 'static> RouterUtils for Router<S> {
     73     fn auth(self, auth: AuthMethod, realm: &str) -> Self {
     74         self.route_layer(middleware::from_fn_with_state(
     75             Arc::new(AuthMiddlewareState::new(auth, realm)),
     76             crate::auth::auth_middleware,
     77         ))
     78     }
     79 }
     80 
     81 pub trait TalerRouter {
     82     fn wire_gateway<T: WireGateway>(self, api: Arc<T>, auth: AuthMethod) -> Self;
     83     fn wire_transfer_gateway<T: WireTransferGateway>(self, api: Arc<T>) -> Self;
     84     fn revenue<T: Revenue>(self, api: Arc<T>, auth: AuthMethod) -> Self;
     85     fn finalize(self) -> Self;
     86     fn serve(
     87         self,
     88         serve: Serve,
     89         lifetime: Option<u32>,
     90     ) -> impl std::future::Future<Output = std::io::Result<()>> + Send;
     91 }
     92 
     93 impl TalerRouter for Router {
     94     fn wire_gateway<T: WireGateway>(self, api: Arc<T>, auth: AuthMethod) -> Self {
     95         self.nest("/taler-wire-gateway", wire::router(api, auth))
     96     }
     97 
     98     fn wire_transfer_gateway<T: WireTransferGateway>(self, api: Arc<T>) -> Self {
     99         self.nest("/taler-wire-transfer-gateway", transfer::router(api))
    100     }
    101 
    102     fn revenue<T: Revenue>(self, api: Arc<T>, auth: AuthMethod) -> Self {
    103         self.nest("/taler-revenue", revenue::router(api, auth))
    104     }
    105 
    106     fn finalize(self) -> Router {
    107         self.method_not_allowed_fallback(async || failure_code(ErrorCode::GENERIC_METHOD_INVALID))
    108             .fallback(async || failure_code(ErrorCode::GENERIC_ENDPOINT_UNKNOWN))
    109             .layer(middleware::from_fn(logger_middleware))
    110     }
    111 
    112     async fn serve(mut self, serve: Serve, lifetime: Option<u32>) -> std::io::Result<()> {
    113         let listener = serve.resolve()?;
    114 
    115         let notify = Arc::new(tokio::sync::Notify::new());
    116         if let Some(lifetime) = lifetime {
    117             self = self.layer(middleware::from_fn_with_state(
    118                 Arc::new(LifetimeMiddlewareState {
    119                     notify: notify.clone(),
    120                     lifetime: AtomicU32::new(lifetime),
    121                 }),
    122                 lifetime_middleware,
    123             ))
    124         }
    125         let router = self.finalize();
    126         let signal = shutdown_signal(notify);
    127         match listener {
    128             Listener::Tcp(tcp_listener) => {
    129                 axum::serve(tcp_listener, router)
    130                     .with_graceful_shutdown(signal)
    131                     .await?;
    132             }
    133             Listener::Unix(unix_listener) => {
    134                 axum::serve(unix_listener, router)
    135                     .with_graceful_shutdown(signal)
    136                     .await?;
    137             }
    138         }
    139 
    140         info!(target: "api", "Server stopped");
    141         Ok(())
    142     }
    143 }
    144 
    145 struct LifetimeMiddlewareState {
    146     lifetime: AtomicU32,
    147     notify: Arc<tokio::sync::Notify>,
    148 }
    149 
    150 async fn lifetime_middleware(
    151     State(state): State<Arc<LifetimeMiddlewareState>>,
    152     request: Request,
    153     next: Next,
    154 ) -> Response {
    155     let mut current = state.lifetime.load(Ordering::Relaxed);
    156     while current != 0 {
    157         match state.lifetime.compare_exchange_weak(
    158             current,
    159             current - 1,
    160             Ordering::Relaxed,
    161             Ordering::Relaxed,
    162         ) {
    163             Ok(_) => break,
    164             Err(new) => current = new,
    165         }
    166     }
    167     if current == 0 {
    168         state.notify.notify_one();
    169     }
    170     next.run(request).await
    171 }
    172 
    173 /** Wait for manual shutdown or system signal shutdown */
    174 async fn shutdown_signal(manual_shutdown: Arc<tokio::sync::Notify>) {
    175     let ctrl_c = async {
    176         signal::ctrl_c()
    177             .await
    178             .expect("failed to install Ctrl+C handler");
    179     };
    180 
    181     #[cfg(unix)]
    182     let terminate = async {
    183         signal::unix::signal(signal::unix::SignalKind::terminate())
    184             .expect("failed to install signal handler")
    185             .recv()
    186             .await;
    187     };
    188 
    189     #[cfg(not(unix))]
    190     let terminate = std::future::pending::<()>();
    191 
    192     let manual = async { manual_shutdown.notified().await };
    193 
    194     tokio::select! {
    195         _ = ctrl_c => {},
    196         _ = terminate => {},
    197         _ = manual => {}
    198     }
    199 }
    200 
    201 #[macro_export]
    202 macro_rules! dyn_event {
    203     ($lvl:ident, $($arg:tt)+) => {
    204         match $lvl {
    205             ::tracing::Level::TRACE => ::tracing::trace!($($arg)+),
    206             ::tracing::Level::DEBUG => ::tracing::debug!($($arg)+),
    207             ::tracing::Level::INFO => ::tracing::info!($($arg)+),
    208             ::tracing::Level::WARN => ::tracing::warn!($($arg)+),
    209             ::tracing::Level::ERROR => ::tracing::error!($($arg)+),
    210         }
    211     };
    212 }
    213 
    214 /** Taler API logger */
    215 async fn logger_middleware(request: Request, next: Next) -> Response {
    216     let request_info = format!(
    217         "{} {}",
    218         request.method(),
    219         request
    220             .uri()
    221             .path_and_query()
    222             .map(|it| it.as_str())
    223             .unwrap_or_default()
    224     );
    225     let now = Instant::now();
    226     let response = next.run(request).await;
    227     let elapsed = now.elapsed();
    228     let status = response.status();
    229     let level = match status.as_u16() {
    230         400..500 => Level::WARN,
    231         500..600 => Level::ERROR,
    232         _ => Level::INFO,
    233     };
    234 
    235     if let Some(log) = response.extensions().get::<Box<str>>() {
    236         dyn_event!(level, target: "api",
    237             "{} {request_info} {}ms: {log}",
    238             response.status(),
    239             elapsed.as_millis()
    240         );
    241     } else {
    242         dyn_event!(level, target: "api",
    243             "{} {request_info} {}ms",
    244             response.status(),
    245             elapsed.as_millis()
    246         );
    247     }
    248     response
    249 }