commit c052c27629a09fbe2ba15744d43fdbf60e42157b
parent fef23a786747b8c19de617a3b4538913edb9064a
Author: Antoine A <>
Date: Wed, 6 May 2026 10:08:27 +0200
common: many improvements
Diffstat:
14 files changed, 249 insertions(+), 153 deletions(-)
diff --git a/Cargo.toml b/Cargo.toml
@@ -57,3 +57,4 @@ rand = { version = "0.10" }
regex = { version = "1" }
rustls = "0.23"
http = "1.4"
+
diff --git a/common/taler-api/src/api/revenue.rs b/common/taler-api/src/api/revenue.rs
@@ -24,10 +24,7 @@ use taler_common::{
use super::TalerApi;
use crate::{
- api::RouterUtils as _,
- auth::AuthMethod,
- constants::{MAX_PAGE_SIZE, MAX_TIMEOUT_MS, REVENUE_API_VERSION},
- error::ApiResult,
+ api::RouterUtils as _, auth::AuthMethod, constants::REVENUE_API_VERSION, error::ApiResult,
extract::Query,
};
@@ -44,7 +41,7 @@ pub fn router<I: Revenue>(state: Arc<I>, auth: AuthMethod) -> Router {
"/history",
get(
async |State(state): State<Arc<I>>, Query(params): Query<HistoryParams>| {
- let params = params.check(MAX_PAGE_SIZE, MAX_TIMEOUT_MS)?;
+ let params = params.check()?;
let history = state.history(params).await?;
ApiResult::Ok(if history.incoming_transactions.is_empty() {
StatusCode::NO_CONTENT.into_response()
diff --git a/common/taler-api/src/api/wire.rs b/common/taler-api/src/api/wire.rs
@@ -38,7 +38,7 @@ use super::TalerApi;
use crate::{
api::RouterUtils as _,
auth::AuthMethod,
- constants::{MAX_PAGE_SIZE, MAX_TIMEOUT_MS, WIRE_GATEWAY_API_VERSION},
+ constants::WIRE_GATEWAY_API_VERSION,
error::{ApiResult, failure, failure_code, failure_status},
extract::{Path, Query, Req},
};
@@ -123,7 +123,7 @@ pub fn router<I: WireGateway>(state: Arc<I>, auth: AuthMethod) -> Router {
"/transfers",
get(
async |State(state): State<Arc<I>>, Query(params): Query<TransferParams>| {
- let page = params.pagination.check(MAX_PAGE_SIZE)?;
+ let page = params.pagination.check()?;
let list = state.transfer_page(page, params.status).await?;
ApiResult::Ok(if list.transfers.is_empty() {
StatusCode::NO_CONTENT.into_response()
@@ -149,7 +149,7 @@ pub fn router<I: WireGateway>(state: Arc<I>, auth: AuthMethod) -> Router {
"/history/incoming",
get(
async |State(state): State<Arc<I>>, Query(params): Query<HistoryParams>| {
- let params = params.check(MAX_PAGE_SIZE, MAX_TIMEOUT_MS)?;
+ let params = params.check()?;
let history = state.incoming_history(params).await?;
ApiResult::Ok(if history.incoming_transactions.is_empty() {
StatusCode::NO_CONTENT.into_response()
@@ -163,7 +163,7 @@ pub fn router<I: WireGateway>(state: Arc<I>, auth: AuthMethod) -> Router {
"/history/outgoing",
get(
async |State(state): State<Arc<I>>, Query(params): Query<HistoryParams>| {
- let params = params.check(MAX_PAGE_SIZE, MAX_TIMEOUT_MS)?;
+ let params = params.check()?;
let history = state.outgoing_history(params).await?;
ApiResult::Ok(if history.outgoing_transactions.is_empty() {
StatusCode::NO_CONTENT.into_response()
diff --git a/common/taler-api/src/constants.rs b/common/taler-api/src/constants.rs
@@ -17,6 +17,4 @@
pub const WIRE_GATEWAY_API_VERSION: &str = "5:0:0";
pub const PREPARED_TRANSFER_API_VERSION: &str = "1:0:0";
pub const REVENUE_API_VERSION: &str = "1:0:0";
-pub const MAX_PAGE_SIZE: i64 = 1024;
-pub const MAX_TIMEOUT_MS: u64 = 60 * 60 * 10; // 1H
pub const MAX_BODY_LENGTH: usize = 4 * 1024; // 4kB
diff --git a/common/taler-api/src/db.rs b/common/taler-api/src/db.rs
@@ -22,8 +22,10 @@ use jiff::{
tz::TimeZone,
};
use sqlx::{
- Decode, Error, PgPool, Postgres, QueryBuilder, Row, Type, error::BoxDynError, postgres::PgRow,
- query::Query,
+ Decode, Error, PgPool, Postgres, QueryBuilder, Row, Type,
+ error::BoxDynError,
+ postgres::PgRow,
+ query::{Query, QueryScalar},
};
use taler_common::{
api_common::SafeU64,
@@ -32,6 +34,7 @@ use taler_common::{
amount::{Amount, Currency, Decimal},
iban::IBAN,
payto::PaytoURI,
+ timestamp::TalerTimestamp,
utils::date_to_utc_ts,
},
};
@@ -192,6 +195,18 @@ impl<'q> BindHelper for Query<'q, Postgres, <Postgres as sqlx::Database>::Argume
}
}
+impl<'q, T> BindHelper
+ for QueryScalar<'q, Postgres, T, <Postgres as sqlx::Database>::Arguments<'q>>
+{
+ fn bind_timestamp(self, timestamp: &Timestamp) -> Self {
+ self.bind(timestamp.as_microsecond())
+ }
+
+ fn bind_date(self, date: &Date) -> Self {
+ self.bind_timestamp(&date_to_utc_ts(date))
+ }
+}
+
/* ----- Get ----- */
pub trait TypeHelper {
@@ -254,6 +269,19 @@ pub trait TypeHelper {
}
})
}
+ fn try_get_taler_timestamp<I: sqlx::ColumnIndex<Self>>(
+ &self,
+ index: I,
+ ) -> sqlx::Result<TalerTimestamp> {
+ self.try_get_map(index, |micros: Option<i64>| match micros {
+ Some(micros) => Ok::<_, String>(TalerTimestamp::Timestamp(
+ jiff::Timestamp::from_microsecond(micros).map_err(|e| {
+ format!("expected timestamp micros got overflowing {micros}: {e}")
+ })?,
+ )),
+ None => Ok(TalerTimestamp::Never),
+ })
+ }
fn try_get_date<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<Date> {
let timestamp = self.try_get_timestamp(index)?;
let zoned = timestamp.to_zoned(TimeZone::UTC);
diff --git a/common/taler-api/src/error.rs b/common/taler-api/src/error.rs
@@ -221,3 +221,11 @@ pub fn not_implemented(hint: impl Display) -> ApiError {
pub fn unauthorized(hint: impl Display) -> ApiError {
ApiError::new(ErrorCode::GENERIC_UNAUTHORIZED).with_hint(hint)
}
+
+pub fn forbidden(hint: impl Display) -> ApiError {
+ ApiError::new(ErrorCode::GENERIC_FORBIDDEN).with_hint(hint)
+}
+
+pub fn bad_request(hint: impl Display) -> ApiError {
+ ApiError::new(ErrorCode::GENERIC_JSON_INVALID).with_hint(hint)
+}
diff --git a/common/taler-api/src/extract.rs b/common/taler-api/src/extract.rs
@@ -15,9 +15,9 @@
*/
use axum::{
- body::Bytes,
+ body::{Body, Bytes},
extract::{FromRequest, FromRequestParts, Request},
- http::{StatusCode, header, request::Parts},
+ http::{HeaderMap, StatusCode, header, request::Parts},
};
use http_body_util::BodyExt as _;
use serde::de::DeserializeOwned;
@@ -27,9 +27,106 @@ use zlib_rs::{InflateConfig, ReturnCode};
use crate::{
constants::MAX_BODY_LENGTH,
- error::{ApiError, failure, failure_status},
+ error::{ApiError, ApiResult, failure, failure_status},
};
+pub async fn decompressed_strict_body(headers: &HeaderMap, body: Body) -> ApiResult<Bytes> {
+ // Check content type
+ match headers.get(header::CONTENT_TYPE) {
+ Some(header) => {
+ if !header.as_bytes().starts_with(b"application/json") {
+ return Err(failure_status(
+ ErrorCode::GENERIC_HTTP_HEADERS_MALFORMED,
+ "Bad Content-Type header",
+ StatusCode::UNSUPPORTED_MEDIA_TYPE,
+ ));
+ }
+ }
+ None => {
+ return Err(failure_status(
+ ErrorCode::GENERIC_HTTP_HEADERS_MALFORMED,
+ "Missing Content-Type header",
+ StatusCode::UNSUPPORTED_MEDIA_TYPE,
+ ));
+ }
+ }
+
+ // Check content length if present and well formed
+ if let Some(length) = headers
+ .get(header::CONTENT_LENGTH)
+ .and_then(|it| it.to_str().ok())
+ .and_then(|it| it.parse::<usize>().ok())
+ && length > MAX_BODY_LENGTH
+ {
+ return Err(failure(
+ ErrorCode::GENERIC_UPLOAD_EXCEEDS_LIMIT,
+ format!("Body is suspiciously big > {MAX_BODY_LENGTH}B"),
+ ));
+ }
+
+ // Check compression
+ let compressed = if let Some(encoding) = headers.get(header::CONTENT_ENCODING) {
+ if encoding == "deflate" {
+ true
+ } else {
+ return Err(failure_status(
+ ErrorCode::GENERIC_HTTP_HEADERS_MALFORMED,
+ format!(
+ "Unsupported encoding '{}'",
+ String::from_utf8_lossy(encoding.as_bytes())
+ ),
+ StatusCode::UNSUPPORTED_MEDIA_TYPE,
+ ));
+ }
+ } else {
+ false
+ };
+
+ // Buffer body
+ let body = http_body_util::Limited::new(body, MAX_BODY_LENGTH);
+ let bytes = match body.collect().await {
+ Ok(chunks) => chunks.to_bytes(),
+ Err(it) => match it.downcast::<http_body_util::LengthLimitError>() {
+ Ok(_) => {
+ return Err(failure(
+ ErrorCode::GENERIC_UPLOAD_EXCEEDS_LIMIT,
+ format!("Body is suspiciously big > {MAX_BODY_LENGTH}B"),
+ ));
+ }
+ Err(err) => {
+ return Err(failure(
+ ErrorCode::GENERIC_UNEXPECTED_REQUEST_ERROR,
+ format!("Failed to read body: {err}"),
+ ));
+ }
+ },
+ };
+
+ let bytes = if compressed {
+ let mut buf = [0; MAX_BODY_LENGTH];
+ let (decompressed, code) =
+ zlib_rs::decompress_slice(&mut buf, &bytes, InflateConfig::default());
+ match code {
+ ReturnCode::Ok => Bytes::copy_from_slice(decompressed),
+ ReturnCode::BufError => {
+ return Err(failure(
+ ErrorCode::GENERIC_UPLOAD_EXCEEDS_LIMIT,
+ format!("Decompressed body is suspiciously big > {MAX_BODY_LENGTH}B"),
+ ));
+ }
+ _ => {
+ return Err(failure(
+ ErrorCode::GENERIC_COMPRESSION_INVALID,
+ "Failed to decompress body: invalid compression",
+ ));
+ }
+ }
+ } else {
+ bytes
+ };
+ Ok(bytes)
+}
+
#[derive(Debug, Clone, Copy, Default)]
#[must_use]
pub struct Req<T>(pub T);
@@ -42,110 +139,24 @@ where
type Rejection = ApiError;
async fn from_request(req: Request, _state: &S) -> Result<Self, Self::Rejection> {
- // Check content type
- match req.headers().get(header::CONTENT_TYPE) {
- Some(header) => {
- if !header.as_bytes().starts_with(b"application/json") {
- return Err(failure_status(
- ErrorCode::GENERIC_HTTP_HEADERS_MALFORMED,
- "Bad Content-Type header",
- StatusCode::UNSUPPORTED_MEDIA_TYPE,
- ));
- }
- }
- None => {
- return Err(failure_status(
- ErrorCode::GENERIC_HTTP_HEADERS_MALFORMED,
- "Missing Content-Type header",
- StatusCode::UNSUPPORTED_MEDIA_TYPE,
- ));
- }
- }
+ let (parts, body) = req.into_parts();
+ let bytes = decompressed_strict_body(&parts.headers, body).await?;
+ Self::try_from(&bytes)
+ }
+}
- // Check content length if present and wellformed
- if let Some(length) = req
- .headers()
- .get(header::CONTENT_LENGTH)
- .and_then(|it| it.to_str().ok())
- .and_then(|it| it.parse::<usize>().ok())
- && length > MAX_BODY_LENGTH
- {
- return Err(failure(
- ErrorCode::GENERIC_UPLOAD_EXCEEDS_LIMIT,
- format!("Body is suspiciously big > {MAX_BODY_LENGTH}B"),
- ));
- }
+impl<T: DeserializeOwned> TryFrom<&Bytes> for Req<T> {
+ type Error = ApiError;
- // Check compression
- let compressed = if let Some(encoding) = req.headers().get(header::CONTENT_ENCODING) {
- if encoding == "deflate" {
- true
- } else {
- return Err(failure_status(
- ErrorCode::GENERIC_HTTP_HEADERS_MALFORMED,
- format!(
- "Unsupported encoding '{}'",
- String::from_utf8_lossy(encoding.as_bytes())
- ),
- StatusCode::UNSUPPORTED_MEDIA_TYPE,
- ));
- }
- } else {
- false
- };
-
- // Buffer body
- let (_, body) = req.into_parts();
- let body = http_body_util::Limited::new(body, MAX_BODY_LENGTH);
- let bytes = match body.collect().await {
- Ok(chunks) => chunks.to_bytes(),
- Err(it) => match it.downcast::<http_body_util::LengthLimitError>() {
- Ok(_) => {
- return Err(failure(
- ErrorCode::GENERIC_UPLOAD_EXCEEDS_LIMIT,
- format!("Body is suspiciously big > {MAX_BODY_LENGTH}B"),
- ));
- }
- Err(err) => {
- return Err(failure(
- ErrorCode::GENERIC_UNEXPECTED_REQUEST_ERROR,
- format!("Failed to read body: {err}"),
- ));
- }
- },
- };
-
- let bytes = if compressed {
- let mut buf = [0; MAX_BODY_LENGTH];
- let (decompressed, code) =
- zlib_rs::decompress_slice(&mut buf, &bytes, InflateConfig::default());
- dbg!(code);
- match code {
- ReturnCode::Ok => Bytes::copy_from_slice(decompressed),
- ReturnCode::BufError => {
- return Err(failure(
- ErrorCode::GENERIC_UPLOAD_EXCEEDS_LIMIT,
- format!("Decompressed body is suspiciously big > {MAX_BODY_LENGTH}B"),
- ));
- }
- _ => {
- return Err(failure(
- ErrorCode::GENERIC_COMPRESSION_INVALID,
- "Failed to decompress body: invalid compression",
- ));
- }
- }
- } else {
- bytes
- };
- let mut de = serde_json::de::Deserializer::from_slice(&bytes);
+ fn try_from(value: &Bytes) -> Result<Self, Self::Error> {
+ let mut de = serde_json::de::Deserializer::from_slice(value);
let parsed = serde_path_to_error::deserialize(&mut de)?;
Ok(Req(parsed))
}
}
#[derive(Debug, Clone, Copy, Default)]
-pub struct Path<T>(pub T);
+pub struct Path<T: DeserializeOwned + Send>(pub T);
impl<T: serde::de::DeserializeOwned + Send, S: Sync + Send> FromRequestParts<S> for Path<T> {
type Rejection = ApiError;
diff --git a/common/taler-api/src/subject.rs b/common/taler-api/src/subject.rs
@@ -121,6 +121,8 @@ pub enum IncomingSubjectResult {
pub enum IncomingSubjectErr {
#[error("found multiple public keys")]
Ambiguous,
+ #[error("missing reserve public key")]
+ Missing,
}
#[derive(Debug, thiserror::Error)]
@@ -182,9 +184,7 @@ pub fn fmt_in_subject(ty: IncomingType, key: &EddsaPublicKey) -> String {
* To parse them while ignoring user errors, we reconstruct valid keys from key
* parts, resolving ambiguities where possible.
**/
-pub fn parse_incoming_unstructured(
- subject: &str,
-) -> Result<Option<IncomingSubject>, IncomingSubjectErr> {
+pub fn parse_incoming_unstructured(subject: &str) -> Result<IncomingSubject, IncomingSubjectErr> {
// We expect subject to be less than 65KB
assert!(subject.len() <= u16::MAX as usize);
@@ -287,7 +287,11 @@ pub fn parse_incoming_unstructured(
}
}
- Ok(best.map(|it| it.subject))
+ if let Some(it) = best {
+ Ok(it.subject)
+ } else {
+ Err(IncomingSubjectErr::Missing)
+ }
}
// Modulo 10 Recursive
@@ -395,11 +399,11 @@ mod test {
let other_mixed =
&format!("{prefix}TEGY6d9mh9pgwvwpgs0z0095z854xegfy7jj202yd0esp8p0za60");
let key = EddsaPublicKey::from_str(key).unwrap();
- let result = Ok(Some(match ty {
+ let result = Ok(match ty {
IncomingType::reserve => IncomingSubject::Reserve(key),
IncomingType::kyc => IncomingSubject::Kyc(key),
IncomingType::map => IncomingSubject::Map(key),
- }));
+ });
// Check succeed if standard or mixed
for case in [standard, mixed] {
@@ -480,7 +484,10 @@ mod test {
&standard[1..], // Check fail if missing char
//"2MZT6RS3RVB3B0E2RDMYW0YRA3Y0VPHYV0CYDE6XBB0YMPFXCEG0", // Check fail if not a valid key
] {
- assert_eq!(parse_incoming_unstructured(case), Ok(None));
+ assert_eq!(
+ parse_incoming_unstructured(case),
+ Err(IncomingSubjectErr::Missing)
+ );
}
if ty == IncomingType::kyc || ty == IncomingType::map {
@@ -515,9 +522,9 @@ mod test {
),
] {
assert_eq!(
- Ok(Some(IncomingSubject::Reserve(
+ Ok(IncomingSubject::Reserve(
EddsaPublicKey::from_str(key).unwrap(),
- ))),
+ )),
parse_incoming_unstructured(subject)
)
}
@@ -527,9 +534,7 @@ mod test {
"JW398X85FWPKKMS0EYB6TQ1799RMY5DDXTZFPW4YC3WJ2DWSJT70",
)] {
assert_eq!(
- Ok(Some(IncomingSubject::Kyc(
- EddsaPublicKey::from_str(key).unwrap(),
- ))),
+ Ok(IncomingSubject::Kyc(EddsaPublicKey::from_str(key).unwrap(),)),
parse_incoming_unstructured(subject)
)
}
diff --git a/common/taler-common/src/api_params.rs b/common/taler-common/src/api_params.rs
@@ -1,6 +1,6 @@
/*
This file is part of TALER
- Copyright (C) 2024-2025 Taler Systems SA
+ Copyright (C) 2024, 2025, 2026 Taler Systems SA
TALER is free software; you can redistribute it and/or modify it under the
terms of the GNU Affero General Public License as published by the Free Software
@@ -43,7 +43,13 @@ pub struct PageParams {
}
impl PageParams {
- pub fn check(self, max_page_size: i64) -> Result<Page, ParamsErr> {
+ const MAX_PAGE_SIZE: i64 = 1024;
+
+ pub fn check(self) -> Result<Page, ParamsErr> {
+ Self::check_custom(self, Self::MAX_PAGE_SIZE)
+ }
+
+ pub fn check_custom(self, max_page_size: i64) -> Result<Page, ParamsErr> {
let limit = self.limit.unwrap_or(-20);
if limit == 0 {
return Err(param_err("limit", format!("must be non-zero got {limit}")));
@@ -100,10 +106,20 @@ pub struct HistoryParams {
}
impl HistoryParams {
- pub fn check(self, max_page_size: i64, max_timeout_ms: u64) -> Result<History, ParamsErr> {
+ pub const MAX_TIMEOUT_MS: u64 = 60 * 60 * 10; // 1H
+
+ pub fn check(self) -> Result<History, ParamsErr> {
+ Self::check_custom(self, PageParams::MAX_PAGE_SIZE, Self::MAX_TIMEOUT_MS)
+ }
+
+ pub fn check_custom(
+ self,
+ max_page_size: i64,
+ max_timeout_ms: u64,
+ ) -> Result<History, ParamsErr> {
let timeout_ms = self.timeout_ms.map(|it| it.min(max_timeout_ms));
Ok(History {
- page: self.pagination.check(max_page_size)?,
+ page: self.pagination.check_custom(max_page_size)?,
timeout_ms,
})
}
diff --git a/common/taler-common/src/types/base32.rs b/common/taler-common/src/types/base32.rs
@@ -16,6 +16,7 @@
use std::{borrow::Cow, fmt::Display, ops::Deref, str::FromStr};
+use rand::{TryRng, rngs::SysRng};
use serde::{Deserialize, Deserializer, Serialize, Serializer, de::Error};
use crate::encoding::base32::{Base32Error, decode_static, encode_static, encoded_buf_len};
@@ -27,6 +28,12 @@ impl<const L: usize> Base32<L> {
pub fn rand() -> Self {
Self(rand::random())
}
+
+ pub fn secure_rand() -> Self {
+ let mut array = [0; L];
+ SysRng.try_fill_bytes(&mut array).unwrap();
+ Self(array)
+ }
}
impl<const L: usize> From<[u8; L]> for Base32<L> {
diff --git a/common/taler-test-utils/src/routine.rs b/common/taler-test-utils/src/routine.rs
@@ -274,7 +274,7 @@ impl TestResponse {
}
let body = self.assert_ok_json::<T>();
let page = body.ids();
- let params = self.query::<PageParams>().check(1024).unwrap();
+ let params = self.query::<PageParams>().check().unwrap();
// testing the size is like expected
assert_eq!(size, page.len(), "bad page length: {page:?}");
diff --git a/common/taler-test-utils/src/server.rs b/common/taler-test-utils/src/server.rs
@@ -34,23 +34,23 @@ use tracing::warn;
use url::Url;
pub trait TestServer {
- fn method(&self, method: Method, path: &str) -> TestRequest;
+ fn request(&self, method: Method, path: &str) -> TestRequest;
fn get(&self, path: &str) -> TestRequest {
- self.method(Method::GET, path)
+ self.request(Method::GET, path)
}
fn post(&self, path: &str) -> TestRequest {
- self.method(Method::POST, path)
+ self.request(Method::POST, path)
}
fn delete(&self, path: &str) -> TestRequest {
- self.method(Method::DELETE, path)
+ self.request(Method::DELETE, path)
}
}
impl TestServer for Router {
- fn method(&self, method: Method, path: &str) -> TestRequest {
+ fn request(&self, method: Method, path: &str) -> TestRequest {
let url = format!("https://example{path}");
TestRequest {
router: self.clone(),
@@ -66,7 +66,7 @@ pub struct TestRequest {
router: Router,
method: Method,
pub url: Url,
- body: Option<Vec<u8>>,
+ body: Option<Bytes>,
headers: HeaderMap,
}
@@ -83,7 +83,17 @@ impl TestRequest {
pub fn json<T: Serialize>(mut self, body: T) -> Self {
assert!(self.body.is_none());
let bytes = serde_json::to_vec(&body).unwrap();
- self.body = Some(bytes);
+ self.body = Some(bytes.into());
+ self.headers.insert(
+ header::CONTENT_TYPE,
+ HeaderValue::from_static("application/json"),
+ );
+ self
+ }
+
+ pub fn raw_json(mut self, raw: Bytes) -> Self {
+ assert!(self.body.is_none());
+ self.body = Some(raw);
self.headers.insert(
header::CONTENT_TYPE,
HeaderValue::from_static("application/json"),
@@ -96,7 +106,7 @@ impl TestRequest {
let mut encoder = ZlibEncoder::new(Vec::with_capacity(body.len() / 4), Compression::fast());
encoder.write_all(&body).unwrap();
let compressed = encoder.finish().unwrap();
- self.body = Some(compressed);
+ self.body = Some(compressed.into());
self.headers.insert(
header::CONTENT_ENCODING,
HeaderValue::from_static("deflate"),
@@ -132,22 +142,24 @@ impl TestRequest {
let Self {
router,
method,
- url: uri,
- body,
+ url,
+ body: req_body,
headers,
} = self;
- let uri = Uri::try_from(uri.as_str()).unwrap();
+ let uri = Uri::try_from(url.as_str()).unwrap();
let mut req = axum::http::request::Builder::new()
.method(&method)
.uri(&uri)
- .body(body.map(Body::from).unwrap_or_else(Body::empty))
+ .body(req_body.clone().map(Body::from).unwrap_or_else(Body::empty))
.unwrap();
*req.headers_mut() = headers;
- let resp = router.oneshot(req).await.unwrap();
+ let resp = router.clone().oneshot(req).await.unwrap();
let (parts, body) = resp.into_parts();
let bytes = body.collect().await.unwrap();
TestResponse {
- bytes: bytes.to_bytes(),
+ router,
+ req_body: req_body.unwrap_or_default(),
+ res_body: bytes.to_bytes(),
method,
status: parts.status,
uri,
@@ -157,7 +169,7 @@ impl TestRequest {
impl IntoFuture for TestRequest {
type Output = TestResponse;
- type IntoFuture = Pin<Box<dyn Future<Output = Self::Output>>>;
+ type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send>>;
fn into_future(self) -> Self::IntoFuture {
Box::pin(self.send())
@@ -166,9 +178,11 @@ impl IntoFuture for TestRequest {
#[must_use]
pub struct TestResponse {
- bytes: Bytes,
- method: Method,
- uri: Uri,
+ pub router: Router,
+ pub req_body: Bytes,
+ res_body: Bytes,
+ pub method: Method,
+ pub uri: Uri,
pub status: StatusCode,
}
@@ -192,9 +206,10 @@ impl TestResponse {
pub fn json_parse<'de, T: Deserialize<'de>>(&'de self) -> T {
let Self {
status,
- bytes,
+ res_body: bytes,
method,
uri,
+ ..
} = self;
match serde_json::from_slice(bytes) {
Ok(body) => body,
@@ -209,9 +224,10 @@ impl TestResponse {
pub fn assert_status(&self, expected: StatusCode) {
let Self {
status,
- bytes,
+ res_body: bytes,
method,
uri,
+ ..
} = self;
if expected != *status {
if status.is_success() || bytes.is_empty() {
@@ -234,11 +250,22 @@ impl TestResponse {
}
#[track_caller]
+ pub fn assert_accepted_json<'de, T: Deserialize<'de>>(&'de self) -> T {
+ self.assert_accepted();
+ self.json_parse()
+ }
+
+ #[track_caller]
pub fn assert_ok(&self) {
self.assert_status(StatusCode::OK);
}
#[track_caller]
+ pub fn assert_accepted(&self) {
+ self.assert_status(StatusCode::ACCEPTED);
+ }
+
+ #[track_caller]
pub fn assert_no_content(&self) {
self.assert_status(StatusCode::NO_CONTENT);
}
diff --git a/taler-cyclos/src/worker.rs b/taler-cyclos/src/worker.rs
@@ -352,8 +352,7 @@ impl Worker<'_> {
return Ok(());
}
match parse_incoming_unstructured(&tx.subject) {
- Ok(None) => bounce(self.db, "missing public key").await?,
- Ok(Some(subject)) => {
+ Ok(subject) => {
match db::register_tx_in(self.db, &tx, &Some(subject), &Timestamp::now())
.await?
{
diff --git a/taler-magnet-bank/src/worker.rs b/taler-magnet-bank/src/worker.rs
@@ -247,8 +247,7 @@ impl Worker<'_> {
match self.account_type {
AccountType::Exchange => {
match parse_incoming_unstructured(&tx_in.subject) {
- Ok(None) => bounce(self.db, "missing public key").await?,
- Ok(Some(subject)) => match db::register_tx_in(
+ Ok(subject) => match db::register_tx_in(
self.db,
&tx_in,
&Some(subject),