api.rs (7469B)
1 /* 2 This file is part of TALER 3 Copyright (C) 2025, 2026 Taler Systems SA 4 5 TALER is free software; you can redistribute it and/or modify it under the 6 terms of the GNU Affero General Public License as published by the Free Software 7 Foundation; either version 3, or (at your option) any later version. 8 9 TALER is distributed in the hope that it will be useful, but WITHOUT ANY 10 WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR 11 A PARTICULAR PURPOSE. See the GNU Affero General Public License for more details. 12 13 You should have received a copy of the GNU Affero General Public License along with 14 TALER; see the file COPYING. If not, see <http://www.gnu.org/licenses/> 15 */ 16 17 use std::{ 18 sync::{ 19 Arc, 20 atomic::{AtomicU32, Ordering}, 21 }, 22 time::Instant, 23 }; 24 25 use axum::{ 26 extract::{Request, State}, 27 middleware::{self, Next}, 28 response::Response, 29 }; 30 use revenue::Revenue; 31 use taler_common::{error_code::ErrorCode, types::amount::Amount}; 32 use tokio::signal; 33 use tracing::{Level, info}; 34 use wire::WireGateway; 35 36 use crate::{ 37 Listener, Serve, 38 api::transfer::WireTransferGateway, 39 auth::{AuthMethod, AuthMiddlewareState}, 40 error::{ApiResult, failure, failure_code}, 41 }; 42 43 pub mod revenue; 44 pub mod transfer; 45 pub mod wire; 46 47 pub use axum::Router; 48 49 pub trait TalerApi: Send + Sync + 'static { 50 fn currency(&self) -> &str; 51 fn implementation(&self) -> &'static str; 52 fn check_currency(&self, amount: &Amount) -> ApiResult<()> { 53 let currency = self.currency(); 54 if amount.currency.as_ref() != currency { 55 Err(failure( 56 ErrorCode::GENERIC_CURRENCY_MISMATCH, 57 format!( 58 "wrong currency expected {} got {}", 59 currency, amount.currency 60 ), 61 )) 62 } else { 63 Ok(()) 64 } 65 } 66 } 67 68 pub trait RouterUtils { 69 fn auth(self, auth: AuthMethod, realm: &str) -> Self; 70 } 71 72 impl<S: Send + Clone + Sync + 'static> RouterUtils for Router<S> { 73 fn auth(self, auth: AuthMethod, realm: &str) -> Self { 74 self.route_layer(middleware::from_fn_with_state( 75 Arc::new(AuthMiddlewareState::new(auth, realm)), 76 crate::auth::auth_middleware, 77 )) 78 } 79 } 80 81 pub trait TalerRouter { 82 fn wire_gateway<T: WireGateway>(self, api: Arc<T>, auth: AuthMethod) -> Self; 83 fn wire_transfer_gateway<T: WireTransferGateway>(self, api: Arc<T>) -> Self; 84 fn revenue<T: Revenue>(self, api: Arc<T>, auth: AuthMethod) -> Self; 85 fn finalize(self) -> Self; 86 fn serve( 87 self, 88 serve: Serve, 89 lifetime: Option<u32>, 90 ) -> impl std::future::Future<Output = std::io::Result<()>> + Send; 91 } 92 93 impl TalerRouter for Router { 94 fn wire_gateway<T: WireGateway>(self, api: Arc<T>, auth: AuthMethod) -> Self { 95 self.nest("/taler-wire-gateway", wire::router(api, auth)) 96 } 97 98 fn wire_transfer_gateway<T: WireTransferGateway>(self, api: Arc<T>) -> Self { 99 self.nest("/taler-wire-transfer-gateway", transfer::router(api)) 100 } 101 102 fn revenue<T: Revenue>(self, api: Arc<T>, auth: AuthMethod) -> Self { 103 self.nest("/taler-revenue", revenue::router(api, auth)) 104 } 105 106 fn finalize(self) -> Router { 107 self.method_not_allowed_fallback(async || failure_code(ErrorCode::GENERIC_METHOD_INVALID)) 108 .fallback(async || failure_code(ErrorCode::GENERIC_ENDPOINT_UNKNOWN)) 109 .layer(middleware::from_fn(logger_middleware)) 110 } 111 112 async fn serve(mut self, serve: Serve, lifetime: Option<u32>) -> std::io::Result<()> { 113 let listener = serve.resolve()?; 114 115 let notify = Arc::new(tokio::sync::Notify::new()); 116 if let Some(lifetime) = lifetime { 117 self = self.layer(middleware::from_fn_with_state( 118 Arc::new(LifetimeMiddlewareState { 119 notify: notify.clone(), 120 lifetime: AtomicU32::new(lifetime), 121 }), 122 lifetime_middleware, 123 )) 124 } 125 let router = self.finalize(); 126 let signal = shutdown_signal(notify); 127 match listener { 128 Listener::Tcp(tcp_listener) => { 129 axum::serve(tcp_listener, router) 130 .with_graceful_shutdown(signal) 131 .await?; 132 } 133 Listener::Unix(unix_listener) => { 134 axum::serve(unix_listener, router) 135 .with_graceful_shutdown(signal) 136 .await?; 137 } 138 } 139 140 info!(target: "api", "Server stopped"); 141 Ok(()) 142 } 143 } 144 145 struct LifetimeMiddlewareState { 146 lifetime: AtomicU32, 147 notify: Arc<tokio::sync::Notify>, 148 } 149 150 async fn lifetime_middleware( 151 State(state): State<Arc<LifetimeMiddlewareState>>, 152 request: Request, 153 next: Next, 154 ) -> Response { 155 let mut current = state.lifetime.load(Ordering::Relaxed); 156 while current != 0 { 157 match state.lifetime.compare_exchange_weak( 158 current, 159 current - 1, 160 Ordering::Relaxed, 161 Ordering::Relaxed, 162 ) { 163 Ok(_) => break, 164 Err(new) => current = new, 165 } 166 } 167 if current == 0 { 168 state.notify.notify_one(); 169 } 170 next.run(request).await 171 } 172 173 /** Wait for manual shutdown or system signal shutdown */ 174 async fn shutdown_signal(manual_shutdown: Arc<tokio::sync::Notify>) { 175 let ctrl_c = async { 176 signal::ctrl_c() 177 .await 178 .expect("failed to install Ctrl+C handler"); 179 }; 180 181 #[cfg(unix)] 182 let terminate = async { 183 signal::unix::signal(signal::unix::SignalKind::terminate()) 184 .expect("failed to install signal handler") 185 .recv() 186 .await; 187 }; 188 189 #[cfg(not(unix))] 190 let terminate = std::future::pending::<()>(); 191 192 let manual = async { manual_shutdown.notified().await }; 193 194 tokio::select! { 195 _ = ctrl_c => {}, 196 _ = terminate => {}, 197 _ = manual => {} 198 } 199 } 200 201 #[macro_export] 202 macro_rules! dyn_event { 203 ($lvl:ident, $($arg:tt)+) => { 204 match $lvl { 205 ::tracing::Level::TRACE => ::tracing::trace!($($arg)+), 206 ::tracing::Level::DEBUG => ::tracing::debug!($($arg)+), 207 ::tracing::Level::INFO => ::tracing::info!($($arg)+), 208 ::tracing::Level::WARN => ::tracing::warn!($($arg)+), 209 ::tracing::Level::ERROR => ::tracing::error!($($arg)+), 210 } 211 }; 212 } 213 214 /** Taler API logger */ 215 async fn logger_middleware(request: Request, next: Next) -> Response { 216 let request_info = format!( 217 "{} {}", 218 request.method(), 219 request 220 .uri() 221 .path_and_query() 222 .map(|it| it.as_str()) 223 .unwrap_or_default() 224 ); 225 let now = Instant::now(); 226 let response = next.run(request).await; 227 let elapsed = now.elapsed(); 228 let status = response.status(); 229 let level = match status.as_u16() { 230 400..500 => Level::WARN, 231 500..600 => Level::ERROR, 232 _ => Level::INFO, 233 }; 234 235 if let Some(log) = response.extensions().get::<Box<str>>() { 236 dyn_event!(level, target: "api", 237 "{} {request_info} {}ms: {log}", 238 response.status(), 239 elapsed.as_millis() 240 ); 241 } else { 242 dyn_event!(level, target: "api", 243 "{} {request_info} {}ms", 244 response.status(), 245 elapsed.as_millis() 246 ); 247 } 248 response 249 }