taler-rust

GNU Taler code in Rust. Largely core banking integrations.
Log | Files | Refs | Submodules | README | LICENSE

db.rs (12388B)


      1 /*
      2   This file is part of TALER
      3   Copyright (C) 2024, 2025, 2026 Taler Systems SA
      4 
      5   TALER is free software; you can redistribute it and/or modify it under the
      6   terms of the GNU Affero General Public License as published by the Free Software
      7   Foundation; either version 3, or (at your option) any later version.
      8 
      9   TALER is distributed in the hope that it will be useful, but WITHOUT ANY
     10   WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
     11   A PARTICULAR PURPOSE.  See the GNU Affero General Public License for more details.
     12 
     13   You should have received a copy of the GNU Affero General Public License along with
     14   TALER; see the file COPYING.  If not, see <http://www.gnu.org/licenses/>
     15 */
     16 
     17 use std::{str::FromStr, time::Duration};
     18 
     19 use jiff::{
     20     Timestamp,
     21     civil::{Date, Time},
     22     tz::TimeZone,
     23 };
     24 use sqlx::{
     25     Decode, Error, PgPool, Postgres, QueryBuilder, Row, Type,
     26     error::BoxDynError,
     27     postgres::PgRow,
     28     query::{Query, QueryScalar},
     29 };
     30 use taler_common::{
     31     api_common::SafeU64,
     32     api_params::{History, Page},
     33     types::{
     34         amount::{Amount, Currency, Decimal},
     35         iban::IBAN,
     36         payto::PaytoURI,
     37         timestamp::TalerTimestamp,
     38         utils::date_to_utc_ts,
     39     },
     40 };
     41 use tokio::sync::watch::Receiver;
     42 use url::Url;
     43 
     44 pub type PgQueryBuilder<'b> = QueryBuilder<'b, Postgres>;
     45 
     46 /* ------ Serialization ----- */
     47 
     48 pub trait PgError {
     49     const PG_SERIALIZATION_FAILURE: &str = "40001";
     50     const PG_DEADLOCK_DETECTED: &str = "40P01";
     51     const PG_UNIQUE_VIOLATION: &str = "23505";
     52     const PG_FOREIGN_KEY_VIOLATION: &str = "23503";
     53 
     54     fn is_retryable_err(&self) -> bool;
     55     fn is_unique_err(&self) -> bool;
     56     fn is_fk_err(&self) -> bool;
     57 }
     58 
     59 impl PgError for sqlx::error::Error {
     60     fn is_retryable_err(&self) -> bool {
     61         if let sqlx::Error::Database(e) = self {
     62             return matches!(
     63                 e.downcast_ref::<sqlx::postgres::PgDatabaseError>().code(),
     64                 Self::PG_SERIALIZATION_FAILURE | Self::PG_DEADLOCK_DETECTED
     65             );
     66         }
     67         false
     68     }
     69 
     70     fn is_unique_err(&self) -> bool {
     71         if let sqlx::Error::Database(e) = self {
     72             return e.downcast_ref::<sqlx::postgres::PgDatabaseError>().code()
     73                 == Self::PG_UNIQUE_VIOLATION;
     74         }
     75         false
     76     }
     77 
     78     fn is_fk_err(&self) -> bool {
     79         if let sqlx::Error::Database(e) = self {
     80             return e.downcast_ref::<sqlx::postgres::PgDatabaseError>().code()
     81                 == Self::PG_FOREIGN_KEY_VIOLATION;
     82         }
     83         false
     84     }
     85 }
     86 
     87 #[macro_export]
     88 macro_rules! serialized {
     89     ($logic:expr) => {{
     90         use $crate::db::PgError;
     91         let mut attempts = 0;
     92         const MAX_RETRIES: u32 = 5;
     93 
     94         loop {
     95             let res = $logic.await;
     96             if let sqlx::Result::Err(e) = &res
     97                 && e.is_retryable_err()
     98                 && attempts < MAX_RETRIES
     99             {
    100                 attempts += 1;
    101                 tokio::task::yield_now().await;
    102                 continue;
    103             }
    104             break res;
    105         }
    106     }};
    107 }
    108 
    109 /* ----- Routines ------ */
    110 
    111 pub async fn page<'a, 'b, R: Send + Unpin>(
    112     db: &PgPool,
    113     params: &Page,
    114     id_col: &str,
    115     prepare: impl Fn() -> QueryBuilder<'a, Postgres> + Copy,
    116     map: impl Fn(PgRow) -> Result<R, Error> + Send + Copy,
    117 ) -> Result<Vec<R>, Error> {
    118     serialized!(async {
    119         let mut builder = prepare();
    120         if let Some(offset) = params.offset {
    121             builder
    122                 .push(format_args!(
    123                     " {id_col} {}",
    124                     if params.backward() { '<' } else { '>' }
    125                 ))
    126                 .push_bind(offset);
    127         } else {
    128             builder.push("TRUE");
    129         }
    130         builder.push(format_args!(
    131             " ORDER BY {id_col} {} LIMIT ",
    132             if params.backward() { "DESC" } else { "ASC" }
    133         ));
    134         builder
    135             .push_bind(params.limit.abs())
    136             .build()
    137             .try_map(map)
    138             .fetch_all(db)
    139             .await
    140     })
    141 }
    142 
    143 pub async fn history<'a, 'b, R: Send + Unpin>(
    144     pool: &PgPool,
    145     id_col: &str,
    146     params: &History,
    147     listen: impl FnOnce() -> Receiver<i64>,
    148     prepare: impl Fn() -> QueryBuilder<'a, Postgres> + Copy,
    149     map: impl Fn(PgRow) -> Result<R, Error> + Send + Copy,
    150 ) -> Result<Vec<R>, Error> {
    151     let load = async || page(pool, &params.page, id_col, prepare, map).await;
    152 
    153     // When going backward there is always at least one transaction or none
    154     if params.page.limit >= 0 && params.timeout_ms.is_some_and(|it| it > 0) {
    155         let mut listener = listen();
    156         let init = load().await?;
    157         // Long polling if we found no transactions
    158         if init.is_empty() {
    159             let pooling = tokio::time::timeout(
    160                 Duration::from_millis(params.timeout_ms.unwrap_or(0)),
    161                 async {
    162                     listener
    163                         .wait_for(|id| params.page.offset.is_none_or(|offset| *id > offset))
    164                         .await
    165                         .ok();
    166                 },
    167             )
    168             .await;
    169             match pooling {
    170                 Ok(_) => load().await,
    171                 Err(_) => Ok(init),
    172             }
    173         } else {
    174             Ok(init)
    175         }
    176     } else {
    177         load().await
    178     }
    179 }
    180 
    181 /* ----- Bind ----- */
    182 
    183 pub trait BindHelper {
    184     fn bind_timestamp(self, timestamp: &Timestamp) -> Self;
    185     fn bind_date(self, date: &Date) -> Self;
    186 }
    187 
    188 impl<'q> BindHelper for Query<'q, Postgres, <Postgres as sqlx::Database>::Arguments<'q>> {
    189     fn bind_timestamp(self, timestamp: &Timestamp) -> Self {
    190         self.bind(timestamp.as_microsecond())
    191     }
    192 
    193     fn bind_date(self, date: &Date) -> Self {
    194         self.bind_timestamp(&date_to_utc_ts(date))
    195     }
    196 }
    197 
    198 impl<'q, T> BindHelper
    199     for QueryScalar<'q, Postgres, T, <Postgres as sqlx::Database>::Arguments<'q>>
    200 {
    201     fn bind_timestamp(self, timestamp: &Timestamp) -> Self {
    202         self.bind(timestamp.as_microsecond())
    203     }
    204 
    205     fn bind_date(self, date: &Date) -> Self {
    206         self.bind_timestamp(&date_to_utc_ts(date))
    207     }
    208 }
    209 
    210 /* ----- Get ----- */
    211 
    212 pub trait TypeHelper {
    213     fn try_get_map<
    214         'r,
    215         I: sqlx::ColumnIndex<Self>,
    216         T: Decode<'r, Postgres> + Type<Postgres>,
    217         E: Into<BoxDynError>,
    218         R,
    219         M: FnOnce(T) -> Result<R, E>,
    220     >(
    221         &'r self,
    222         index: I,
    223         map: M,
    224     ) -> sqlx::Result<R>;
    225     fn try_get_opt_map<
    226         'r,
    227         I: sqlx::ColumnIndex<Self>,
    228         T: Decode<'r, Postgres> + Type<Postgres>,
    229         E: Into<BoxDynError>,
    230         R,
    231         M: FnOnce(T) -> Result<R, E>,
    232     >(
    233         &'r self,
    234         index: I,
    235         map: M,
    236     ) -> sqlx::Result<Option<R>> {
    237         self.try_get_map(index, |it: Option<T>| it.map(map).transpose())
    238     }
    239     fn try_get_parse<I: sqlx::ColumnIndex<Self>, E: Into<BoxDynError>, T: FromStr<Err = E>>(
    240         &self,
    241         index: I,
    242     ) -> sqlx::Result<T> {
    243         self.try_get_map(index, |s: &str| s.parse())
    244     }
    245     fn try_get_opt_parse<I: sqlx::ColumnIndex<Self>, E: Into<BoxDynError>, T: FromStr<Err = E>>(
    246         &self,
    247         index: I,
    248     ) -> sqlx::Result<Option<T>> {
    249         self.try_get_map(index, |s: Option<&str>| s.map(|s| s.parse()).transpose())
    250     }
    251     fn try_get_timestamp<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<Timestamp> {
    252         self.try_get_map(index, |micros| {
    253             jiff::Timestamp::from_microsecond(micros)
    254                 .map_err(|e| format!("expected timestamp micros got overflowing {micros}: {e}"))
    255         })
    256     }
    257     fn try_get_opt_timestamp<I: sqlx::ColumnIndex<Self>>(
    258         &self,
    259         index: I,
    260     ) -> sqlx::Result<Option<Timestamp>> {
    261         self.try_get_map(index, |micros: Option<i64>| {
    262             if let Some(micros) = micros {
    263                 Some(jiff::Timestamp::from_microsecond(micros).map_err(|e| {
    264                     format!("expected timestamp micros got overflowing {micros}: {e}")
    265                 }))
    266                 .transpose()
    267             } else {
    268                 Ok(None)
    269             }
    270         })
    271     }
    272     fn try_get_taler_timestamp<I: sqlx::ColumnIndex<Self>>(
    273         &self,
    274         index: I,
    275     ) -> sqlx::Result<TalerTimestamp> {
    276         self.try_get_map(index, |micros: Option<i64>| match micros {
    277             Some(micros) => Ok::<_, String>(TalerTimestamp::Timestamp(
    278                 jiff::Timestamp::from_microsecond(micros).map_err(|e| {
    279                     format!("expected timestamp micros got overflowing {micros}: {e}")
    280                 })?,
    281             )),
    282             None => Ok(TalerTimestamp::Never),
    283         })
    284     }
    285     fn try_get_date<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<Date> {
    286         let timestamp = self.try_get_timestamp(index)?;
    287         let zoned = timestamp.to_zoned(TimeZone::UTC);
    288         assert_eq!(zoned.time(), Time::midnight());
    289         Ok(zoned.date())
    290     }
    291     fn try_get_u16<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<u16> {
    292         self.try_get_map(index, |signed: i16| signed.try_into())
    293     }
    294     fn try_get_opt_u16<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<Option<u16>> {
    295         self.try_get_opt_map(index, |signed: i16| signed.try_into())
    296     }
    297     fn try_get_u32<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<u32> {
    298         self.try_get_map(index, |signed: i32| signed.try_into())
    299     }
    300     fn try_get_opt_u32<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<Option<u32>> {
    301         self.try_get_opt_map(index, |signed: i32| signed.try_into())
    302     }
    303     fn try_get_u64<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<u64> {
    304         self.try_get_map(index, |signed: i64| signed.try_into())
    305     }
    306     fn try_get_opt_u64<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<Option<u64>> {
    307         self.try_get_opt_map(index, |signed: i64| signed.try_into())
    308     }
    309     fn try_get_safeu64<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<SafeU64> {
    310         self.try_get_map(index, |signed: i64| SafeU64::try_from(signed))
    311     }
    312     fn try_get_url<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<Url> {
    313         self.try_get_parse(index)
    314     }
    315     fn try_get_payto<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<PaytoURI> {
    316         self.try_get_parse(index)
    317     }
    318     fn try_get_opt_payto<I: sqlx::ColumnIndex<Self>>(
    319         &self,
    320         index: I,
    321     ) -> sqlx::Result<Option<PaytoURI>> {
    322         self.try_get_opt_parse(index)
    323     }
    324     fn try_get_iban<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<IBAN> {
    325         self.try_get_parse(index)
    326     }
    327     fn try_get_amount<I: sqlx::ColumnIndex<Self>>(
    328         &self,
    329         index: I,
    330         currency: &Currency,
    331     ) -> sqlx::Result<Amount>;
    332     fn try_get_opt_amount<I: sqlx::ColumnIndex<Self>>(
    333         &self,
    334         index: I,
    335         currency: &Currency,
    336     ) -> sqlx::Result<Option<Amount>>;
    337 
    338     /** Flag consider NULL and false to be the same */
    339     fn try_get_flag<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<bool>;
    340 }
    341 
    342 impl TypeHelper for PgRow {
    343     fn try_get_map<
    344         'r,
    345         I: sqlx::ColumnIndex<Self>,
    346         T: Decode<'r, Postgres> + Type<Postgres>,
    347         E: Into<BoxDynError>,
    348         R,
    349         M: FnOnce(T) -> Result<R, E>,
    350     >(
    351         &'r self,
    352         index: I,
    353         map: M,
    354     ) -> sqlx::Result<R> {
    355         let primitive: T = self.try_get(&index)?;
    356         map(primitive).map_err(|source| sqlx::Error::ColumnDecode {
    357             index: format!("{index:?}"),
    358             source: source.into(),
    359         })
    360     }
    361 
    362     fn try_get_amount<I: sqlx::ColumnIndex<Self>>(
    363         &self,
    364         index: I,
    365         currency: &Currency,
    366     ) -> sqlx::Result<Amount> {
    367         let decimal: Decimal = self.try_get(index)?;
    368         Ok(Amount::new_decimal(currency, decimal))
    369     }
    370 
    371     fn try_get_opt_amount<I: sqlx::ColumnIndex<Self>>(
    372         &self,
    373         index: I,
    374         currency: &Currency,
    375     ) -> sqlx::Result<Option<Amount>> {
    376         let decimal: Option<Decimal> = self.try_get(index)?;
    377         Ok(decimal.map(|decimal| Amount::new_decimal(currency, decimal)))
    378     }
    379 
    380     fn try_get_flag<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<bool> {
    381         let opt_bool: Option<bool> = self.try_get(index)?;
    382         Ok(opt_bool.unwrap_or(false))
    383     }
    384 }