server.rs (7198B)
1 /* 2 This file is part of TALER 3 Copyright (C) 2025, 2026 Taler Systems SA 4 5 TALER is free software; you can redistribute it and/or modify it under the 6 terms of the GNU Affero General Public License as published by the Free Software 7 Foundation; either version 3, or (at your option) any later version. 8 9 TALER is distributed in the hope that it will be useful, but WITHOUT ANY 10 WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR 11 A PARTICULAR PURPOSE. See the GNU Affero General Public License for more details. 12 13 You should have received a copy of the GNU Affero General Public License along with 14 TALER; see the file COPYING. If not, see <http://www.gnu.org/licenses/> 15 */ 16 17 use std::{fmt::Debug, io::Write, pin::Pin}; 18 19 use axum::{ 20 Router, 21 body::{Body, Bytes}, 22 extract::Query, 23 http::{ 24 HeaderMap, HeaderValue, Method, StatusCode, Uri, 25 header::{self, AsHeaderName, IntoHeaderName}, 26 }, 27 }; 28 use flate2::{Compression, write::ZlibEncoder}; 29 use http_body_util::BodyExt as _; 30 use serde::{Deserialize, Serialize, de::DeserializeOwned}; 31 use taler_common::{api_common::ErrorDetail, error_code::ErrorCode}; 32 use tower::ServiceExt as _; 33 use tracing::warn; 34 use url::Url; 35 36 pub trait TestServer { 37 fn method(&self, method: Method, path: &str) -> TestRequest; 38 39 fn get(&self, path: &str) -> TestRequest { 40 self.method(Method::GET, path) 41 } 42 43 fn post(&self, path: &str) -> TestRequest { 44 self.method(Method::POST, path) 45 } 46 47 fn delete(&self, path: &str) -> TestRequest { 48 self.method(Method::DELETE, path) 49 } 50 } 51 52 impl TestServer for Router { 53 fn method(&self, method: Method, path: &str) -> TestRequest { 54 let url = format!("https://example{path}"); 55 TestRequest { 56 router: self.clone(), 57 method, 58 url: url.parse().unwrap(), 59 body: None, 60 headers: HeaderMap::new(), 61 } 62 } 63 } 64 65 pub struct TestRequest { 66 router: Router, 67 method: Method, 68 url: Url, 69 body: Option<Vec<u8>>, 70 headers: HeaderMap, 71 } 72 73 impl TestRequest { 74 #[track_caller] 75 pub fn query<T: Serialize>(mut self, k: &str, v: T) -> Self { 76 let mut pairs = self.url.query_pairs_mut(); 77 let serializer = serde_urlencoded::Serializer::new(&mut pairs); 78 [(k, v)].serialize(serializer).unwrap(); 79 drop(pairs); 80 self 81 } 82 83 pub fn json<T: Serialize>(mut self, body: &T) -> Self { 84 assert!(self.body.is_none()); 85 let bytes = serde_json::to_vec(body).unwrap(); 86 self.body = Some(bytes); 87 self.headers.insert( 88 header::CONTENT_TYPE, 89 HeaderValue::from_static("application/json"), 90 ); 91 self 92 } 93 94 pub fn deflate(mut self) -> Self { 95 let body = self.body.unwrap(); 96 let mut encoder = ZlibEncoder::new(Vec::with_capacity(body.len() / 4), Compression::fast()); 97 encoder.write_all(&body).unwrap(); 98 let compressed = encoder.finish().unwrap(); 99 self.body = Some(compressed); 100 self.headers.insert( 101 header::CONTENT_ENCODING, 102 HeaderValue::from_static("deflate"), 103 ); 104 self 105 } 106 107 pub fn remove(mut self, k: impl AsHeaderName) -> Self { 108 self.headers.remove(k); 109 self 110 } 111 112 pub fn header<V>(mut self, k: impl IntoHeaderName, v: V) -> Self 113 where 114 V: TryInto<HeaderValue>, 115 V::Error: Debug, 116 { 117 self.headers.insert(k, v.try_into().unwrap()); 118 self 119 } 120 121 async fn send(self) -> TestResponse { 122 let TestRequest { 123 router, 124 method, 125 url: uri, 126 body, 127 headers, 128 } = self; 129 let uri = Uri::try_from(uri.as_str()).unwrap(); 130 let mut req = axum::http::request::Builder::new() 131 .method(&method) 132 .uri(&uri) 133 .body(body.map(Body::from).unwrap_or_else(Body::empty)) 134 .unwrap(); 135 *req.headers_mut() = headers; 136 let resp = router.oneshot(req).await.unwrap(); 137 let (parts, body) = resp.into_parts(); 138 let bytes = body.collect().await.unwrap(); 139 TestResponse { 140 bytes: bytes.to_bytes(), 141 method, 142 status: parts.status, 143 uri, 144 } 145 } 146 } 147 148 impl IntoFuture for TestRequest { 149 type Output = TestResponse; 150 type IntoFuture = Pin<Box<dyn Future<Output = Self::Output>>>; 151 152 fn into_future(self) -> Self::IntoFuture { 153 Box::pin(self.send()) 154 } 155 } 156 157 pub struct TestResponse { 158 bytes: Bytes, 159 method: Method, 160 uri: Uri, 161 pub status: StatusCode, 162 } 163 164 impl TestResponse { 165 #[track_caller] 166 pub fn is_implemented(&self) -> bool { 167 if self.status == StatusCode::NOT_IMPLEMENTED { 168 let err: ErrorDetail = self.json_parse(); 169 warn!( 170 "{} is not implemented: {}", 171 self.uri.path(), 172 err.hint.unwrap_or_default() 173 ); 174 false 175 } else { 176 true 177 } 178 } 179 180 #[track_caller] 181 pub fn json_parse<'de, T: Deserialize<'de>>(&'de self) -> T { 182 let TestResponse { 183 status, 184 bytes, 185 method, 186 uri, 187 } = self; 188 match serde_json::from_slice(bytes) { 189 Ok(body) => body, 190 Err(err) => panic!( 191 "{method} {uri} {status} invalid JSON body: {err}\n{}", 192 String::from_utf8_lossy(bytes) 193 ), 194 } 195 } 196 197 #[track_caller] 198 pub fn assert_status(&self, expected: StatusCode) { 199 let TestResponse { 200 status, 201 bytes, 202 method, 203 uri, 204 } = self; 205 if expected != *status { 206 if status.is_success() || bytes.is_empty() { 207 panic!("{method} {uri} expected {expected} got {status}"); 208 } else { 209 let err: ErrorDetail = self.json_parse(); 210 let description = err.hint.unwrap_or_default(); 211 panic!( 212 "{method} {uri} expected {expected} got {status}: {} {description}", 213 err.code 214 ); 215 } 216 } 217 } 218 219 #[track_caller] 220 pub fn assert_ok_json<'de, T: Deserialize<'de>>(&'de self) -> T { 221 self.assert_ok(); 222 self.json_parse() 223 } 224 225 #[track_caller] 226 pub fn assert_ok(&self) { 227 self.assert_status(StatusCode::OK); 228 } 229 230 #[track_caller] 231 pub fn assert_no_content(&self) { 232 self.assert_status(StatusCode::NO_CONTENT); 233 } 234 235 #[track_caller] 236 pub fn assert_error(&self, error_code: ErrorCode) { 237 let (status_code, _) = error_code.metadata(); 238 self.assert_error_status(error_code, StatusCode::from_u16(status_code).unwrap()); 239 } 240 241 #[track_caller] 242 pub fn assert_error_status(&self, error_code: ErrorCode, status: StatusCode) { 243 self.assert_status(status); 244 let err: ErrorDetail = self.json_parse(); 245 assert_eq!(error_code as u32, err.code); 246 } 247 248 #[track_caller] 249 pub fn query<T: DeserializeOwned>(&self) -> T { 250 Query::try_from_uri(&self.uri).unwrap().0 251 } 252 }