taler-rust

GNU Taler code in Rust. Largely core banking integrations.
Log | Files | Refs | Submodules | README | LICENSE

commit c052c27629a09fbe2ba15744d43fdbf60e42157b
parent fef23a786747b8c19de617a3b4538913edb9064a
Author: Antoine A <>
Date:   Wed,  6 May 2026 10:08:27 +0200

common: many improvements

Diffstat:
MCargo.toml | 1+
Mcommon/taler-api/src/api/revenue.rs | 7++-----
Mcommon/taler-api/src/api/wire.rs | 8++++----
Mcommon/taler-api/src/constants.rs | 2--
Mcommon/taler-api/src/db.rs | 32++++++++++++++++++++++++++++++--
Mcommon/taler-api/src/error.rs | 8++++++++
Mcommon/taler-api/src/extract.rs | 209++++++++++++++++++++++++++++++++++++++++++-------------------------------------
Mcommon/taler-api/src/subject.rs | 29+++++++++++++++++------------
Mcommon/taler-common/src/api_params.rs | 24++++++++++++++++++++----
Mcommon/taler-common/src/types/base32.rs | 7+++++++
Mcommon/taler-test-utils/src/routine.rs | 2+-
Mcommon/taler-test-utils/src/server.rs | 67+++++++++++++++++++++++++++++++++++++++++++++++--------------------
Mtaler-cyclos/src/worker.rs | 3+--
Mtaler-magnet-bank/src/worker.rs | 3+--
14 files changed, 249 insertions(+), 153 deletions(-)

diff --git a/Cargo.toml b/Cargo.toml @@ -57,3 +57,4 @@ rand = { version = "0.10" } regex = { version = "1" } rustls = "0.23" http = "1.4" + diff --git a/common/taler-api/src/api/revenue.rs b/common/taler-api/src/api/revenue.rs @@ -24,10 +24,7 @@ use taler_common::{ use super::TalerApi; use crate::{ - api::RouterUtils as _, - auth::AuthMethod, - constants::{MAX_PAGE_SIZE, MAX_TIMEOUT_MS, REVENUE_API_VERSION}, - error::ApiResult, + api::RouterUtils as _, auth::AuthMethod, constants::REVENUE_API_VERSION, error::ApiResult, extract::Query, }; @@ -44,7 +41,7 @@ pub fn router<I: Revenue>(state: Arc<I>, auth: AuthMethod) -> Router { "/history", get( async |State(state): State<Arc<I>>, Query(params): Query<HistoryParams>| { - let params = params.check(MAX_PAGE_SIZE, MAX_TIMEOUT_MS)?; + let params = params.check()?; let history = state.history(params).await?; ApiResult::Ok(if history.incoming_transactions.is_empty() { StatusCode::NO_CONTENT.into_response() diff --git a/common/taler-api/src/api/wire.rs b/common/taler-api/src/api/wire.rs @@ -38,7 +38,7 @@ use super::TalerApi; use crate::{ api::RouterUtils as _, auth::AuthMethod, - constants::{MAX_PAGE_SIZE, MAX_TIMEOUT_MS, WIRE_GATEWAY_API_VERSION}, + constants::WIRE_GATEWAY_API_VERSION, error::{ApiResult, failure, failure_code, failure_status}, extract::{Path, Query, Req}, }; @@ -123,7 +123,7 @@ pub fn router<I: WireGateway>(state: Arc<I>, auth: AuthMethod) -> Router { "/transfers", get( async |State(state): State<Arc<I>>, Query(params): Query<TransferParams>| { - let page = params.pagination.check(MAX_PAGE_SIZE)?; + let page = params.pagination.check()?; let list = state.transfer_page(page, params.status).await?; ApiResult::Ok(if list.transfers.is_empty() { StatusCode::NO_CONTENT.into_response() @@ -149,7 +149,7 @@ pub fn router<I: WireGateway>(state: Arc<I>, auth: AuthMethod) -> Router { "/history/incoming", get( async |State(state): State<Arc<I>>, Query(params): Query<HistoryParams>| { - let params = params.check(MAX_PAGE_SIZE, MAX_TIMEOUT_MS)?; + let params = params.check()?; let history = state.incoming_history(params).await?; ApiResult::Ok(if history.incoming_transactions.is_empty() { StatusCode::NO_CONTENT.into_response() @@ -163,7 +163,7 @@ pub fn router<I: WireGateway>(state: Arc<I>, auth: AuthMethod) -> Router { "/history/outgoing", get( async |State(state): State<Arc<I>>, Query(params): Query<HistoryParams>| { - let params = params.check(MAX_PAGE_SIZE, MAX_TIMEOUT_MS)?; + let params = params.check()?; let history = state.outgoing_history(params).await?; ApiResult::Ok(if history.outgoing_transactions.is_empty() { StatusCode::NO_CONTENT.into_response() diff --git a/common/taler-api/src/constants.rs b/common/taler-api/src/constants.rs @@ -17,6 +17,4 @@ pub const WIRE_GATEWAY_API_VERSION: &str = "5:0:0"; pub const PREPARED_TRANSFER_API_VERSION: &str = "1:0:0"; pub const REVENUE_API_VERSION: &str = "1:0:0"; -pub const MAX_PAGE_SIZE: i64 = 1024; -pub const MAX_TIMEOUT_MS: u64 = 60 * 60 * 10; // 1H pub const MAX_BODY_LENGTH: usize = 4 * 1024; // 4kB diff --git a/common/taler-api/src/db.rs b/common/taler-api/src/db.rs @@ -22,8 +22,10 @@ use jiff::{ tz::TimeZone, }; use sqlx::{ - Decode, Error, PgPool, Postgres, QueryBuilder, Row, Type, error::BoxDynError, postgres::PgRow, - query::Query, + Decode, Error, PgPool, Postgres, QueryBuilder, Row, Type, + error::BoxDynError, + postgres::PgRow, + query::{Query, QueryScalar}, }; use taler_common::{ api_common::SafeU64, @@ -32,6 +34,7 @@ use taler_common::{ amount::{Amount, Currency, Decimal}, iban::IBAN, payto::PaytoURI, + timestamp::TalerTimestamp, utils::date_to_utc_ts, }, }; @@ -192,6 +195,18 @@ impl<'q> BindHelper for Query<'q, Postgres, <Postgres as sqlx::Database>::Argume } } +impl<'q, T> BindHelper + for QueryScalar<'q, Postgres, T, <Postgres as sqlx::Database>::Arguments<'q>> +{ + fn bind_timestamp(self, timestamp: &Timestamp) -> Self { + self.bind(timestamp.as_microsecond()) + } + + fn bind_date(self, date: &Date) -> Self { + self.bind_timestamp(&date_to_utc_ts(date)) + } +} + /* ----- Get ----- */ pub trait TypeHelper { @@ -254,6 +269,19 @@ pub trait TypeHelper { } }) } + fn try_get_taler_timestamp<I: sqlx::ColumnIndex<Self>>( + &self, + index: I, + ) -> sqlx::Result<TalerTimestamp> { + self.try_get_map(index, |micros: Option<i64>| match micros { + Some(micros) => Ok::<_, String>(TalerTimestamp::Timestamp( + jiff::Timestamp::from_microsecond(micros).map_err(|e| { + format!("expected timestamp micros got overflowing {micros}: {e}") + })?, + )), + None => Ok(TalerTimestamp::Never), + }) + } fn try_get_date<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<Date> { let timestamp = self.try_get_timestamp(index)?; let zoned = timestamp.to_zoned(TimeZone::UTC); diff --git a/common/taler-api/src/error.rs b/common/taler-api/src/error.rs @@ -221,3 +221,11 @@ pub fn not_implemented(hint: impl Display) -> ApiError { pub fn unauthorized(hint: impl Display) -> ApiError { ApiError::new(ErrorCode::GENERIC_UNAUTHORIZED).with_hint(hint) } + +pub fn forbidden(hint: impl Display) -> ApiError { + ApiError::new(ErrorCode::GENERIC_FORBIDDEN).with_hint(hint) +} + +pub fn bad_request(hint: impl Display) -> ApiError { + ApiError::new(ErrorCode::GENERIC_JSON_INVALID).with_hint(hint) +} diff --git a/common/taler-api/src/extract.rs b/common/taler-api/src/extract.rs @@ -15,9 +15,9 @@ */ use axum::{ - body::Bytes, + body::{Body, Bytes}, extract::{FromRequest, FromRequestParts, Request}, - http::{StatusCode, header, request::Parts}, + http::{HeaderMap, StatusCode, header, request::Parts}, }; use http_body_util::BodyExt as _; use serde::de::DeserializeOwned; @@ -27,9 +27,106 @@ use zlib_rs::{InflateConfig, ReturnCode}; use crate::{ constants::MAX_BODY_LENGTH, - error::{ApiError, failure, failure_status}, + error::{ApiError, ApiResult, failure, failure_status}, }; +pub async fn decompressed_strict_body(headers: &HeaderMap, body: Body) -> ApiResult<Bytes> { + // Check content type + match headers.get(header::CONTENT_TYPE) { + Some(header) => { + if !header.as_bytes().starts_with(b"application/json") { + return Err(failure_status( + ErrorCode::GENERIC_HTTP_HEADERS_MALFORMED, + "Bad Content-Type header", + StatusCode::UNSUPPORTED_MEDIA_TYPE, + )); + } + } + None => { + return Err(failure_status( + ErrorCode::GENERIC_HTTP_HEADERS_MALFORMED, + "Missing Content-Type header", + StatusCode::UNSUPPORTED_MEDIA_TYPE, + )); + } + } + + // Check content length if present and well formed + if let Some(length) = headers + .get(header::CONTENT_LENGTH) + .and_then(|it| it.to_str().ok()) + .and_then(|it| it.parse::<usize>().ok()) + && length > MAX_BODY_LENGTH + { + return Err(failure( + ErrorCode::GENERIC_UPLOAD_EXCEEDS_LIMIT, + format!("Body is suspiciously big > {MAX_BODY_LENGTH}B"), + )); + } + + // Check compression + let compressed = if let Some(encoding) = headers.get(header::CONTENT_ENCODING) { + if encoding == "deflate" { + true + } else { + return Err(failure_status( + ErrorCode::GENERIC_HTTP_HEADERS_MALFORMED, + format!( + "Unsupported encoding '{}'", + String::from_utf8_lossy(encoding.as_bytes()) + ), + StatusCode::UNSUPPORTED_MEDIA_TYPE, + )); + } + } else { + false + }; + + // Buffer body + let body = http_body_util::Limited::new(body, MAX_BODY_LENGTH); + let bytes = match body.collect().await { + Ok(chunks) => chunks.to_bytes(), + Err(it) => match it.downcast::<http_body_util::LengthLimitError>() { + Ok(_) => { + return Err(failure( + ErrorCode::GENERIC_UPLOAD_EXCEEDS_LIMIT, + format!("Body is suspiciously big > {MAX_BODY_LENGTH}B"), + )); + } + Err(err) => { + return Err(failure( + ErrorCode::GENERIC_UNEXPECTED_REQUEST_ERROR, + format!("Failed to read body: {err}"), + )); + } + }, + }; + + let bytes = if compressed { + let mut buf = [0; MAX_BODY_LENGTH]; + let (decompressed, code) = + zlib_rs::decompress_slice(&mut buf, &bytes, InflateConfig::default()); + match code { + ReturnCode::Ok => Bytes::copy_from_slice(decompressed), + ReturnCode::BufError => { + return Err(failure( + ErrorCode::GENERIC_UPLOAD_EXCEEDS_LIMIT, + format!("Decompressed body is suspiciously big > {MAX_BODY_LENGTH}B"), + )); + } + _ => { + return Err(failure( + ErrorCode::GENERIC_COMPRESSION_INVALID, + "Failed to decompress body: invalid compression", + )); + } + } + } else { + bytes + }; + Ok(bytes) +} + #[derive(Debug, Clone, Copy, Default)] #[must_use] pub struct Req<T>(pub T); @@ -42,110 +139,24 @@ where type Rejection = ApiError; async fn from_request(req: Request, _state: &S) -> Result<Self, Self::Rejection> { - // Check content type - match req.headers().get(header::CONTENT_TYPE) { - Some(header) => { - if !header.as_bytes().starts_with(b"application/json") { - return Err(failure_status( - ErrorCode::GENERIC_HTTP_HEADERS_MALFORMED, - "Bad Content-Type header", - StatusCode::UNSUPPORTED_MEDIA_TYPE, - )); - } - } - None => { - return Err(failure_status( - ErrorCode::GENERIC_HTTP_HEADERS_MALFORMED, - "Missing Content-Type header", - StatusCode::UNSUPPORTED_MEDIA_TYPE, - )); - } - } + let (parts, body) = req.into_parts(); + let bytes = decompressed_strict_body(&parts.headers, body).await?; + Self::try_from(&bytes) + } +} - // Check content length if present and wellformed - if let Some(length) = req - .headers() - .get(header::CONTENT_LENGTH) - .and_then(|it| it.to_str().ok()) - .and_then(|it| it.parse::<usize>().ok()) - && length > MAX_BODY_LENGTH - { - return Err(failure( - ErrorCode::GENERIC_UPLOAD_EXCEEDS_LIMIT, - format!("Body is suspiciously big > {MAX_BODY_LENGTH}B"), - )); - } +impl<T: DeserializeOwned> TryFrom<&Bytes> for Req<T> { + type Error = ApiError; - // Check compression - let compressed = if let Some(encoding) = req.headers().get(header::CONTENT_ENCODING) { - if encoding == "deflate" { - true - } else { - return Err(failure_status( - ErrorCode::GENERIC_HTTP_HEADERS_MALFORMED, - format!( - "Unsupported encoding '{}'", - String::from_utf8_lossy(encoding.as_bytes()) - ), - StatusCode::UNSUPPORTED_MEDIA_TYPE, - )); - } - } else { - false - }; - - // Buffer body - let (_, body) = req.into_parts(); - let body = http_body_util::Limited::new(body, MAX_BODY_LENGTH); - let bytes = match body.collect().await { - Ok(chunks) => chunks.to_bytes(), - Err(it) => match it.downcast::<http_body_util::LengthLimitError>() { - Ok(_) => { - return Err(failure( - ErrorCode::GENERIC_UPLOAD_EXCEEDS_LIMIT, - format!("Body is suspiciously big > {MAX_BODY_LENGTH}B"), - )); - } - Err(err) => { - return Err(failure( - ErrorCode::GENERIC_UNEXPECTED_REQUEST_ERROR, - format!("Failed to read body: {err}"), - )); - } - }, - }; - - let bytes = if compressed { - let mut buf = [0; MAX_BODY_LENGTH]; - let (decompressed, code) = - zlib_rs::decompress_slice(&mut buf, &bytes, InflateConfig::default()); - dbg!(code); - match code { - ReturnCode::Ok => Bytes::copy_from_slice(decompressed), - ReturnCode::BufError => { - return Err(failure( - ErrorCode::GENERIC_UPLOAD_EXCEEDS_LIMIT, - format!("Decompressed body is suspiciously big > {MAX_BODY_LENGTH}B"), - )); - } - _ => { - return Err(failure( - ErrorCode::GENERIC_COMPRESSION_INVALID, - "Failed to decompress body: invalid compression", - )); - } - } - } else { - bytes - }; - let mut de = serde_json::de::Deserializer::from_slice(&bytes); + fn try_from(value: &Bytes) -> Result<Self, Self::Error> { + let mut de = serde_json::de::Deserializer::from_slice(value); let parsed = serde_path_to_error::deserialize(&mut de)?; Ok(Req(parsed)) } } #[derive(Debug, Clone, Copy, Default)] -pub struct Path<T>(pub T); +pub struct Path<T: DeserializeOwned + Send>(pub T); impl<T: serde::de::DeserializeOwned + Send, S: Sync + Send> FromRequestParts<S> for Path<T> { type Rejection = ApiError; diff --git a/common/taler-api/src/subject.rs b/common/taler-api/src/subject.rs @@ -121,6 +121,8 @@ pub enum IncomingSubjectResult { pub enum IncomingSubjectErr { #[error("found multiple public keys")] Ambiguous, + #[error("missing reserve public key")] + Missing, } #[derive(Debug, thiserror::Error)] @@ -182,9 +184,7 @@ pub fn fmt_in_subject(ty: IncomingType, key: &EddsaPublicKey) -> String { * To parse them while ignoring user errors, we reconstruct valid keys from key * parts, resolving ambiguities where possible. **/ -pub fn parse_incoming_unstructured( - subject: &str, -) -> Result<Option<IncomingSubject>, IncomingSubjectErr> { +pub fn parse_incoming_unstructured(subject: &str) -> Result<IncomingSubject, IncomingSubjectErr> { // We expect subject to be less than 65KB assert!(subject.len() <= u16::MAX as usize); @@ -287,7 +287,11 @@ pub fn parse_incoming_unstructured( } } - Ok(best.map(|it| it.subject)) + if let Some(it) = best { + Ok(it.subject) + } else { + Err(IncomingSubjectErr::Missing) + } } // Modulo 10 Recursive @@ -395,11 +399,11 @@ mod test { let other_mixed = &format!("{prefix}TEGY6d9mh9pgwvwpgs0z0095z854xegfy7jj202yd0esp8p0za60"); let key = EddsaPublicKey::from_str(key).unwrap(); - let result = Ok(Some(match ty { + let result = Ok(match ty { IncomingType::reserve => IncomingSubject::Reserve(key), IncomingType::kyc => IncomingSubject::Kyc(key), IncomingType::map => IncomingSubject::Map(key), - })); + }); // Check succeed if standard or mixed for case in [standard, mixed] { @@ -480,7 +484,10 @@ mod test { &standard[1..], // Check fail if missing char //"2MZT6RS3RVB3B0E2RDMYW0YRA3Y0VPHYV0CYDE6XBB0YMPFXCEG0", // Check fail if not a valid key ] { - assert_eq!(parse_incoming_unstructured(case), Ok(None)); + assert_eq!( + parse_incoming_unstructured(case), + Err(IncomingSubjectErr::Missing) + ); } if ty == IncomingType::kyc || ty == IncomingType::map { @@ -515,9 +522,9 @@ mod test { ), ] { assert_eq!( - Ok(Some(IncomingSubject::Reserve( + Ok(IncomingSubject::Reserve( EddsaPublicKey::from_str(key).unwrap(), - ))), + )), parse_incoming_unstructured(subject) ) } @@ -527,9 +534,7 @@ mod test { "JW398X85FWPKKMS0EYB6TQ1799RMY5DDXTZFPW4YC3WJ2DWSJT70", )] { assert_eq!( - Ok(Some(IncomingSubject::Kyc( - EddsaPublicKey::from_str(key).unwrap(), - ))), + Ok(IncomingSubject::Kyc(EddsaPublicKey::from_str(key).unwrap(),)), parse_incoming_unstructured(subject) ) } diff --git a/common/taler-common/src/api_params.rs b/common/taler-common/src/api_params.rs @@ -1,6 +1,6 @@ /* This file is part of TALER - Copyright (C) 2024-2025 Taler Systems SA + Copyright (C) 2024, 2025, 2026 Taler Systems SA TALER is free software; you can redistribute it and/or modify it under the terms of the GNU Affero General Public License as published by the Free Software @@ -43,7 +43,13 @@ pub struct PageParams { } impl PageParams { - pub fn check(self, max_page_size: i64) -> Result<Page, ParamsErr> { + const MAX_PAGE_SIZE: i64 = 1024; + + pub fn check(self) -> Result<Page, ParamsErr> { + Self::check_custom(self, Self::MAX_PAGE_SIZE) + } + + pub fn check_custom(self, max_page_size: i64) -> Result<Page, ParamsErr> { let limit = self.limit.unwrap_or(-20); if limit == 0 { return Err(param_err("limit", format!("must be non-zero got {limit}"))); @@ -100,10 +106,20 @@ pub struct HistoryParams { } impl HistoryParams { - pub fn check(self, max_page_size: i64, max_timeout_ms: u64) -> Result<History, ParamsErr> { + pub const MAX_TIMEOUT_MS: u64 = 60 * 60 * 10; // 1H + + pub fn check(self) -> Result<History, ParamsErr> { + Self::check_custom(self, PageParams::MAX_PAGE_SIZE, Self::MAX_TIMEOUT_MS) + } + + pub fn check_custom( + self, + max_page_size: i64, + max_timeout_ms: u64, + ) -> Result<History, ParamsErr> { let timeout_ms = self.timeout_ms.map(|it| it.min(max_timeout_ms)); Ok(History { - page: self.pagination.check(max_page_size)?, + page: self.pagination.check_custom(max_page_size)?, timeout_ms, }) } diff --git a/common/taler-common/src/types/base32.rs b/common/taler-common/src/types/base32.rs @@ -16,6 +16,7 @@ use std::{borrow::Cow, fmt::Display, ops::Deref, str::FromStr}; +use rand::{TryRng, rngs::SysRng}; use serde::{Deserialize, Deserializer, Serialize, Serializer, de::Error}; use crate::encoding::base32::{Base32Error, decode_static, encode_static, encoded_buf_len}; @@ -27,6 +28,12 @@ impl<const L: usize> Base32<L> { pub fn rand() -> Self { Self(rand::random()) } + + pub fn secure_rand() -> Self { + let mut array = [0; L]; + SysRng.try_fill_bytes(&mut array).unwrap(); + Self(array) + } } impl<const L: usize> From<[u8; L]> for Base32<L> { diff --git a/common/taler-test-utils/src/routine.rs b/common/taler-test-utils/src/routine.rs @@ -274,7 +274,7 @@ impl TestResponse { } let body = self.assert_ok_json::<T>(); let page = body.ids(); - let params = self.query::<PageParams>().check(1024).unwrap(); + let params = self.query::<PageParams>().check().unwrap(); // testing the size is like expected assert_eq!(size, page.len(), "bad page length: {page:?}"); diff --git a/common/taler-test-utils/src/server.rs b/common/taler-test-utils/src/server.rs @@ -34,23 +34,23 @@ use tracing::warn; use url::Url; pub trait TestServer { - fn method(&self, method: Method, path: &str) -> TestRequest; + fn request(&self, method: Method, path: &str) -> TestRequest; fn get(&self, path: &str) -> TestRequest { - self.method(Method::GET, path) + self.request(Method::GET, path) } fn post(&self, path: &str) -> TestRequest { - self.method(Method::POST, path) + self.request(Method::POST, path) } fn delete(&self, path: &str) -> TestRequest { - self.method(Method::DELETE, path) + self.request(Method::DELETE, path) } } impl TestServer for Router { - fn method(&self, method: Method, path: &str) -> TestRequest { + fn request(&self, method: Method, path: &str) -> TestRequest { let url = format!("https://example{path}"); TestRequest { router: self.clone(), @@ -66,7 +66,7 @@ pub struct TestRequest { router: Router, method: Method, pub url: Url, - body: Option<Vec<u8>>, + body: Option<Bytes>, headers: HeaderMap, } @@ -83,7 +83,17 @@ impl TestRequest { pub fn json<T: Serialize>(mut self, body: T) -> Self { assert!(self.body.is_none()); let bytes = serde_json::to_vec(&body).unwrap(); - self.body = Some(bytes); + self.body = Some(bytes.into()); + self.headers.insert( + header::CONTENT_TYPE, + HeaderValue::from_static("application/json"), + ); + self + } + + pub fn raw_json(mut self, raw: Bytes) -> Self { + assert!(self.body.is_none()); + self.body = Some(raw); self.headers.insert( header::CONTENT_TYPE, HeaderValue::from_static("application/json"), @@ -96,7 +106,7 @@ impl TestRequest { let mut encoder = ZlibEncoder::new(Vec::with_capacity(body.len() / 4), Compression::fast()); encoder.write_all(&body).unwrap(); let compressed = encoder.finish().unwrap(); - self.body = Some(compressed); + self.body = Some(compressed.into()); self.headers.insert( header::CONTENT_ENCODING, HeaderValue::from_static("deflate"), @@ -132,22 +142,24 @@ impl TestRequest { let Self { router, method, - url: uri, - body, + url, + body: req_body, headers, } = self; - let uri = Uri::try_from(uri.as_str()).unwrap(); + let uri = Uri::try_from(url.as_str()).unwrap(); let mut req = axum::http::request::Builder::new() .method(&method) .uri(&uri) - .body(body.map(Body::from).unwrap_or_else(Body::empty)) + .body(req_body.clone().map(Body::from).unwrap_or_else(Body::empty)) .unwrap(); *req.headers_mut() = headers; - let resp = router.oneshot(req).await.unwrap(); + let resp = router.clone().oneshot(req).await.unwrap(); let (parts, body) = resp.into_parts(); let bytes = body.collect().await.unwrap(); TestResponse { - bytes: bytes.to_bytes(), + router, + req_body: req_body.unwrap_or_default(), + res_body: bytes.to_bytes(), method, status: parts.status, uri, @@ -157,7 +169,7 @@ impl TestRequest { impl IntoFuture for TestRequest { type Output = TestResponse; - type IntoFuture = Pin<Box<dyn Future<Output = Self::Output>>>; + type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send>>; fn into_future(self) -> Self::IntoFuture { Box::pin(self.send()) @@ -166,9 +178,11 @@ impl IntoFuture for TestRequest { #[must_use] pub struct TestResponse { - bytes: Bytes, - method: Method, - uri: Uri, + pub router: Router, + pub req_body: Bytes, + res_body: Bytes, + pub method: Method, + pub uri: Uri, pub status: StatusCode, } @@ -192,9 +206,10 @@ impl TestResponse { pub fn json_parse<'de, T: Deserialize<'de>>(&'de self) -> T { let Self { status, - bytes, + res_body: bytes, method, uri, + .. } = self; match serde_json::from_slice(bytes) { Ok(body) => body, @@ -209,9 +224,10 @@ impl TestResponse { pub fn assert_status(&self, expected: StatusCode) { let Self { status, - bytes, + res_body: bytes, method, uri, + .. } = self; if expected != *status { if status.is_success() || bytes.is_empty() { @@ -234,11 +250,22 @@ impl TestResponse { } #[track_caller] + pub fn assert_accepted_json<'de, T: Deserialize<'de>>(&'de self) -> T { + self.assert_accepted(); + self.json_parse() + } + + #[track_caller] pub fn assert_ok(&self) { self.assert_status(StatusCode::OK); } #[track_caller] + pub fn assert_accepted(&self) { + self.assert_status(StatusCode::ACCEPTED); + } + + #[track_caller] pub fn assert_no_content(&self) { self.assert_status(StatusCode::NO_CONTENT); } diff --git a/taler-cyclos/src/worker.rs b/taler-cyclos/src/worker.rs @@ -352,8 +352,7 @@ impl Worker<'_> { return Ok(()); } match parse_incoming_unstructured(&tx.subject) { - Ok(None) => bounce(self.db, "missing public key").await?, - Ok(Some(subject)) => { + Ok(subject) => { match db::register_tx_in(self.db, &tx, &Some(subject), &Timestamp::now()) .await? { diff --git a/taler-magnet-bank/src/worker.rs b/taler-magnet-bank/src/worker.rs @@ -247,8 +247,7 @@ impl Worker<'_> { match self.account_type { AccountType::Exchange => { match parse_incoming_unstructured(&tx_in.subject) { - Ok(None) => bounce(self.db, "missing public key").await?, - Ok(Some(subject)) => match db::register_tx_in( + Ok(subject) => match db::register_tx_in( self.db, &tx_in, &Some(subject),