taler-rust

GNU Taler code in Rust. Largely core banking integrations.
Log | Files | Refs | Submodules | README | LICENSE

server.rs (8277B)


      1 /*
      2   This file is part of TALER
      3   Copyright (C) 2025, 2026 Taler Systems SA
      4 
      5   TALER is free software; you can redistribute it and/or modify it under the
      6   terms of the GNU Affero General Public License as published by the Free Software
      7   Foundation; either version 3, or (at your option) any later version.
      8 
      9   TALER is distributed in the hope that it will be useful, but WITHOUT ANY
     10   WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
     11   A PARTICULAR PURPOSE.  See the GNU Affero General Public License for more details.
     12 
     13   You should have received a copy of the GNU Affero General Public License along with
     14   TALER; see the file COPYING.  If not, see <http://www.gnu.org/licenses/>
     15 */
     16 
     17 use std::{fmt::Debug, io::Write, pin::Pin};
     18 
     19 use axum::{
     20     Router,
     21     body::{Body, Bytes},
     22     extract::Query,
     23     http::{
     24         HeaderMap, HeaderValue, Method, StatusCode, Uri,
     25         header::{self, AUTHORIZATION, AsHeaderName, IntoHeaderName},
     26     },
     27 };
     28 use flate2::{Compression, write::ZlibEncoder};
     29 use http_body_util::BodyExt as _;
     30 use serde::{Deserialize, Serialize, de::DeserializeOwned};
     31 use taler_common::{api_common::ErrorDetail, encoding::base64, error_code::ErrorCode};
     32 use tower::ServiceExt as _;
     33 use tracing::warn;
     34 use url::Url;
     35 
     36 pub trait TestServer {
     37     fn request(&self, method: Method, path: &str) -> TestRequest;
     38 
     39     fn get(&self, path: &str) -> TestRequest {
     40         self.request(Method::GET, path)
     41     }
     42 
     43     fn post(&self, path: &str) -> TestRequest {
     44         self.request(Method::POST, path)
     45     }
     46 
     47     fn delete(&self, path: &str) -> TestRequest {
     48         self.request(Method::DELETE, path)
     49     }
     50 }
     51 
     52 impl TestServer for Router {
     53     fn request(&self, method: Method, path: &str) -> TestRequest {
     54         let url = format!("https://example{path}");
     55         TestRequest {
     56             router: self.clone(),
     57             method,
     58             url: url.parse().unwrap(),
     59             body: None,
     60             headers: HeaderMap::new(),
     61         }
     62     }
     63 }
     64 
     65 pub struct TestRequest {
     66     router: Router,
     67     method: Method,
     68     pub url: Url,
     69     body: Option<Bytes>,
     70     headers: HeaderMap,
     71 }
     72 
     73 impl TestRequest {
     74     #[track_caller]
     75     pub fn query<T: Serialize>(mut self, k: &str, v: T) -> Self {
     76         let mut pairs = self.url.query_pairs_mut();
     77         let serializer = serde_urlencoded::Serializer::new(&mut pairs);
     78         [(k, v)].serialize(serializer).unwrap();
     79         drop(pairs);
     80         self
     81     }
     82 
     83     pub fn json<T: Serialize>(mut self, body: T) -> Self {
     84         assert!(self.body.is_none());
     85         let bytes = serde_json::to_vec(&body).unwrap();
     86         self.body = Some(bytes.into());
     87         self.headers.insert(
     88             header::CONTENT_TYPE,
     89             HeaderValue::from_static("application/json"),
     90         );
     91         self
     92     }
     93 
     94     pub fn raw_json(mut self, raw: Bytes) -> Self {
     95         assert!(self.body.is_none());
     96         self.body = Some(raw);
     97         self.headers.insert(
     98             header::CONTENT_TYPE,
     99             HeaderValue::from_static("application/json"),
    100         );
    101         self
    102     }
    103 
    104     pub fn deflate(mut self) -> Self {
    105         let body = self.body.unwrap();
    106         let mut encoder = ZlibEncoder::new(Vec::with_capacity(body.len() / 4), Compression::fast());
    107         encoder.write_all(&body).unwrap();
    108         let compressed = encoder.finish().unwrap();
    109         self.body = Some(compressed.into());
    110         self.headers.insert(
    111             header::CONTENT_ENCODING,
    112             HeaderValue::from_static("deflate"),
    113         );
    114         self
    115     }
    116 
    117     pub fn remove(mut self, k: impl AsHeaderName) -> Self {
    118         self.headers.remove(k);
    119         self
    120     }
    121 
    122     pub fn header<V>(mut self, k: impl IntoHeaderName, v: V) -> Self
    123     where
    124         V: TryInto<HeaderValue>,
    125         V::Error: Debug,
    126     {
    127         self.headers.insert(k, v.try_into().unwrap());
    128         self
    129     }
    130 
    131     pub fn basic_auth(self, username: &str, password: &str) -> Self {
    132         self.header(
    133             AUTHORIZATION,
    134             format!(
    135                 "Basic {}",
    136                 base64::fmt(format!("{username}:{password}").as_bytes())
    137             ),
    138         )
    139     }
    140 
    141     async fn send(self) -> TestResponse {
    142         let Self {
    143             router,
    144             method,
    145             url,
    146             body: req_body,
    147             headers,
    148         } = self;
    149         let uri = Uri::try_from(url.as_str()).unwrap();
    150         let mut req = axum::http::request::Builder::new()
    151             .method(&method)
    152             .uri(&uri)
    153             .body(req_body.clone().map(Body::from).unwrap_or_else(Body::empty))
    154             .unwrap();
    155         *req.headers_mut() = headers;
    156         let resp = router.clone().oneshot(req).await.unwrap();
    157         let (parts, body) = resp.into_parts();
    158         let bytes = body.collect().await.unwrap();
    159         TestResponse {
    160             router,
    161             req_body: req_body.unwrap_or_default(),
    162             res_body: bytes.to_bytes(),
    163             method,
    164             status: parts.status,
    165             uri,
    166         }
    167     }
    168 }
    169 
    170 impl IntoFuture for TestRequest {
    171     type Output = TestResponse;
    172     type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send>>;
    173 
    174     fn into_future(self) -> Self::IntoFuture {
    175         Box::pin(self.send())
    176     }
    177 }
    178 
    179 #[must_use]
    180 pub struct TestResponse {
    181     pub router: Router,
    182     pub req_body: Bytes,
    183     res_body: Bytes,
    184     pub method: Method,
    185     pub uri: Uri,
    186     pub status: StatusCode,
    187 }
    188 
    189 impl TestResponse {
    190     #[track_caller]
    191     pub fn is_implemented(&self) -> bool {
    192         if self.status == StatusCode::NOT_IMPLEMENTED {
    193             let err: ErrorDetail = self.json_parse();
    194             warn!(
    195                 "{} is not implemented: {}",
    196                 self.uri.path(),
    197                 err.hint.unwrap_or_default()
    198             );
    199             false
    200         } else {
    201             true
    202         }
    203     }
    204 
    205     #[track_caller]
    206     pub fn json_parse<'de, T: Deserialize<'de>>(&'de self) -> T {
    207         let Self {
    208             status,
    209             res_body: bytes,
    210             method,
    211             uri,
    212             ..
    213         } = self;
    214         match serde_json::from_slice(bytes) {
    215             Ok(body) => body,
    216             Err(err) => panic!(
    217                 "{method} {uri} {status} invalid JSON body: {err}\n{}",
    218                 String::from_utf8_lossy(bytes)
    219             ),
    220         }
    221     }
    222 
    223     #[track_caller]
    224     pub fn assert_status(&self, expected: StatusCode) {
    225         let Self {
    226             status,
    227             res_body: bytes,
    228             method,
    229             uri,
    230             ..
    231         } = self;
    232         if expected != *status {
    233             if status.is_success() || bytes.is_empty() {
    234                 panic!("{method} {uri} expected {expected} got {status}");
    235             } else {
    236                 let err: ErrorDetail = self.json_parse();
    237                 let description = err.hint.unwrap_or_default();
    238                 panic!(
    239                     "{method} {uri} expected {expected} got {status}: {} {description}",
    240                     err.code
    241                 );
    242             }
    243         }
    244     }
    245 
    246     #[track_caller]
    247     pub fn assert_ok_json<'de, T: Deserialize<'de>>(&'de self) -> T {
    248         self.assert_ok();
    249         self.json_parse()
    250     }
    251 
    252     #[track_caller]
    253     pub fn assert_accepted_json<'de, T: Deserialize<'de>>(&'de self) -> T {
    254         self.assert_accepted();
    255         self.json_parse()
    256     }
    257 
    258     #[track_caller]
    259     pub fn assert_ok(&self) {
    260         self.assert_status(StatusCode::OK);
    261     }
    262 
    263     #[track_caller]
    264     pub fn assert_accepted(&self) {
    265         self.assert_status(StatusCode::ACCEPTED);
    266     }
    267 
    268     #[track_caller]
    269     pub fn assert_no_content(&self) {
    270         self.assert_status(StatusCode::NO_CONTENT);
    271     }
    272 
    273     #[track_caller]
    274     pub fn assert_error(&self, error_code: ErrorCode) {
    275         let (status_code, _) = error_code.metadata();
    276         self.assert_error_status(error_code, StatusCode::from_u16(status_code).unwrap());
    277     }
    278 
    279     #[track_caller]
    280     pub fn assert_error_status(&self, error_code: ErrorCode, status: StatusCode) {
    281         self.assert_status(status);
    282         let err: ErrorDetail = self.json_parse();
    283         assert_eq!(error_code as u32, err.code);
    284     }
    285 
    286     #[track_caller]
    287     pub fn query<T: DeserializeOwned>(&self) -> T {
    288         Query::try_from_uri(&self.uri).unwrap().0
    289     }
    290 }