extract.rs (6180B)
1 /* 2 This file is part of TALER 3 Copyright (C) 2025, 2026 Taler Systems SA 4 5 TALER is free software; you can redistribute it and/or modify it under the 6 terms of the GNU Affero General Public License as published by the Free Software 7 Foundation; either version 3, or (at your option) any later version. 8 9 TALER is distributed in the hope that it will be useful, but WITHOUT ANY 10 WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR 11 A PARTICULAR PURPOSE. See the GNU Affero General Public License for more details. 12 13 You should have received a copy of the GNU Affero General Public License along with 14 TALER; see the file COPYING. If not, see <http://www.gnu.org/licenses/> 15 */ 16 17 use axum::{ 18 body::{Body, Bytes}, 19 extract::{FromRequest, FromRequestParts, Request}, 20 http::{HeaderMap, StatusCode, header, request::Parts}, 21 }; 22 use http_body_util::BodyExt as _; 23 use serde::de::DeserializeOwned; 24 use taler_common::error_code::ErrorCode; 25 use url::form_urlencoded; 26 use zlib_rs::{InflateConfig, ReturnCode}; 27 28 use crate::{ 29 constants::MAX_BODY_LENGTH, 30 error::{ApiError, ApiResult, failure, failure_status}, 31 }; 32 33 pub async fn decompressed_strict_body(headers: &HeaderMap, body: Body) -> ApiResult<Bytes> { 34 // Check content type 35 match headers.get(header::CONTENT_TYPE) { 36 Some(header) => { 37 if !header.as_bytes().starts_with(b"application/json") { 38 return Err(failure_status( 39 ErrorCode::GENERIC_HTTP_HEADERS_MALFORMED, 40 "Bad Content-Type header", 41 StatusCode::UNSUPPORTED_MEDIA_TYPE, 42 )); 43 } 44 } 45 None => { 46 return Err(failure_status( 47 ErrorCode::GENERIC_HTTP_HEADERS_MALFORMED, 48 "Missing Content-Type header", 49 StatusCode::UNSUPPORTED_MEDIA_TYPE, 50 )); 51 } 52 } 53 54 // Check content length if present and well formed 55 if let Some(length) = headers 56 .get(header::CONTENT_LENGTH) 57 .and_then(|it| it.to_str().ok()) 58 .and_then(|it| it.parse::<usize>().ok()) 59 && length > MAX_BODY_LENGTH 60 { 61 return Err(failure( 62 ErrorCode::GENERIC_UPLOAD_EXCEEDS_LIMIT, 63 format!("Body is suspiciously big > {MAX_BODY_LENGTH}B"), 64 )); 65 } 66 67 // Check compression 68 let compressed = if let Some(encoding) = headers.get(header::CONTENT_ENCODING) { 69 if encoding == "deflate" { 70 true 71 } else { 72 return Err(failure_status( 73 ErrorCode::GENERIC_HTTP_HEADERS_MALFORMED, 74 format!( 75 "Unsupported encoding '{}'", 76 String::from_utf8_lossy(encoding.as_bytes()) 77 ), 78 StatusCode::UNSUPPORTED_MEDIA_TYPE, 79 )); 80 } 81 } else { 82 false 83 }; 84 85 // Buffer body 86 let body = http_body_util::Limited::new(body, MAX_BODY_LENGTH); 87 let bytes = match body.collect().await { 88 Ok(chunks) => chunks.to_bytes(), 89 Err(it) => match it.downcast::<http_body_util::LengthLimitError>() { 90 Ok(_) => { 91 return Err(failure( 92 ErrorCode::GENERIC_UPLOAD_EXCEEDS_LIMIT, 93 format!("Body is suspiciously big > {MAX_BODY_LENGTH}B"), 94 )); 95 } 96 Err(err) => { 97 return Err(failure( 98 ErrorCode::GENERIC_UNEXPECTED_REQUEST_ERROR, 99 format!("Failed to read body: {err}"), 100 )); 101 } 102 }, 103 }; 104 105 let bytes = if compressed { 106 let mut buf = [0; MAX_BODY_LENGTH]; 107 let (decompressed, code) = 108 zlib_rs::decompress_slice(&mut buf, &bytes, InflateConfig::default()); 109 match code { 110 ReturnCode::Ok => Bytes::copy_from_slice(decompressed), 111 ReturnCode::BufError => { 112 return Err(failure( 113 ErrorCode::GENERIC_UPLOAD_EXCEEDS_LIMIT, 114 format!("Decompressed body is suspiciously big > {MAX_BODY_LENGTH}B"), 115 )); 116 } 117 _ => { 118 return Err(failure( 119 ErrorCode::GENERIC_COMPRESSION_INVALID, 120 "Failed to decompress body: invalid compression", 121 )); 122 } 123 } 124 } else { 125 bytes 126 }; 127 Ok(bytes) 128 } 129 130 #[derive(Debug, Clone, Copy, Default)] 131 #[must_use] 132 pub struct Req<T>(pub T); 133 134 impl<T, S> FromRequest<S> for Req<T> 135 where 136 T: DeserializeOwned, 137 S: Send + Sync, 138 { 139 type Rejection = ApiError; 140 141 async fn from_request(req: Request, _state: &S) -> Result<Self, Self::Rejection> { 142 let (parts, body) = req.into_parts(); 143 let bytes = decompressed_strict_body(&parts.headers, body).await?; 144 Self::try_from(&bytes) 145 } 146 } 147 148 impl<T: DeserializeOwned> TryFrom<&Bytes> for Req<T> { 149 type Error = ApiError; 150 151 fn try_from(value: &Bytes) -> Result<Self, Self::Error> { 152 let mut de = serde_json::de::Deserializer::from_slice(value); 153 let parsed = serde_path_to_error::deserialize(&mut de)?; 154 Ok(Req(parsed)) 155 } 156 } 157 158 #[derive(Debug, Clone, Copy, Default)] 159 pub struct Path<T: DeserializeOwned + Send>(pub T); 160 161 impl<T: serde::de::DeserializeOwned + Send, S: Sync + Send> FromRequestParts<S> for Path<T> { 162 type Rejection = ApiError; 163 164 async fn from_request_parts(parts: &mut Parts, _: &S) -> Result<Self, Self::Rejection> { 165 Ok(Self( 166 axum::extract::Path::from_request_parts(parts, &()).await?.0, 167 )) 168 } 169 } 170 171 #[derive(Debug, Clone, Copy, Default)] 172 pub struct Query<T>(pub T); 173 174 impl<T: serde::de::DeserializeOwned, S: Sync + Send> FromRequestParts<S> for Query<T> { 175 type Rejection = ApiError; 176 177 async fn from_request_parts(parts: &mut Parts, _: &S) -> Result<Self, Self::Rejection> { 178 let query = parts.uri.query().unwrap_or_default(); 179 let deserializer = 180 serde_urlencoded::Deserializer::new(form_urlencoded::parse(query.as_bytes())); 181 let params = serde_path_to_error::deserialize(deserializer)?; 182 Ok(Query(params)) 183 } 184 }