taler-rust

GNU Taler code in Rust. Largely core banking integrations.
Log | Files | Refs | Submodules | README | LICENSE

db.rs (12565B)


      1 /*
      2   This file is part of TALER
      3   Copyright (C) 2024, 2025, 2026 Taler Systems SA
      4 
      5   TALER is free software; you can redistribute it and/or modify it under the
      6   terms of the GNU Affero General Public License as published by the Free Software
      7   Foundation; either version 3, or (at your option) any later version.
      8 
      9   TALER is distributed in the hope that it will be useful, but WITHOUT ANY
     10   WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
     11   A PARTICULAR PURPOSE.  See the GNU Affero General Public License for more details.
     12 
     13   You should have received a copy of the GNU Affero General Public License along with
     14   TALER; see the file COPYING.  If not, see <http://www.gnu.org/licenses/>
     15 */
     16 
     17 use std::{str::FromStr, time::Duration};
     18 
     19 use jiff::{
     20     Timestamp,
     21     civil::{Date, Time},
     22     tz::TimeZone,
     23 };
     24 use sqlx::{
     25     Decode, Error, PgPool, Postgres, QueryBuilder, Row, Type,
     26     error::BoxDynError,
     27     postgres::PgRow,
     28     query::{Query, QueryScalar},
     29 };
     30 use taler_common::{
     31     api::params::{History, Page, Pooling},
     32     types::{
     33         amount::{Amount, Currency, Decimal},
     34         iban::IBAN,
     35         payto::PaytoURI,
     36         timestamp::TalerTimestamp,
     37         utils::date_to_utc_ts,
     38     },
     39 };
     40 use tokio::sync::watch::{self};
     41 use url::Url;
     42 
     43 pub type PgQueryBuilder<'b> = QueryBuilder<'b, Postgres>;
     44 
     45 /* ------ Serialization ----- */
     46 
     47 pub trait PgError {
     48     const PG_SERIALIZATION_FAILURE: &str = "40001";
     49     const PG_DEADLOCK_DETECTED: &str = "40P01";
     50     const PG_UNIQUE_VIOLATION: &str = "23505";
     51     const PG_FOREIGN_KEY_VIOLATION: &str = "23503";
     52 
     53     fn is_retryable_err(&self) -> bool;
     54     fn is_unique_err(&self) -> bool;
     55     fn is_fk_err(&self) -> bool;
     56 }
     57 
     58 impl PgError for sqlx::error::Error {
     59     fn is_retryable_err(&self) -> bool {
     60         if let sqlx::Error::Database(e) = self {
     61             return matches!(
     62                 e.downcast_ref::<sqlx::postgres::PgDatabaseError>().code(),
     63                 Self::PG_SERIALIZATION_FAILURE | Self::PG_DEADLOCK_DETECTED
     64             );
     65         }
     66         false
     67     }
     68 
     69     fn is_unique_err(&self) -> bool {
     70         if let sqlx::Error::Database(e) = self {
     71             return e.downcast_ref::<sqlx::postgres::PgDatabaseError>().code()
     72                 == Self::PG_UNIQUE_VIOLATION;
     73         }
     74         false
     75     }
     76 
     77     fn is_fk_err(&self) -> bool {
     78         if let sqlx::Error::Database(e) = self {
     79             return e.downcast_ref::<sqlx::postgres::PgDatabaseError>().code()
     80                 == Self::PG_FOREIGN_KEY_VIOLATION;
     81         }
     82         false
     83     }
     84 }
     85 
     86 #[macro_export]
     87 macro_rules! serialized {
     88     ($logic:expr) => {{
     89         use $crate::db::PgError;
     90         let mut attempts = 0;
     91         const MAX_RETRIES: u32 = 5;
     92 
     93         loop {
     94             let res: sqlx::Result<_, sqlx::Error> = $logic.await;
     95             if let Err(e) = &res
     96                 && e.is_retryable_err()
     97                 && attempts < MAX_RETRIES
     98             {
     99                 attempts += 1;
    100                 tokio::task::yield_now().await;
    101                 continue;
    102             }
    103             break res;
    104         }
    105     }};
    106 }
    107 
    108 /* ----- Routines ------ */
    109 
    110 pub async fn page<'a, 'b, R: Send + Unpin>(
    111     db: &PgPool,
    112     params: &Page,
    113     id_col: &str,
    114     prepare: impl Fn() -> QueryBuilder<'a, Postgres> + Copy,
    115     map: impl Fn(PgRow) -> Result<R, Error> + Send + Copy,
    116 ) -> Result<Vec<R>, Error> {
    117     serialized!(async {
    118         let mut builder = prepare();
    119         if let Some(offset) = params.offset {
    120             builder
    121                 .push(format_args!(
    122                     " {id_col} {}",
    123                     if params.backward() { '<' } else { '>' }
    124                 ))
    125                 .push_bind(offset);
    126         } else {
    127             builder.push("TRUE");
    128         }
    129         builder.push(format_args!(
    130             " ORDER BY {id_col} {} LIMIT ",
    131             if params.backward() { "DESC" } else { "ASC" }
    132         ));
    133         builder
    134             .push_bind(params.limit.abs())
    135             .build()
    136             .try_map(map)
    137             .fetch_all(db)
    138             .await
    139     })
    140 }
    141 
    142 pub async fn pooling<R, N, F: Future<Output = sqlx::Result<R>>>(
    143     params: &Pooling,
    144     listen: impl FnOnce() -> watch::Receiver<N>,
    145     filter: impl FnMut(&N) -> bool,
    146     mut load: impl FnMut() -> F,
    147     mut check: impl FnMut(&R) -> bool,
    148 ) -> Result<R, Error> {
    149     let timeout = params.timeout_ms.unwrap_or_default();
    150     if timeout > 0 {
    151         let mut listener = listen();
    152         let init = load().await?;
    153         // Long polling if we found no transactions
    154         if !check(&init) {
    155             let pooling = tokio::time::timeout(Duration::from_millis(timeout), async {
    156                 listener.wait_for(filter).await.ok();
    157             })
    158             .await;
    159             match pooling {
    160                 Ok(_) => load().await,
    161                 Err(_) => Ok(init),
    162             }
    163         } else {
    164             Ok(init)
    165         }
    166     } else {
    167         load().await
    168     }
    169 }
    170 
    171 pub async fn history<T: Send + Unpin>(
    172     db: &PgPool,
    173     id_col: &str,
    174     params: &History,
    175     listen: impl FnOnce() -> watch::Receiver<i64>,
    176     prepare: impl Fn() -> QueryBuilder<'static, Postgres> + Copy,
    177     map: impl Fn(PgRow) -> Result<T, Error> + Send + Copy,
    178 ) -> Result<Vec<T>, Error> {
    179     let load = async || page(db, &params.page, id_col, prepare, map).await;
    180     // When going backward there is always at least one transaction or none
    181     let poll = if params.page.limit < 0 {
    182         &Pooling::default()
    183     } else {
    184         &params.pooling
    185     };
    186     pooling(
    187         poll,
    188         listen,
    189         |id| params.page.offset.is_none_or(|offset| *id > offset),
    190         load,
    191         |init| !init.is_empty(),
    192     )
    193     .await
    194 }
    195 
    196 /* ----- Bind ----- */
    197 
    198 pub trait BindHelper {
    199     fn bind_timestamp(self, timestamp: &Timestamp) -> Self;
    200     fn bind_date(self, date: &Date) -> Self;
    201 }
    202 
    203 impl<'q> BindHelper for Query<'q, Postgres, <Postgres as sqlx::Database>::Arguments<'q>> {
    204     fn bind_timestamp(self, timestamp: &Timestamp) -> Self {
    205         self.bind(timestamp.as_microsecond())
    206     }
    207 
    208     fn bind_date(self, date: &Date) -> Self {
    209         self.bind_timestamp(&date_to_utc_ts(date))
    210     }
    211 }
    212 
    213 impl<'q, T> BindHelper
    214     for QueryScalar<'q, Postgres, T, <Postgres as sqlx::Database>::Arguments<'q>>
    215 {
    216     fn bind_timestamp(self, timestamp: &Timestamp) -> Self {
    217         self.bind(timestamp.as_microsecond())
    218     }
    219 
    220     fn bind_date(self, date: &Date) -> Self {
    221         self.bind_timestamp(&date_to_utc_ts(date))
    222     }
    223 }
    224 
    225 /* ----- Get ----- */
    226 
    227 pub trait TypeHelper {
    228     fn try_get_map<
    229         'r,
    230         I: sqlx::ColumnIndex<Self>,
    231         T: Decode<'r, Postgres> + Type<Postgres>,
    232         E: Into<BoxDynError>,
    233         R,
    234         M: FnOnce(T) -> Result<R, E>,
    235     >(
    236         &'r self,
    237         index: I,
    238         map: M,
    239     ) -> sqlx::Result<R>;
    240     fn try_get_opt_map<
    241         'r,
    242         I: sqlx::ColumnIndex<Self>,
    243         T: Decode<'r, Postgres> + Type<Postgres>,
    244         E: Into<BoxDynError>,
    245         R,
    246         M: FnOnce(T) -> Result<R, E>,
    247     >(
    248         &'r self,
    249         index: I,
    250         map: M,
    251     ) -> sqlx::Result<Option<R>> {
    252         self.try_get_map(index, |it: Option<T>| it.map(map).transpose())
    253     }
    254     fn try_get_parse<I: sqlx::ColumnIndex<Self>, E: Into<BoxDynError>, T: FromStr<Err = E>>(
    255         &self,
    256         index: I,
    257     ) -> sqlx::Result<T> {
    258         self.try_get_map(index, |s: &str| s.parse())
    259     }
    260     fn try_get_opt_parse<I: sqlx::ColumnIndex<Self>, E: Into<BoxDynError>, T: FromStr<Err = E>>(
    261         &self,
    262         index: I,
    263     ) -> sqlx::Result<Option<T>> {
    264         self.try_get_map(index, |s: Option<&str>| s.map(|s| s.parse()).transpose())
    265     }
    266     fn try_get_timestamp<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<Timestamp> {
    267         self.try_get_map(index, |micros| {
    268             jiff::Timestamp::from_microsecond(micros)
    269                 .map_err(|e| format!("expected timestamp micros got overflowing {micros}: {e}"))
    270         })
    271     }
    272     fn try_get_opt_timestamp<I: sqlx::ColumnIndex<Self>>(
    273         &self,
    274         index: I,
    275     ) -> sqlx::Result<Option<Timestamp>> {
    276         self.try_get_map(index, |micros: Option<i64>| {
    277             if let Some(micros) = micros {
    278                 Some(jiff::Timestamp::from_microsecond(micros).map_err(|e| {
    279                     format!("expected timestamp micros got overflowing {micros}: {e}")
    280                 }))
    281                 .transpose()
    282             } else {
    283                 Ok(None)
    284             }
    285         })
    286     }
    287     fn try_get_taler_timestamp<I: sqlx::ColumnIndex<Self>>(
    288         &self,
    289         index: I,
    290     ) -> sqlx::Result<TalerTimestamp> {
    291         self.try_get_map(index, |micros: Option<i64>| match micros {
    292             Some(micros) => Ok::<_, String>(TalerTimestamp::Timestamp(
    293                 jiff::Timestamp::from_microsecond(micros).map_err(|e| {
    294                     format!("expected timestamp micros got overflowing {micros}: {e}")
    295                 })?,
    296             )),
    297             None => Ok(TalerTimestamp::Never),
    298         })
    299     }
    300     fn try_get_date<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<Date> {
    301         let timestamp = self.try_get_timestamp(index)?;
    302         let zoned = timestamp.to_zoned(TimeZone::UTC);
    303         assert_eq!(zoned.time(), Time::midnight());
    304         Ok(zoned.date())
    305     }
    306     fn try_get_u16<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<u16> {
    307         self.try_get_map(index, |signed: i16| signed.try_into())
    308     }
    309     fn try_get_opt_u16<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<Option<u16>> {
    310         self.try_get_opt_map(index, |signed: i16| signed.try_into())
    311     }
    312     fn try_get_u32<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<u32> {
    313         self.try_get_map(index, |signed: i32| signed.try_into())
    314     }
    315     fn try_get_opt_u32<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<Option<u32>> {
    316         self.try_get_opt_map(index, |signed: i32| signed.try_into())
    317     }
    318     fn try_get_u64<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<u64> {
    319         self.try_get_map(index, |signed: i64| signed.try_into())
    320     }
    321     fn try_get_opt_u64<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<Option<u64>> {
    322         self.try_get_opt_map(index, |signed: i64| signed.try_into())
    323     }
    324     fn try_get_url<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<Url> {
    325         self.try_get_parse(index)
    326     }
    327     fn try_get_payto<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<PaytoURI> {
    328         self.try_get_parse(index)
    329     }
    330     fn try_get_opt_payto<I: sqlx::ColumnIndex<Self>>(
    331         &self,
    332         index: I,
    333     ) -> sqlx::Result<Option<PaytoURI>> {
    334         self.try_get_opt_parse(index)
    335     }
    336     fn try_get_iban<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<IBAN> {
    337         self.try_get_parse(index)
    338     }
    339     fn try_get_amount<I: sqlx::ColumnIndex<Self>>(
    340         &self,
    341         index: I,
    342         currency: &Currency,
    343     ) -> sqlx::Result<Amount>;
    344     fn try_get_opt_amount<I: sqlx::ColumnIndex<Self>>(
    345         &self,
    346         index: I,
    347         currency: &Currency,
    348     ) -> sqlx::Result<Option<Amount>>;
    349 
    350     /** Flag consider NULL and false to be the same */
    351     fn try_get_flag<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<bool>;
    352 }
    353 
    354 impl TypeHelper for PgRow {
    355     fn try_get_map<
    356         'r,
    357         I: sqlx::ColumnIndex<Self>,
    358         T: Decode<'r, Postgres> + Type<Postgres>,
    359         E: Into<BoxDynError>,
    360         R,
    361         M: FnOnce(T) -> Result<R, E>,
    362     >(
    363         &'r self,
    364         index: I,
    365         map: M,
    366     ) -> sqlx::Result<R> {
    367         let primitive: T = self.try_get(&index)?;
    368         map(primitive).map_err(|source| sqlx::Error::ColumnDecode {
    369             index: format!("{index:?}"),
    370             source: source.into(),
    371         })
    372     }
    373 
    374     fn try_get_amount<I: sqlx::ColumnIndex<Self>>(
    375         &self,
    376         index: I,
    377         currency: &Currency,
    378     ) -> sqlx::Result<Amount> {
    379         let decimal: Decimal = self.try_get(index)?;
    380         Ok(Amount::new_decimal(currency, decimal))
    381     }
    382 
    383     fn try_get_opt_amount<I: sqlx::ColumnIndex<Self>>(
    384         &self,
    385         index: I,
    386         currency: &Currency,
    387     ) -> sqlx::Result<Option<Amount>> {
    388         let decimal: Option<Decimal> = self.try_get(index)?;
    389         Ok(decimal.map(|decimal| Amount::new_decimal(currency, decimal)))
    390     }
    391 
    392     fn try_get_flag<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<bool> {
    393         let opt_bool: Option<bool> = self.try_get(index)?;
    394         Ok(opt_bool.unwrap_or(false))
    395     }
    396 }