server.rs (8277B)
1 /* 2 This file is part of TALER 3 Copyright (C) 2025, 2026 Taler Systems SA 4 5 TALER is free software; you can redistribute it and/or modify it under the 6 terms of the GNU Affero General Public License as published by the Free Software 7 Foundation; either version 3, or (at your option) any later version. 8 9 TALER is distributed in the hope that it will be useful, but WITHOUT ANY 10 WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR 11 A PARTICULAR PURPOSE. See the GNU Affero General Public License for more details. 12 13 You should have received a copy of the GNU Affero General Public License along with 14 TALER; see the file COPYING. If not, see <http://www.gnu.org/licenses/> 15 */ 16 17 use std::{fmt::Debug, io::Write, pin::Pin}; 18 19 use axum::{ 20 Router, 21 body::{Body, Bytes}, 22 extract::Query, 23 http::{ 24 HeaderMap, HeaderValue, Method, StatusCode, Uri, 25 header::{self, AUTHORIZATION, AsHeaderName, IntoHeaderName}, 26 }, 27 }; 28 use flate2::{Compression, write::ZlibEncoder}; 29 use http_body_util::BodyExt as _; 30 use serde::{Deserialize, Serialize, de::DeserializeOwned}; 31 use taler_common::{api_common::ErrorDetail, encoding::base64, error_code::ErrorCode}; 32 use tower::ServiceExt as _; 33 use tracing::warn; 34 use url::Url; 35 36 pub trait TestServer { 37 fn request(&self, method: Method, path: &str) -> TestRequest; 38 39 fn get(&self, path: &str) -> TestRequest { 40 self.request(Method::GET, path) 41 } 42 43 fn post(&self, path: &str) -> TestRequest { 44 self.request(Method::POST, path) 45 } 46 47 fn delete(&self, path: &str) -> TestRequest { 48 self.request(Method::DELETE, path) 49 } 50 } 51 52 impl TestServer for Router { 53 fn request(&self, method: Method, path: &str) -> TestRequest { 54 let url = format!("https://example{path}"); 55 TestRequest { 56 router: self.clone(), 57 method, 58 url: url.parse().unwrap(), 59 body: None, 60 headers: HeaderMap::new(), 61 } 62 } 63 } 64 65 pub struct TestRequest { 66 router: Router, 67 method: Method, 68 pub url: Url, 69 body: Option<Bytes>, 70 headers: HeaderMap, 71 } 72 73 impl TestRequest { 74 #[track_caller] 75 pub fn query<T: Serialize>(mut self, k: &str, v: T) -> Self { 76 let mut pairs = self.url.query_pairs_mut(); 77 let serializer = serde_urlencoded::Serializer::new(&mut pairs); 78 [(k, v)].serialize(serializer).unwrap(); 79 drop(pairs); 80 self 81 } 82 83 pub fn json<T: Serialize>(mut self, body: T) -> Self { 84 assert!(self.body.is_none()); 85 let bytes = serde_json::to_vec(&body).unwrap(); 86 self.body = Some(bytes.into()); 87 self.headers.insert( 88 header::CONTENT_TYPE, 89 HeaderValue::from_static("application/json"), 90 ); 91 self 92 } 93 94 pub fn raw_json(mut self, raw: Bytes) -> Self { 95 assert!(self.body.is_none()); 96 self.body = Some(raw); 97 self.headers.insert( 98 header::CONTENT_TYPE, 99 HeaderValue::from_static("application/json"), 100 ); 101 self 102 } 103 104 pub fn deflate(mut self) -> Self { 105 let body = self.body.unwrap(); 106 let mut encoder = ZlibEncoder::new(Vec::with_capacity(body.len() / 4), Compression::fast()); 107 encoder.write_all(&body).unwrap(); 108 let compressed = encoder.finish().unwrap(); 109 self.body = Some(compressed.into()); 110 self.headers.insert( 111 header::CONTENT_ENCODING, 112 HeaderValue::from_static("deflate"), 113 ); 114 self 115 } 116 117 pub fn remove(mut self, k: impl AsHeaderName) -> Self { 118 self.headers.remove(k); 119 self 120 } 121 122 pub fn header<V>(mut self, k: impl IntoHeaderName, v: V) -> Self 123 where 124 V: TryInto<HeaderValue>, 125 V::Error: Debug, 126 { 127 self.headers.insert(k, v.try_into().unwrap()); 128 self 129 } 130 131 pub fn basic_auth(self, username: &str, password: &str) -> Self { 132 self.header( 133 AUTHORIZATION, 134 format!( 135 "Basic {}", 136 base64::fmt(format!("{username}:{password}").as_bytes()) 137 ), 138 ) 139 } 140 141 async fn send(self) -> TestResponse { 142 let Self { 143 router, 144 method, 145 url, 146 body: req_body, 147 headers, 148 } = self; 149 let uri = Uri::try_from(url.as_str()).unwrap(); 150 let mut req = axum::http::request::Builder::new() 151 .method(&method) 152 .uri(&uri) 153 .body(req_body.clone().map(Body::from).unwrap_or_else(Body::empty)) 154 .unwrap(); 155 *req.headers_mut() = headers; 156 let resp = router.clone().oneshot(req).await.unwrap(); 157 let (parts, body) = resp.into_parts(); 158 let bytes = body.collect().await.unwrap(); 159 TestResponse { 160 router, 161 req_body: req_body.unwrap_or_default(), 162 res_body: bytes.to_bytes(), 163 method, 164 status: parts.status, 165 uri, 166 } 167 } 168 } 169 170 impl IntoFuture for TestRequest { 171 type Output = TestResponse; 172 type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send>>; 173 174 fn into_future(self) -> Self::IntoFuture { 175 Box::pin(self.send()) 176 } 177 } 178 179 #[must_use] 180 pub struct TestResponse { 181 pub router: Router, 182 pub req_body: Bytes, 183 res_body: Bytes, 184 pub method: Method, 185 pub uri: Uri, 186 pub status: StatusCode, 187 } 188 189 impl TestResponse { 190 #[track_caller] 191 pub fn is_implemented(&self) -> bool { 192 if self.status == StatusCode::NOT_IMPLEMENTED { 193 let err: ErrorDetail = self.json_parse(); 194 warn!( 195 "{} is not implemented: {}", 196 self.uri.path(), 197 err.hint.unwrap_or_default() 198 ); 199 false 200 } else { 201 true 202 } 203 } 204 205 #[track_caller] 206 pub fn json_parse<'de, T: Deserialize<'de>>(&'de self) -> T { 207 let Self { 208 status, 209 res_body: bytes, 210 method, 211 uri, 212 .. 213 } = self; 214 match serde_json::from_slice(bytes) { 215 Ok(body) => body, 216 Err(err) => panic!( 217 "{method} {uri} {status} invalid JSON body: {err}\n{}", 218 String::from_utf8_lossy(bytes) 219 ), 220 } 221 } 222 223 #[track_caller] 224 pub fn assert_status(&self, expected: StatusCode) { 225 let Self { 226 status, 227 res_body: bytes, 228 method, 229 uri, 230 .. 231 } = self; 232 if expected != *status { 233 if status.is_success() || bytes.is_empty() { 234 panic!("{method} {uri} expected {expected} got {status}"); 235 } else { 236 let err: ErrorDetail = self.json_parse(); 237 let description = err.hint.unwrap_or_default(); 238 panic!( 239 "{method} {uri} expected {expected} got {status}: {} {description}", 240 err.code 241 ); 242 } 243 } 244 } 245 246 #[track_caller] 247 pub fn assert_ok_json<'de, T: Deserialize<'de>>(&'de self) -> T { 248 self.assert_ok(); 249 self.json_parse() 250 } 251 252 #[track_caller] 253 pub fn assert_accepted_json<'de, T: Deserialize<'de>>(&'de self) -> T { 254 self.assert_accepted(); 255 self.json_parse() 256 } 257 258 #[track_caller] 259 pub fn assert_ok(&self) { 260 self.assert_status(StatusCode::OK); 261 } 262 263 #[track_caller] 264 pub fn assert_accepted(&self) { 265 self.assert_status(StatusCode::ACCEPTED); 266 } 267 268 #[track_caller] 269 pub fn assert_no_content(&self) { 270 self.assert_status(StatusCode::NO_CONTENT); 271 } 272 273 #[track_caller] 274 pub fn assert_error(&self, error_code: ErrorCode) { 275 let (status_code, _) = error_code.metadata(); 276 self.assert_error_status(error_code, StatusCode::from_u16(status_code).unwrap()); 277 } 278 279 #[track_caller] 280 pub fn assert_error_status(&self, error_code: ErrorCode, status: StatusCode) { 281 self.assert_status(status); 282 let err: ErrorDetail = self.json_parse(); 283 assert_eq!(error_code as u32, err.code); 284 } 285 286 #[track_caller] 287 pub fn query<T: DeserializeOwned>(&self) -> T { 288 Query::try_from_uri(&self.uri).unwrap().0 289 } 290 }