db.rs (12565B)
1 /* 2 This file is part of TALER 3 Copyright (C) 2024, 2025, 2026 Taler Systems SA 4 5 TALER is free software; you can redistribute it and/or modify it under the 6 terms of the GNU Affero General Public License as published by the Free Software 7 Foundation; either version 3, or (at your option) any later version. 8 9 TALER is distributed in the hope that it will be useful, but WITHOUT ANY 10 WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR 11 A PARTICULAR PURPOSE. See the GNU Affero General Public License for more details. 12 13 You should have received a copy of the GNU Affero General Public License along with 14 TALER; see the file COPYING. If not, see <http://www.gnu.org/licenses/> 15 */ 16 17 use std::{str::FromStr, time::Duration}; 18 19 use jiff::{ 20 Timestamp, 21 civil::{Date, Time}, 22 tz::TimeZone, 23 }; 24 use sqlx::{ 25 Decode, Error, PgPool, Postgres, QueryBuilder, Row, Type, 26 error::BoxDynError, 27 postgres::PgRow, 28 query::{Query, QueryScalar}, 29 }; 30 use taler_common::{ 31 api::params::{History, Page, Pooling}, 32 types::{ 33 amount::{Amount, Currency, Decimal}, 34 iban::IBAN, 35 payto::PaytoURI, 36 timestamp::TalerTimestamp, 37 utils::date_to_utc_ts, 38 }, 39 }; 40 use tokio::sync::watch::{self}; 41 use url::Url; 42 43 pub type PgQueryBuilder<'b> = QueryBuilder<'b, Postgres>; 44 45 /* ------ Serialization ----- */ 46 47 pub trait PgError { 48 const PG_SERIALIZATION_FAILURE: &str = "40001"; 49 const PG_DEADLOCK_DETECTED: &str = "40P01"; 50 const PG_UNIQUE_VIOLATION: &str = "23505"; 51 const PG_FOREIGN_KEY_VIOLATION: &str = "23503"; 52 53 fn is_retryable_err(&self) -> bool; 54 fn is_unique_err(&self) -> bool; 55 fn is_fk_err(&self) -> bool; 56 } 57 58 impl PgError for sqlx::error::Error { 59 fn is_retryable_err(&self) -> bool { 60 if let sqlx::Error::Database(e) = self { 61 return matches!( 62 e.downcast_ref::<sqlx::postgres::PgDatabaseError>().code(), 63 Self::PG_SERIALIZATION_FAILURE | Self::PG_DEADLOCK_DETECTED 64 ); 65 } 66 false 67 } 68 69 fn is_unique_err(&self) -> bool { 70 if let sqlx::Error::Database(e) = self { 71 return e.downcast_ref::<sqlx::postgres::PgDatabaseError>().code() 72 == Self::PG_UNIQUE_VIOLATION; 73 } 74 false 75 } 76 77 fn is_fk_err(&self) -> bool { 78 if let sqlx::Error::Database(e) = self { 79 return e.downcast_ref::<sqlx::postgres::PgDatabaseError>().code() 80 == Self::PG_FOREIGN_KEY_VIOLATION; 81 } 82 false 83 } 84 } 85 86 #[macro_export] 87 macro_rules! serialized { 88 ($logic:expr) => {{ 89 use $crate::db::PgError; 90 let mut attempts = 0; 91 const MAX_RETRIES: u32 = 5; 92 93 loop { 94 let res: sqlx::Result<_, sqlx::Error> = $logic.await; 95 if let Err(e) = &res 96 && e.is_retryable_err() 97 && attempts < MAX_RETRIES 98 { 99 attempts += 1; 100 tokio::task::yield_now().await; 101 continue; 102 } 103 break res; 104 } 105 }}; 106 } 107 108 /* ----- Routines ------ */ 109 110 pub async fn page<'a, 'b, R: Send + Unpin>( 111 db: &PgPool, 112 params: &Page, 113 id_col: &str, 114 prepare: impl Fn() -> QueryBuilder<'a, Postgres> + Copy, 115 map: impl Fn(PgRow) -> Result<R, Error> + Send + Copy, 116 ) -> Result<Vec<R>, Error> { 117 serialized!(async { 118 let mut builder = prepare(); 119 if let Some(offset) = params.offset { 120 builder 121 .push(format_args!( 122 " {id_col} {}", 123 if params.backward() { '<' } else { '>' } 124 )) 125 .push_bind(offset); 126 } else { 127 builder.push("TRUE"); 128 } 129 builder.push(format_args!( 130 " ORDER BY {id_col} {} LIMIT ", 131 if params.backward() { "DESC" } else { "ASC" } 132 )); 133 builder 134 .push_bind(params.limit.abs()) 135 .build() 136 .try_map(map) 137 .fetch_all(db) 138 .await 139 }) 140 } 141 142 pub async fn pooling<R, N, F: Future<Output = sqlx::Result<R>>>( 143 params: &Pooling, 144 listen: impl FnOnce() -> watch::Receiver<N>, 145 filter: impl FnMut(&N) -> bool, 146 mut load: impl FnMut() -> F, 147 mut check: impl FnMut(&R) -> bool, 148 ) -> Result<R, Error> { 149 let timeout = params.timeout_ms.unwrap_or_default(); 150 if timeout > 0 { 151 let mut listener = listen(); 152 let init = load().await?; 153 // Long polling if we found no transactions 154 if !check(&init) { 155 let pooling = tokio::time::timeout(Duration::from_millis(timeout), async { 156 listener.wait_for(filter).await.ok(); 157 }) 158 .await; 159 match pooling { 160 Ok(_) => load().await, 161 Err(_) => Ok(init), 162 } 163 } else { 164 Ok(init) 165 } 166 } else { 167 load().await 168 } 169 } 170 171 pub async fn history<T: Send + Unpin>( 172 db: &PgPool, 173 id_col: &str, 174 params: &History, 175 listen: impl FnOnce() -> watch::Receiver<i64>, 176 prepare: impl Fn() -> QueryBuilder<'static, Postgres> + Copy, 177 map: impl Fn(PgRow) -> Result<T, Error> + Send + Copy, 178 ) -> Result<Vec<T>, Error> { 179 let load = async || page(db, ¶ms.page, id_col, prepare, map).await; 180 // When going backward there is always at least one transaction or none 181 let poll = if params.page.limit < 0 { 182 &Pooling::default() 183 } else { 184 ¶ms.pooling 185 }; 186 pooling( 187 poll, 188 listen, 189 |id| params.page.offset.is_none_or(|offset| *id > offset), 190 load, 191 |init| !init.is_empty(), 192 ) 193 .await 194 } 195 196 /* ----- Bind ----- */ 197 198 pub trait BindHelper { 199 fn bind_timestamp(self, timestamp: &Timestamp) -> Self; 200 fn bind_date(self, date: &Date) -> Self; 201 } 202 203 impl<'q> BindHelper for Query<'q, Postgres, <Postgres as sqlx::Database>::Arguments<'q>> { 204 fn bind_timestamp(self, timestamp: &Timestamp) -> Self { 205 self.bind(timestamp.as_microsecond()) 206 } 207 208 fn bind_date(self, date: &Date) -> Self { 209 self.bind_timestamp(&date_to_utc_ts(date)) 210 } 211 } 212 213 impl<'q, T> BindHelper 214 for QueryScalar<'q, Postgres, T, <Postgres as sqlx::Database>::Arguments<'q>> 215 { 216 fn bind_timestamp(self, timestamp: &Timestamp) -> Self { 217 self.bind(timestamp.as_microsecond()) 218 } 219 220 fn bind_date(self, date: &Date) -> Self { 221 self.bind_timestamp(&date_to_utc_ts(date)) 222 } 223 } 224 225 /* ----- Get ----- */ 226 227 pub trait TypeHelper { 228 fn try_get_map< 229 'r, 230 I: sqlx::ColumnIndex<Self>, 231 T: Decode<'r, Postgres> + Type<Postgres>, 232 E: Into<BoxDynError>, 233 R, 234 M: FnOnce(T) -> Result<R, E>, 235 >( 236 &'r self, 237 index: I, 238 map: M, 239 ) -> sqlx::Result<R>; 240 fn try_get_opt_map< 241 'r, 242 I: sqlx::ColumnIndex<Self>, 243 T: Decode<'r, Postgres> + Type<Postgres>, 244 E: Into<BoxDynError>, 245 R, 246 M: FnOnce(T) -> Result<R, E>, 247 >( 248 &'r self, 249 index: I, 250 map: M, 251 ) -> sqlx::Result<Option<R>> { 252 self.try_get_map(index, |it: Option<T>| it.map(map).transpose()) 253 } 254 fn try_get_parse<I: sqlx::ColumnIndex<Self>, E: Into<BoxDynError>, T: FromStr<Err = E>>( 255 &self, 256 index: I, 257 ) -> sqlx::Result<T> { 258 self.try_get_map(index, |s: &str| s.parse()) 259 } 260 fn try_get_opt_parse<I: sqlx::ColumnIndex<Self>, E: Into<BoxDynError>, T: FromStr<Err = E>>( 261 &self, 262 index: I, 263 ) -> sqlx::Result<Option<T>> { 264 self.try_get_map(index, |s: Option<&str>| s.map(|s| s.parse()).transpose()) 265 } 266 fn try_get_timestamp<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<Timestamp> { 267 self.try_get_map(index, |micros| { 268 jiff::Timestamp::from_microsecond(micros) 269 .map_err(|e| format!("expected timestamp micros got overflowing {micros}: {e}")) 270 }) 271 } 272 fn try_get_opt_timestamp<I: sqlx::ColumnIndex<Self>>( 273 &self, 274 index: I, 275 ) -> sqlx::Result<Option<Timestamp>> { 276 self.try_get_map(index, |micros: Option<i64>| { 277 if let Some(micros) = micros { 278 Some(jiff::Timestamp::from_microsecond(micros).map_err(|e| { 279 format!("expected timestamp micros got overflowing {micros}: {e}") 280 })) 281 .transpose() 282 } else { 283 Ok(None) 284 } 285 }) 286 } 287 fn try_get_taler_timestamp<I: sqlx::ColumnIndex<Self>>( 288 &self, 289 index: I, 290 ) -> sqlx::Result<TalerTimestamp> { 291 self.try_get_map(index, |micros: Option<i64>| match micros { 292 Some(micros) => Ok::<_, String>(TalerTimestamp::Timestamp( 293 jiff::Timestamp::from_microsecond(micros).map_err(|e| { 294 format!("expected timestamp micros got overflowing {micros}: {e}") 295 })?, 296 )), 297 None => Ok(TalerTimestamp::Never), 298 }) 299 } 300 fn try_get_date<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<Date> { 301 let timestamp = self.try_get_timestamp(index)?; 302 let zoned = timestamp.to_zoned(TimeZone::UTC); 303 assert_eq!(zoned.time(), Time::midnight()); 304 Ok(zoned.date()) 305 } 306 fn try_get_u16<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<u16> { 307 self.try_get_map(index, |signed: i16| signed.try_into()) 308 } 309 fn try_get_opt_u16<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<Option<u16>> { 310 self.try_get_opt_map(index, |signed: i16| signed.try_into()) 311 } 312 fn try_get_u32<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<u32> { 313 self.try_get_map(index, |signed: i32| signed.try_into()) 314 } 315 fn try_get_opt_u32<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<Option<u32>> { 316 self.try_get_opt_map(index, |signed: i32| signed.try_into()) 317 } 318 fn try_get_u64<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<u64> { 319 self.try_get_map(index, |signed: i64| signed.try_into()) 320 } 321 fn try_get_opt_u64<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<Option<u64>> { 322 self.try_get_opt_map(index, |signed: i64| signed.try_into()) 323 } 324 fn try_get_url<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<Url> { 325 self.try_get_parse(index) 326 } 327 fn try_get_payto<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<PaytoURI> { 328 self.try_get_parse(index) 329 } 330 fn try_get_opt_payto<I: sqlx::ColumnIndex<Self>>( 331 &self, 332 index: I, 333 ) -> sqlx::Result<Option<PaytoURI>> { 334 self.try_get_opt_parse(index) 335 } 336 fn try_get_iban<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<IBAN> { 337 self.try_get_parse(index) 338 } 339 fn try_get_amount<I: sqlx::ColumnIndex<Self>>( 340 &self, 341 index: I, 342 currency: &Currency, 343 ) -> sqlx::Result<Amount>; 344 fn try_get_opt_amount<I: sqlx::ColumnIndex<Self>>( 345 &self, 346 index: I, 347 currency: &Currency, 348 ) -> sqlx::Result<Option<Amount>>; 349 350 /** Flag consider NULL and false to be the same */ 351 fn try_get_flag<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<bool>; 352 } 353 354 impl TypeHelper for PgRow { 355 fn try_get_map< 356 'r, 357 I: sqlx::ColumnIndex<Self>, 358 T: Decode<'r, Postgres> + Type<Postgres>, 359 E: Into<BoxDynError>, 360 R, 361 M: FnOnce(T) -> Result<R, E>, 362 >( 363 &'r self, 364 index: I, 365 map: M, 366 ) -> sqlx::Result<R> { 367 let primitive: T = self.try_get(&index)?; 368 map(primitive).map_err(|source| sqlx::Error::ColumnDecode { 369 index: format!("{index:?}"), 370 source: source.into(), 371 }) 372 } 373 374 fn try_get_amount<I: sqlx::ColumnIndex<Self>>( 375 &self, 376 index: I, 377 currency: &Currency, 378 ) -> sqlx::Result<Amount> { 379 let decimal: Decimal = self.try_get(index)?; 380 Ok(Amount::new_decimal(currency, decimal)) 381 } 382 383 fn try_get_opt_amount<I: sqlx::ColumnIndex<Self>>( 384 &self, 385 index: I, 386 currency: &Currency, 387 ) -> sqlx::Result<Option<Amount>> { 388 let decimal: Option<Decimal> = self.try_get(index)?; 389 Ok(decimal.map(|decimal| Amount::new_decimal(currency, decimal))) 390 } 391 392 fn try_get_flag<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<bool> { 393 let opt_bool: Option<bool> = self.try_get(index)?; 394 Ok(opt_bool.unwrap_or(false)) 395 } 396 }