taler-rust

GNU Taler code in Rust. Largely core banking integrations.
Log | Files | Refs | Submodules | README | LICENSE

extract.rs (6180B)


      1 /*
      2   This file is part of TALER
      3   Copyright (C) 2025, 2026 Taler Systems SA
      4 
      5   TALER is free software; you can redistribute it and/or modify it under the
      6   terms of the GNU Affero General Public License as published by the Free Software
      7   Foundation; either version 3, or (at your option) any later version.
      8 
      9   TALER is distributed in the hope that it will be useful, but WITHOUT ANY
     10   WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
     11   A PARTICULAR PURPOSE.  See the GNU Affero General Public License for more details.
     12 
     13   You should have received a copy of the GNU Affero General Public License along with
     14   TALER; see the file COPYING.  If not, see <http://www.gnu.org/licenses/>
     15 */
     16 
     17 use axum::{
     18     body::{Body, Bytes},
     19     extract::{FromRequest, FromRequestParts, Request},
     20     http::{HeaderMap, StatusCode, header, request::Parts},
     21 };
     22 use http_body_util::BodyExt as _;
     23 use serde::de::DeserializeOwned;
     24 use taler_common::error_code::ErrorCode;
     25 use url::form_urlencoded;
     26 use zlib_rs::{InflateConfig, ReturnCode};
     27 
     28 use crate::{
     29     constants::MAX_BODY_LENGTH,
     30     error::{ApiError, ApiResult, failure, failure_status},
     31 };
     32 
     33 pub async fn decompressed_strict_body(headers: &HeaderMap, body: Body) -> ApiResult<Bytes> {
     34     // Check content type
     35     match headers.get(header::CONTENT_TYPE) {
     36         Some(header) => {
     37             if !header.as_bytes().starts_with(b"application/json") {
     38                 return Err(failure_status(
     39                     ErrorCode::GENERIC_HTTP_HEADERS_MALFORMED,
     40                     "Bad Content-Type header",
     41                     StatusCode::UNSUPPORTED_MEDIA_TYPE,
     42                 ));
     43             }
     44         }
     45         None => {
     46             return Err(failure_status(
     47                 ErrorCode::GENERIC_HTTP_HEADERS_MALFORMED,
     48                 "Missing Content-Type header",
     49                 StatusCode::UNSUPPORTED_MEDIA_TYPE,
     50             ));
     51         }
     52     }
     53 
     54     // Check content length if present and well formed
     55     if let Some(length) = headers
     56         .get(header::CONTENT_LENGTH)
     57         .and_then(|it| it.to_str().ok())
     58         .and_then(|it| it.parse::<usize>().ok())
     59         && length > MAX_BODY_LENGTH
     60     {
     61         return Err(failure(
     62             ErrorCode::GENERIC_UPLOAD_EXCEEDS_LIMIT,
     63             format!("Body is suspiciously big > {MAX_BODY_LENGTH}B"),
     64         ));
     65     }
     66 
     67     // Check compression
     68     let compressed = if let Some(encoding) = headers.get(header::CONTENT_ENCODING) {
     69         if encoding == "deflate" {
     70             true
     71         } else {
     72             return Err(failure_status(
     73                 ErrorCode::GENERIC_HTTP_HEADERS_MALFORMED,
     74                 format!(
     75                     "Unsupported encoding '{}'",
     76                     String::from_utf8_lossy(encoding.as_bytes())
     77                 ),
     78                 StatusCode::UNSUPPORTED_MEDIA_TYPE,
     79             ));
     80         }
     81     } else {
     82         false
     83     };
     84 
     85     // Buffer body
     86     let body = http_body_util::Limited::new(body, MAX_BODY_LENGTH);
     87     let bytes = match body.collect().await {
     88         Ok(chunks) => chunks.to_bytes(),
     89         Err(it) => match it.downcast::<http_body_util::LengthLimitError>() {
     90             Ok(_) => {
     91                 return Err(failure(
     92                     ErrorCode::GENERIC_UPLOAD_EXCEEDS_LIMIT,
     93                     format!("Body is suspiciously big > {MAX_BODY_LENGTH}B"),
     94                 ));
     95             }
     96             Err(err) => {
     97                 return Err(failure(
     98                     ErrorCode::GENERIC_UNEXPECTED_REQUEST_ERROR,
     99                     format!("Failed to read body: {err}"),
    100                 ));
    101             }
    102         },
    103     };
    104 
    105     let bytes = if compressed {
    106         let mut buf = [0; MAX_BODY_LENGTH];
    107         let (decompressed, code) =
    108             zlib_rs::decompress_slice(&mut buf, &bytes, InflateConfig::default());
    109         match code {
    110             ReturnCode::Ok => Bytes::copy_from_slice(decompressed),
    111             ReturnCode::BufError => {
    112                 return Err(failure(
    113                     ErrorCode::GENERIC_UPLOAD_EXCEEDS_LIMIT,
    114                     format!("Decompressed body is suspiciously big > {MAX_BODY_LENGTH}B"),
    115                 ));
    116             }
    117             _ => {
    118                 return Err(failure(
    119                     ErrorCode::GENERIC_COMPRESSION_INVALID,
    120                     "Failed to decompress body: invalid compression",
    121                 ));
    122             }
    123         }
    124     } else {
    125         bytes
    126     };
    127     Ok(bytes)
    128 }
    129 
    130 #[derive(Debug, Clone, Copy, Default)]
    131 #[must_use]
    132 pub struct Req<T>(pub T);
    133 
    134 impl<T, S> FromRequest<S> for Req<T>
    135 where
    136     T: DeserializeOwned,
    137     S: Send + Sync,
    138 {
    139     type Rejection = ApiError;
    140 
    141     async fn from_request(req: Request, _state: &S) -> Result<Self, Self::Rejection> {
    142         let (parts, body) = req.into_parts();
    143         let bytes = decompressed_strict_body(&parts.headers, body).await?;
    144         Self::try_from(&bytes)
    145     }
    146 }
    147 
    148 impl<T: DeserializeOwned> TryFrom<&Bytes> for Req<T> {
    149     type Error = ApiError;
    150 
    151     fn try_from(value: &Bytes) -> Result<Self, Self::Error> {
    152         let mut de = serde_json::de::Deserializer::from_slice(value);
    153         let parsed = serde_path_to_error::deserialize(&mut de)?;
    154         Ok(Req(parsed))
    155     }
    156 }
    157 
    158 #[derive(Debug, Clone, Copy, Default)]
    159 pub struct Path<T: DeserializeOwned + Send>(pub T);
    160 
    161 impl<T: serde::de::DeserializeOwned + Send, S: Sync + Send> FromRequestParts<S> for Path<T> {
    162     type Rejection = ApiError;
    163 
    164     async fn from_request_parts(parts: &mut Parts, _: &S) -> Result<Self, Self::Rejection> {
    165         Ok(Self(
    166             axum::extract::Path::from_request_parts(parts, &()).await?.0,
    167         ))
    168     }
    169 }
    170 
    171 #[derive(Debug, Clone, Copy, Default)]
    172 pub struct Query<T>(pub T);
    173 
    174 impl<T: serde::de::DeserializeOwned, S: Sync + Send> FromRequestParts<S> for Query<T> {
    175     type Rejection = ApiError;
    176 
    177     async fn from_request_parts(parts: &mut Parts, _: &S) -> Result<Self, Self::Rejection> {
    178         let query = parts.uri.query().unwrap_or_default();
    179         let deserializer =
    180             serde_urlencoded::Deserializer::new(form_urlencoded::parse(query.as_bytes()));
    181         let params = serde_path_to_error::deserialize(deserializer)?;
    182         Ok(Query(params))
    183     }
    184 }