extract.rs (6996B)
1 /* 2 This file is part of TALER 3 Copyright (C) 2025, 2026 Taler Systems SA 4 5 TALER is free software; you can redistribute it and/or modify it under the 6 terms of the GNU Affero General Public License as published by the Free Software 7 Foundation; either version 3, or (at your option) any later version. 8 9 TALER is distributed in the hope that it will be useful, but WITHOUT ANY 10 WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR 11 A PARTICULAR PURPOSE. See the GNU Affero General Public License for more details. 12 13 You should have received a copy of the GNU Affero General Public License along with 14 TALER; see the file COPYING. If not, see <http://www.gnu.org/licenses/> 15 */ 16 17 use axum::{ 18 body::{Body, Bytes, HttpBody}, 19 extract::{FromRequest, FromRequestParts, Request}, 20 http::{ 21 HeaderMap, StatusCode, 22 header::{self, CONTENT_TYPE}, 23 request::Parts, 24 }, 25 }; 26 use http_body_util::BodyExt as _; 27 use serde::de::DeserializeOwned; 28 use taler_common::error_code::ErrorCode; 29 use tracing::trace; 30 use url::form_urlencoded; 31 use zlib_rs::{InflateConfig, ReturnCode}; 32 33 use crate::{ 34 constants::MAX_BODY_LENGTH, 35 error::{ApiError, ApiResult, failure, failure_status}, 36 }; 37 38 pub async fn decompressed_strict_body(headers: &HeaderMap, body: Body) -> ApiResult<Bytes> { 39 // Check content type 40 match headers.get(header::CONTENT_TYPE) { 41 Some(header) => { 42 if !header.as_bytes().starts_with(b"application/json") { 43 return Err(failure_status( 44 ErrorCode::GENERIC_HTTP_HEADERS_MALFORMED, 45 "Bad Content-Type header", 46 StatusCode::UNSUPPORTED_MEDIA_TYPE, 47 )); 48 } 49 } 50 None => { 51 return Err(failure_status( 52 ErrorCode::GENERIC_HTTP_HEADERS_MALFORMED, 53 "Missing Content-Type header", 54 StatusCode::UNSUPPORTED_MEDIA_TYPE, 55 )); 56 } 57 } 58 59 // Check content length if present and well formed 60 if let Some(length) = headers 61 .get(header::CONTENT_LENGTH) 62 .and_then(|it| it.to_str().ok()) 63 .and_then(|it| it.parse::<usize>().ok()) 64 && length > MAX_BODY_LENGTH 65 { 66 return Err(failure( 67 ErrorCode::GENERIC_UPLOAD_EXCEEDS_LIMIT, 68 format!("Body is suspiciously big > {MAX_BODY_LENGTH}B"), 69 )); 70 } 71 72 // Check compression 73 let compressed = if let Some(encoding) = headers.get(header::CONTENT_ENCODING) { 74 if encoding == "deflate" { 75 true 76 } else { 77 return Err(failure_status( 78 ErrorCode::GENERIC_HTTP_HEADERS_MALFORMED, 79 format!( 80 "Unsupported encoding '{}'", 81 String::from_utf8_lossy(encoding.as_bytes()) 82 ), 83 StatusCode::UNSUPPORTED_MEDIA_TYPE, 84 )); 85 } 86 } else { 87 false 88 }; 89 90 // Buffer body 91 let body = http_body_util::Limited::new(body, MAX_BODY_LENGTH); 92 let bytes = match body.collect().await { 93 Ok(chunks) => chunks.to_bytes(), 94 Err(it) => match it.downcast::<http_body_util::LengthLimitError>() { 95 Ok(_) => { 96 return Err(failure( 97 ErrorCode::GENERIC_UPLOAD_EXCEEDS_LIMIT, 98 format!("Body is suspiciously big > {MAX_BODY_LENGTH}B"), 99 )); 100 } 101 Err(err) => { 102 return Err(failure( 103 ErrorCode::GENERIC_UNEXPECTED_REQUEST_ERROR, 104 format!("Failed to read body: {err}"), 105 )); 106 } 107 }, 108 }; 109 110 let bytes = if compressed { 111 let mut buf = [0; MAX_BODY_LENGTH]; 112 let (decompressed, code) = 113 zlib_rs::decompress_slice(&mut buf, &bytes, InflateConfig::default()); 114 match code { 115 ReturnCode::Ok => Bytes::copy_from_slice(decompressed), 116 ReturnCode::BufError => { 117 return Err(failure( 118 ErrorCode::GENERIC_UPLOAD_EXCEEDS_LIMIT, 119 format!("Decompressed body is suspiciously big > {MAX_BODY_LENGTH}B"), 120 )); 121 } 122 _ => { 123 return Err(failure( 124 ErrorCode::GENERIC_COMPRESSION_INVALID, 125 "Failed to decompress body: invalid compression", 126 )); 127 } 128 } 129 } else { 130 bytes 131 }; 132 trace!(target: "api", "req {}", String::from_utf8_lossy(&bytes)); 133 Ok(bytes) 134 } 135 136 #[derive(Debug, Clone, Copy, Default)] 137 #[must_use] 138 pub struct Req<T>(pub T); 139 140 impl<T, S> FromRequest<S> for Req<T> 141 where 142 T: DeserializeOwned, 143 S: Send + Sync, 144 { 145 type Rejection = ApiError; 146 147 async fn from_request(req: Request, _state: &S) -> Result<Self, Self::Rejection> { 148 let (parts, body) = req.into_parts(); 149 let bytes = decompressed_strict_body(&parts.headers, body).await?; 150 Self::try_from(&bytes) 151 } 152 } 153 154 impl<T: DeserializeOwned> TryFrom<&Bytes> for Req<T> { 155 type Error = ApiError; 156 157 fn try_from(value: &Bytes) -> Result<Self, Self::Error> { 158 let mut de = serde_json::de::Deserializer::from_slice(value); 159 let parsed = serde_path_to_error::deserialize(&mut de)?; 160 Ok(Req(parsed)) 161 } 162 } 163 164 #[derive(Debug, Clone, Copy, Default)] 165 #[must_use] 166 pub struct OptReq<T>(pub Option<T>); 167 168 impl<T, S> FromRequest<S> for OptReq<T> 169 where 170 T: DeserializeOwned, 171 S: Send + Sync, 172 { 173 type Rejection = ApiError; 174 175 async fn from_request(req: Request, _state: &S) -> Result<Self, Self::Rejection> { 176 let (parts, body) = req.into_parts(); 177 if !parts.headers.contains_key(CONTENT_TYPE) && body.size_hint().exact() == Some(0) { 178 Ok(Self(None)) 179 } else { 180 let bytes = decompressed_strict_body(&parts.headers, body).await?; 181 let req = Req::try_from(&bytes)?; 182 Ok(Self(Some(req.0))) 183 } 184 } 185 } 186 187 #[derive(Debug, Clone, Copy, Default)] 188 pub struct Path<T: DeserializeOwned + Send>(pub T); 189 190 impl<T: serde::de::DeserializeOwned + Send, S: Sync + Send> FromRequestParts<S> for Path<T> { 191 type Rejection = ApiError; 192 193 async fn from_request_parts(parts: &mut Parts, _: &S) -> Result<Self, Self::Rejection> { 194 Ok(Self( 195 axum::extract::Path::from_request_parts(parts, &()).await?.0, 196 )) 197 } 198 } 199 200 #[derive(Debug, Clone, Copy, Default)] 201 pub struct Query<T>(pub T); 202 203 impl<T: serde::de::DeserializeOwned, S: Sync + Send> FromRequestParts<S> for Query<T> { 204 type Rejection = ApiError; 205 206 async fn from_request_parts(parts: &mut Parts, _: &S) -> Result<Self, Self::Rejection> { 207 let query = parts.uri.query().unwrap_or_default(); 208 let deserializer = 209 serde_urlencoded::Deserializer::new(form_urlencoded::parse(query.as_bytes())); 210 let params = serde_path_to_error::deserialize(deserializer)?; 211 Ok(Query(params)) 212 } 213 }