taler-rust

GNU Taler code in Rust. Largely core banking integrations.
Log | Files | Refs | Submodules | README | LICENSE

db.rs (9034B)


      1 /*
      2   This file is part of TALER
      3   Copyright (C) 2024, 2025, 2026 Taler Systems SA
      4 
      5   TALER is free software; you can redistribute it and/or modify it under the
      6   terms of the GNU Affero General Public License as published by the Free Software
      7   Foundation; either version 3, or (at your option) any later version.
      8 
      9   TALER is distributed in the hope that it will be useful, but WITHOUT ANY
     10   WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
     11   A PARTICULAR PURPOSE.  See the GNU Affero General Public License for more details.
     12 
     13   You should have received a copy of the GNU Affero General Public License along with
     14   TALER; see the file COPYING.  If not, see <http://www.gnu.org/licenses/>
     15 */
     16 
     17 use std::{str::FromStr, time::Duration};
     18 
     19 use jiff::{
     20     Timestamp,
     21     civil::{Date, Time},
     22     tz::TimeZone,
     23 };
     24 use sqlx::{
     25     Decode, Error, PgPool, QueryBuilder, Type, error::BoxDynError, postgres::PgRow, query::Query,
     26 };
     27 use sqlx::{Postgres, Row};
     28 use taler_common::{
     29     api_common::SafeU64,
     30     api_params::{History, Page},
     31     types::{
     32         amount::{Amount, Currency, Decimal},
     33         iban::IBAN,
     34         payto::PaytoURI,
     35         utils::date_to_utc_ts,
     36     },
     37 };
     38 use tokio::sync::watch::Receiver;
     39 use url::Url;
     40 
     41 pub type PgQueryBuilder<'b> = QueryBuilder<'b, Postgres>;
     42 /* ------ Serialization ----- */
     43 
     44 #[macro_export]
     45 macro_rules! serialized {
     46     ($logic:expr) => {{
     47         let mut attempts = 0;
     48         const MAX_RETRIES: u32 = 5;
     49         /// PostgreSQL serialization failure error code (40001)
     50         const PG_SERIALIZATION_FAILURE: &str = "40001";
     51         /// PostgreSQL deadlock detected error code (40P01)
     52         const PG_DEADLOCK_DETECTED: &str = "40P01";
     53 
     54         loop {
     55             match $logic.await {
     56                 Ok(res) => break Ok(res),
     57                 Err(sqlx::Error::Database(e))
     58                     if matches!(
     59                         e.downcast_ref::<sqlx::postgres::PgDatabaseError>().code(),
     60                         PG_SERIALIZATION_FAILURE | PG_DEADLOCK_DETECTED
     61                     ) && attempts < MAX_RETRIES =>
     62                 {
     63                     attempts += 1;
     64                     tokio::task::yield_now().await;
     65                     continue;
     66                 }
     67                 Err(e) => break Err(e),
     68             }
     69         }
     70     }};
     71 }
     72 
     73 /* ----- Routines ------ */
     74 
     75 pub async fn page<'a, 'b, R: Send + Unpin>(
     76     db: &PgPool,
     77     id_col: &str,
     78     params: &Page,
     79     prepare: impl Fn() -> QueryBuilder<'a, Postgres> + Copy,
     80     map: impl Fn(PgRow) -> Result<R, Error> + Send + Copy,
     81 ) -> Result<Vec<R>, Error> {
     82     serialized!(async {
     83         let mut builder = prepare();
     84         if let Some(offset) = params.offset {
     85             builder
     86                 .push(format_args!(
     87                     " {id_col} {}",
     88                     if params.backward() { '<' } else { '>' }
     89                 ))
     90                 .push_bind(offset);
     91         } else {
     92             builder.push("TRUE");
     93         }
     94         builder.push(format_args!(
     95             " ORDER BY {id_col} {} LIMIT ",
     96             if params.backward() { "DESC" } else { "ASC" }
     97         ));
     98         builder
     99             .push_bind(params.limit.abs())
    100             .build()
    101             .try_map(map)
    102             .fetch_all(db)
    103             .await
    104     })
    105 }
    106 
    107 pub async fn history<'a, 'b, R: Send + Unpin>(
    108     pool: &PgPool,
    109     id_col: &str,
    110     params: &History,
    111     listen: impl FnOnce() -> Receiver<i64>,
    112     prepare: impl Fn() -> QueryBuilder<'a, Postgres> + Copy,
    113     map: impl Fn(PgRow) -> Result<R, Error> + Send + Copy,
    114 ) -> Result<Vec<R>, Error> {
    115     let load = async || page(pool, id_col, &params.page, prepare, map).await;
    116 
    117     // When going backward there is always at least one transaction or none
    118     if params.page.limit >= 0 && params.timeout_ms.is_some_and(|it| it > 0) {
    119         let mut listener = listen();
    120         let init = load().await?;
    121         // Long polling if we found no transactions
    122         if init.is_empty() {
    123             let pooling = tokio::time::timeout(
    124                 Duration::from_millis(params.timeout_ms.unwrap_or(0)),
    125                 async {
    126                     listener
    127                         .wait_for(|id| params.page.offset.is_none_or(|offset| *id > offset))
    128                         .await
    129                         .ok();
    130                 },
    131             )
    132             .await;
    133             match pooling {
    134                 Ok(_) => load().await,
    135                 Err(_) => Ok(init),
    136             }
    137         } else {
    138             Ok(init)
    139         }
    140     } else {
    141         load().await
    142     }
    143 }
    144 
    145 /* ----- Bind ----- */
    146 
    147 pub trait BindHelper {
    148     fn bind_timestamp(self, timestamp: &Timestamp) -> Self;
    149     fn bind_date(self, date: &Date) -> Self;
    150 }
    151 
    152 impl<'q> BindHelper for Query<'q, Postgres, <Postgres as sqlx::Database>::Arguments<'q>> {
    153     fn bind_timestamp(self, timestamp: &Timestamp) -> Self {
    154         self.bind(timestamp.as_microsecond())
    155     }
    156 
    157     fn bind_date(self, date: &Date) -> Self {
    158         self.bind_timestamp(&date_to_utc_ts(date))
    159     }
    160 }
    161 
    162 /* ----- Get ----- */
    163 
    164 pub trait TypeHelper {
    165     fn try_get_map<
    166         'r,
    167         I: sqlx::ColumnIndex<Self>,
    168         T: Decode<'r, Postgres> + Type<Postgres>,
    169         E: Into<BoxDynError>,
    170         R,
    171         M: FnOnce(T) -> Result<R, E>,
    172     >(
    173         &'r self,
    174         index: I,
    175         map: M,
    176     ) -> sqlx::Result<R>;
    177     fn try_get_parse<I: sqlx::ColumnIndex<Self>, E: Into<BoxDynError>, T: FromStr<Err = E>>(
    178         &self,
    179         index: I,
    180     ) -> sqlx::Result<T>;
    181     fn try_get_opt_parse<I: sqlx::ColumnIndex<Self>, E: Into<BoxDynError>, T: FromStr<Err = E>>(
    182         &self,
    183         index: I,
    184     ) -> sqlx::Result<Option<T>>;
    185     fn try_get_timestamp<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<Timestamp> {
    186         self.try_get_map(index, |micros| {
    187             jiff::Timestamp::from_microsecond(micros)
    188                 .map_err(|e| format!("expected timestamp micros got overflowing {micros}: {e}"))
    189         })
    190     }
    191     fn try_get_date<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<Date> {
    192         let timestamp = self.try_get_timestamp(index)?;
    193         let zoned = timestamp.to_zoned(TimeZone::UTC);
    194         assert_eq!(zoned.time(), Time::midnight());
    195         Ok(zoned.date())
    196     }
    197     fn try_get_u32<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<u32> {
    198         self.try_get_map(index, |signed: i32| signed.try_into())
    199     }
    200     fn try_get_u64<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<u64> {
    201         self.try_get_map(index, |signed: i64| signed.try_into())
    202     }
    203     fn try_get_safeu64<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<SafeU64> {
    204         self.try_get_map(index, |signed: i64| SafeU64::try_from(signed))
    205     }
    206     fn try_get_url<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<Url> {
    207         self.try_get_parse(index)
    208     }
    209     fn try_get_payto<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<PaytoURI> {
    210         self.try_get_parse(index)
    211     }
    212     fn try_get_opt_payto<I: sqlx::ColumnIndex<Self>>(
    213         &self,
    214         index: I,
    215     ) -> sqlx::Result<Option<PaytoURI>> {
    216         self.try_get_opt_parse(index)
    217     }
    218     fn try_get_iban<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<IBAN> {
    219         self.try_get_parse(index)
    220     }
    221     fn try_get_amount<I: sqlx::ColumnIndex<Self>>(
    222         &self,
    223         index: I,
    224         currency: &Currency,
    225     ) -> sqlx::Result<Amount>;
    226 
    227     /** Flag consider NULL and false to be the same */
    228     fn try_get_flag<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<bool>;
    229 }
    230 
    231 impl TypeHelper for PgRow {
    232     fn try_get_map<
    233         'r,
    234         I: sqlx::ColumnIndex<Self>,
    235         T: Decode<'r, Postgres> + Type<Postgres>,
    236         E: Into<BoxDynError>,
    237         R,
    238         M: FnOnce(T) -> Result<R, E>,
    239     >(
    240         &'r self,
    241         index: I,
    242         map: M,
    243     ) -> sqlx::Result<R> {
    244         let primitive: T = self.try_get(&index)?;
    245         map(primitive).map_err(|source| sqlx::Error::ColumnDecode {
    246             index: format!("{index:?}"),
    247             source: source.into(),
    248         })
    249     }
    250 
    251     fn try_get_parse<I: sqlx::ColumnIndex<Self>, E: Into<BoxDynError>, T: FromStr<Err = E>>(
    252         &self,
    253         index: I,
    254     ) -> sqlx::Result<T> {
    255         self.try_get_map(index, |s: &str| s.parse())
    256     }
    257 
    258     fn try_get_opt_parse<I: sqlx::ColumnIndex<Self>, E: Into<BoxDynError>, T: FromStr<Err = E>>(
    259         &self,
    260         index: I,
    261     ) -> sqlx::Result<Option<T>> {
    262         self.try_get_map(index, |s: Option<&str>| s.map(|s| s.parse()).transpose())
    263     }
    264 
    265     fn try_get_amount<I: sqlx::ColumnIndex<Self>>(
    266         &self,
    267         index: I,
    268         currency: &Currency,
    269     ) -> sqlx::Result<Amount> {
    270         let decimal: Decimal = self.try_get(index)?;
    271         Ok(Amount::new_decimal(currency, decimal))
    272     }
    273 
    274     fn try_get_flag<I: sqlx::ColumnIndex<Self>>(&self, index: I) -> sqlx::Result<bool> {
    275         let opt_bool: Option<bool> = self.try_get(index)?;
    276         Ok(opt_bool.unwrap_or(false))
    277     }
    278 }