taler-rust

GNU Taler code in Rust. Largely core banking integrations.
Log | Files | Refs | Submodules | README | LICENSE

extract.rs (6996B)


      1 /*
      2   This file is part of TALER
      3   Copyright (C) 2025, 2026 Taler Systems SA
      4 
      5   TALER is free software; you can redistribute it and/or modify it under the
      6   terms of the GNU Affero General Public License as published by the Free Software
      7   Foundation; either version 3, or (at your option) any later version.
      8 
      9   TALER is distributed in the hope that it will be useful, but WITHOUT ANY
     10   WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
     11   A PARTICULAR PURPOSE.  See the GNU Affero General Public License for more details.
     12 
     13   You should have received a copy of the GNU Affero General Public License along with
     14   TALER; see the file COPYING.  If not, see <http://www.gnu.org/licenses/>
     15 */
     16 
     17 use axum::{
     18     body::{Body, Bytes, HttpBody},
     19     extract::{FromRequest, FromRequestParts, Request},
     20     http::{
     21         HeaderMap, StatusCode,
     22         header::{self, CONTENT_TYPE},
     23         request::Parts,
     24     },
     25 };
     26 use http_body_util::BodyExt as _;
     27 use serde::de::DeserializeOwned;
     28 use taler_common::error_code::ErrorCode;
     29 use tracing::trace;
     30 use url::form_urlencoded;
     31 use zlib_rs::{InflateConfig, ReturnCode};
     32 
     33 use crate::{
     34     constants::MAX_BODY_LENGTH,
     35     error::{ApiError, ApiResult, failure, failure_status},
     36 };
     37 
     38 pub async fn decompressed_strict_body(headers: &HeaderMap, body: Body) -> ApiResult<Bytes> {
     39     // Check content type
     40     match headers.get(header::CONTENT_TYPE) {
     41         Some(header) => {
     42             if !header.as_bytes().starts_with(b"application/json") {
     43                 return Err(failure_status(
     44                     ErrorCode::GENERIC_HTTP_HEADERS_MALFORMED,
     45                     "Bad Content-Type header",
     46                     StatusCode::UNSUPPORTED_MEDIA_TYPE,
     47                 ));
     48             }
     49         }
     50         None => {
     51             return Err(failure_status(
     52                 ErrorCode::GENERIC_HTTP_HEADERS_MALFORMED,
     53                 "Missing Content-Type header",
     54                 StatusCode::UNSUPPORTED_MEDIA_TYPE,
     55             ));
     56         }
     57     }
     58 
     59     // Check content length if present and well formed
     60     if let Some(length) = headers
     61         .get(header::CONTENT_LENGTH)
     62         .and_then(|it| it.to_str().ok())
     63         .and_then(|it| it.parse::<usize>().ok())
     64         && length > MAX_BODY_LENGTH
     65     {
     66         return Err(failure(
     67             ErrorCode::GENERIC_UPLOAD_EXCEEDS_LIMIT,
     68             format!("Body is suspiciously big > {MAX_BODY_LENGTH}B"),
     69         ));
     70     }
     71 
     72     // Check compression
     73     let compressed = if let Some(encoding) = headers.get(header::CONTENT_ENCODING) {
     74         if encoding == "deflate" {
     75             true
     76         } else {
     77             return Err(failure_status(
     78                 ErrorCode::GENERIC_HTTP_HEADERS_MALFORMED,
     79                 format!(
     80                     "Unsupported encoding '{}'",
     81                     String::from_utf8_lossy(encoding.as_bytes())
     82                 ),
     83                 StatusCode::UNSUPPORTED_MEDIA_TYPE,
     84             ));
     85         }
     86     } else {
     87         false
     88     };
     89 
     90     // Buffer body
     91     let body = http_body_util::Limited::new(body, MAX_BODY_LENGTH);
     92     let bytes = match body.collect().await {
     93         Ok(chunks) => chunks.to_bytes(),
     94         Err(it) => match it.downcast::<http_body_util::LengthLimitError>() {
     95             Ok(_) => {
     96                 return Err(failure(
     97                     ErrorCode::GENERIC_UPLOAD_EXCEEDS_LIMIT,
     98                     format!("Body is suspiciously big > {MAX_BODY_LENGTH}B"),
     99                 ));
    100             }
    101             Err(err) => {
    102                 return Err(failure(
    103                     ErrorCode::GENERIC_UNEXPECTED_REQUEST_ERROR,
    104                     format!("Failed to read body: {err}"),
    105                 ));
    106             }
    107         },
    108     };
    109 
    110     let bytes = if compressed {
    111         let mut buf = [0; MAX_BODY_LENGTH];
    112         let (decompressed, code) =
    113             zlib_rs::decompress_slice(&mut buf, &bytes, InflateConfig::default());
    114         match code {
    115             ReturnCode::Ok => Bytes::copy_from_slice(decompressed),
    116             ReturnCode::BufError => {
    117                 return Err(failure(
    118                     ErrorCode::GENERIC_UPLOAD_EXCEEDS_LIMIT,
    119                     format!("Decompressed body is suspiciously big > {MAX_BODY_LENGTH}B"),
    120                 ));
    121             }
    122             _ => {
    123                 return Err(failure(
    124                     ErrorCode::GENERIC_COMPRESSION_INVALID,
    125                     "Failed to decompress body: invalid compression",
    126                 ));
    127             }
    128         }
    129     } else {
    130         bytes
    131     };
    132     trace!(target: "api", "req {}", String::from_utf8_lossy(&bytes));
    133     Ok(bytes)
    134 }
    135 
    136 #[derive(Debug, Clone, Copy, Default)]
    137 #[must_use]
    138 pub struct Req<T>(pub T);
    139 
    140 impl<T, S> FromRequest<S> for Req<T>
    141 where
    142     T: DeserializeOwned,
    143     S: Send + Sync,
    144 {
    145     type Rejection = ApiError;
    146 
    147     async fn from_request(req: Request, _state: &S) -> Result<Self, Self::Rejection> {
    148         let (parts, body) = req.into_parts();
    149         let bytes = decompressed_strict_body(&parts.headers, body).await?;
    150         Self::try_from(&bytes)
    151     }
    152 }
    153 
    154 impl<T: DeserializeOwned> TryFrom<&Bytes> for Req<T> {
    155     type Error = ApiError;
    156 
    157     fn try_from(value: &Bytes) -> Result<Self, Self::Error> {
    158         let mut de = serde_json::de::Deserializer::from_slice(value);
    159         let parsed = serde_path_to_error::deserialize(&mut de)?;
    160         Ok(Req(parsed))
    161     }
    162 }
    163 
    164 #[derive(Debug, Clone, Copy, Default)]
    165 #[must_use]
    166 pub struct OptReq<T>(pub Option<T>);
    167 
    168 impl<T, S> FromRequest<S> for OptReq<T>
    169 where
    170     T: DeserializeOwned,
    171     S: Send + Sync,
    172 {
    173     type Rejection = ApiError;
    174 
    175     async fn from_request(req: Request, _state: &S) -> Result<Self, Self::Rejection> {
    176         let (parts, body) = req.into_parts();
    177         if !parts.headers.contains_key(CONTENT_TYPE) && body.size_hint().exact() == Some(0) {
    178             Ok(Self(None))
    179         } else {
    180             let bytes = decompressed_strict_body(&parts.headers, body).await?;
    181             let req = Req::try_from(&bytes)?;
    182             Ok(Self(Some(req.0)))
    183         }
    184     }
    185 }
    186 
    187 #[derive(Debug, Clone, Copy, Default)]
    188 pub struct Path<T: DeserializeOwned + Send>(pub T);
    189 
    190 impl<T: serde::de::DeserializeOwned + Send, S: Sync + Send> FromRequestParts<S> for Path<T> {
    191     type Rejection = ApiError;
    192 
    193     async fn from_request_parts(parts: &mut Parts, _: &S) -> Result<Self, Self::Rejection> {
    194         Ok(Self(
    195             axum::extract::Path::from_request_parts(parts, &()).await?.0,
    196         ))
    197     }
    198 }
    199 
    200 #[derive(Debug, Clone, Copy, Default)]
    201 pub struct Query<T>(pub T);
    202 
    203 impl<T: serde::de::DeserializeOwned, S: Sync + Send> FromRequestParts<S> for Query<T> {
    204     type Rejection = ApiError;
    205 
    206     async fn from_request_parts(parts: &mut Parts, _: &S) -> Result<Self, Self::Rejection> {
    207         let query = parts.uri.query().unwrap_or_default();
    208         let deserializer =
    209             serde_urlencoded::Deserializer::new(form_urlencoded::parse(query.as_bytes()));
    210         let params = serde_path_to_error::deserialize(deserializer)?;
    211         Ok(Query(params))
    212     }
    213 }