api.rs (8597B)
1 /* 2 This file is part of TALER 3 Copyright (C) 2025, 2026 Taler Systems SA 4 5 TALER is free software; you can redistribute it and/or modify it under the 6 terms of the GNU Affero General Public License as published by the Free Software 7 Foundation; either version 3, or (at your option) any later version. 8 9 TALER is distributed in the hope that it will be useful, but WITHOUT ANY 10 WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR 11 A PARTICULAR PURPOSE. See the GNU Affero General Public License for more details. 12 13 You should have received a copy of the GNU Affero General Public License along with 14 TALER; see the file COPYING. If not, see <http://www.gnu.org/licenses/> 15 */ 16 17 use std::{ 18 sync::{ 19 Arc, 20 atomic::{AtomicU32, Ordering}, 21 }, 22 time::Instant, 23 }; 24 25 use axum::{ 26 extract::{Request, State}, 27 middleware::{self, Next}, 28 response::Response, 29 }; 30 use compact_str::CompactString; 31 use rand::RngExt as _; 32 use revenue::Revenue; 33 use taler_common::{ 34 error_code::ErrorCode, 35 log::LOG_TASK_ID, 36 types::amount::{Amount, Currency}, 37 }; 38 use tokio::signal; 39 use tracing::{Level, debug, info}; 40 use wire::WireGateway; 41 42 use crate::{ 43 Listener, Serve, 44 api::prepared::PreparedTransfer, 45 auth::{AuthMethod, AuthMiddlewareState}, 46 error::{ApiResult, LoggedError, failure, failure_code}, 47 }; 48 49 pub mod prepared; 50 pub mod revenue; 51 pub mod wire; 52 53 pub use axum::Router; 54 55 pub trait Validation { 56 fn check(&self, currency: &Currency) -> ApiResult<()>; 57 } 58 59 fn check_currency(currency: &Currency, amount: &Amount) -> ApiResult<()> { 60 if &amount.currency != currency { 61 Err(failure( 62 ErrorCode::GENERIC_CURRENCY_MISMATCH, 63 format!( 64 "wrong currency expected {} got {}", 65 currency, amount.currency 66 ), 67 )) 68 } else { 69 Ok(()) 70 } 71 } 72 73 pub trait TalerApi: Send + Sync + 'static { 74 fn currency(&self) -> Currency; 75 fn implementation(&self) -> &'static str; 76 } 77 78 pub trait RouterUtils { 79 fn auth(self, auth: AuthMethod, realm: &str) -> Self; 80 } 81 82 impl<S: Send + Clone + Sync + 'static> RouterUtils for Router<S> { 83 fn auth(self, auth: AuthMethod, realm: &str) -> Self { 84 self.route_layer(middleware::from_fn_with_state( 85 Arc::new(AuthMiddlewareState::new(auth, realm)), 86 crate::auth::auth_middleware, 87 )) 88 } 89 } 90 91 pub trait TalerRouter { 92 fn wire_gateway<T: WireGateway>(self, api: Arc<T>, auth: AuthMethod) -> Self; 93 fn prepared_transfer<T: PreparedTransfer>(self, api: Arc<T>) -> Self; 94 fn revenue<T: Revenue>(self, api: Arc<T>, auth: AuthMethod) -> Self; 95 fn finalize(self) -> Self; 96 fn serve( 97 self, 98 serve: &Serve, 99 lifetime: Option<u32>, 100 ) -> impl std::future::Future<Output = std::io::Result<()>> + Send; 101 } 102 103 impl TalerRouter for Router { 104 fn wire_gateway<T: WireGateway>(self, api: Arc<T>, auth: AuthMethod) -> Self { 105 self.nest("/taler-wire-gateway", wire::router(api, auth)) 106 } 107 108 fn prepared_transfer<T: PreparedTransfer>(self, api: Arc<T>) -> Self { 109 self.nest("/taler-prepared-transfer", prepared::router(api)) 110 } 111 112 fn revenue<T: Revenue>(self, api: Arc<T>, auth: AuthMethod) -> Self { 113 self.nest("/taler-revenue", revenue::router(api, auth)) 114 } 115 116 fn finalize(self) -> Router { 117 self.method_not_allowed_fallback(async || failure_code(ErrorCode::GENERIC_METHOD_INVALID)) 118 .fallback(async || failure_code(ErrorCode::GENERIC_ENDPOINT_UNKNOWN)) 119 .layer(middleware::from_fn(logger_middleware)) 120 } 121 122 async fn serve(mut self, serve: &Serve, lifetime: Option<u32>) -> std::io::Result<()> { 123 let listener = serve.resolve()?; 124 125 let notify = Arc::new(tokio::sync::Notify::new()); 126 if let Some(lifetime) = lifetime { 127 self = self.layer(middleware::from_fn_with_state( 128 Arc::new(LifetimeMiddlewareState { 129 notify: notify.clone(), 130 lifetime: AtomicU32::new(lifetime), 131 }), 132 lifetime_middleware, 133 )) 134 } 135 let router = self.finalize(); 136 let signal = shutdown_signal(notify); 137 match listener { 138 Listener::Tcp(tcp_listener) => { 139 axum::serve(tcp_listener, router) 140 .with_graceful_shutdown(signal) 141 .await?; 142 } 143 Listener::Unix(unix_listener) => { 144 axum::serve(unix_listener, router) 145 .with_graceful_shutdown(signal) 146 .await?; 147 } 148 } 149 150 info!(target: "api", "Server stopped"); 151 Ok(()) 152 } 153 } 154 155 struct LifetimeMiddlewareState { 156 lifetime: AtomicU32, 157 notify: Arc<tokio::sync::Notify>, 158 } 159 160 async fn lifetime_middleware( 161 State(state): State<Arc<LifetimeMiddlewareState>>, 162 request: Request, 163 next: Next, 164 ) -> Response { 165 let mut current = state.lifetime.load(Ordering::Relaxed); 166 while current != 0 { 167 match state.lifetime.compare_exchange_weak( 168 current, 169 current - 1, 170 Ordering::Relaxed, 171 Ordering::Relaxed, 172 ) { 173 Ok(_) => break, 174 Err(new) => current = new, 175 } 176 } 177 if current == 0 { 178 state.notify.notify_one(); 179 } 180 next.run(request).await 181 } 182 183 /** Wait for manual shutdown or system signal shutdown */ 184 async fn shutdown_signal(manual_shutdown: Arc<tokio::sync::Notify>) { 185 let ctrl_c = async { 186 signal::ctrl_c() 187 .await 188 .expect("failed to install Ctrl+C handler"); 189 }; 190 191 #[cfg(unix)] 192 let terminate = async { 193 signal::unix::signal(signal::unix::SignalKind::terminate()) 194 .expect("failed to install signal handler") 195 .recv() 196 .await; 197 }; 198 199 #[cfg(not(unix))] 200 let terminate = std::future::pending::<()>(); 201 202 let manual = async { manual_shutdown.notified().await }; 203 204 tokio::select! { 205 _ = ctrl_c => {}, 206 _ = terminate => {}, 207 _ = manual => {} 208 } 209 } 210 211 #[macro_export] 212 macro_rules! dyn_event { 213 ($lvl:ident, $($arg:tt)+) => { 214 match $lvl { 215 ::tracing::Level::TRACE => ::tracing::trace!($($arg)+), 216 ::tracing::Level::DEBUG => ::tracing::debug!($($arg)+), 217 ::tracing::Level::INFO => ::tracing::info!($($arg)+), 218 ::tracing::Level::WARN => ::tracing::warn!($($arg)+), 219 ::tracing::Level::ERROR => ::tracing::error!($($arg)+), 220 } 221 }; 222 } 223 224 /** Taler API logger */ 225 async fn logger_middleware(request: Request, next: Next) -> Response { 226 let now = Instant::now(); 227 let request_id: compact_str::CompactString = { 228 let charset = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"; 229 let mut rng = rand::rng(); 230 231 let mut ansi = [0u8; 10]; 232 for c in ansi.iter_mut() { 233 let idx = rng.random_range(0..charset.len()); 234 *c = charset[idx]; 235 } 236 unsafe { CompactString::from_utf8_unchecked(ansi) } 237 }; 238 let method = request.method().clone(); 239 let path_and_query = request.uri().path_and_query().cloned(); 240 let path_and_query = path_and_query 241 .as_ref() 242 .map(|it| it.as_str()) 243 .unwrap_or_default(); 244 LOG_TASK_ID 245 .scope(request_id, async { 246 debug!(target: "api", "{method} {path_and_query}"); 247 let response = next.run(request).await; 248 let elapsed = now.elapsed(); 249 let status = response.status(); 250 let level = match status.as_u16() { 251 400..500 => Level::WARN, 252 500..600 => Level::ERROR, 253 _ => Level::INFO, 254 }; 255 256 if let Some(log) = response.extensions().get::<LoggedError>() { 257 let LoggedError { code, info } = log; 258 dyn_event!(level, target: "api", 259 "{} {method} {path_and_query} {}ms: {code}{}", 260 response.status(), 261 elapsed.as_millis(), 262 std::fmt::from_fn(|f|{ 263 if let Some(info) = info { 264 write!(f, " {info}")?; 265 } 266 Ok(()) 267 }) 268 ); 269 } else { 270 dyn_event!(level, target: "api", 271 "{} {method} {path_and_query} {}ms", 272 response.status(), 273 elapsed.as_millis() 274 ); 275 } 276 response 277 }) 278 .await 279 }