taler-rust

GNU Taler code in Rust. Largely core banking integrations.
Log | Files | Refs | Submodules | README | LICENSE

amount.rs (15366B)


      1 /*
      2   This file is part of TALER
      3   Copyright (C) 2024, 2025, 2026 Taler Systems SA
      4 
      5   TALER is free software; you can redistribute it and/or modify it under the
      6   terms of the GNU Affero General Public License as published by the Free Software
      7   Foundation; either version 3, or (at your option) any later version.
      8 
      9   TALER is distributed in the hope that it will be useful, but WITHOUT ANY
     10   WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
     11   A PARTICULAR PURPOSE.  See the GNU Affero General Public License for more details.
     12 
     13   You should have received a copy of the GNU Affero General Public License along with
     14   TALER; see the file COPYING.  If not, see <http://www.gnu.org/licenses/>
     15 */
     16 
     17 //! Type for the Taler Amount <https://docs.taler.net/core/api-common.html#tsref-type-Amount>
     18 
     19 use std::{
     20     fmt::{Debug, Display},
     21     num::ParseIntError,
     22     str::FromStr,
     23 };
     24 
     25 use super::utils::InlineStr;
     26 
     27 /** Number of characters we use to represent currency names */
     28 // We use the same value than the exchange -1 because we use a byte for the len instead of 0 termination
     29 pub const CURRENCY_LEN: usize = 11;
     30 
     31 /** Maximum legal value for an amount, based on IEEE double */
     32 pub const MAX_VALUE: u64 = 2 << 52;
     33 
     34 /** The number of digits in a fraction part of an amount */
     35 pub const FRAC_BASE_NB_DIGITS: u8 = 8;
     36 
     37 /** The fraction part of an amount represents which fraction of the value */
     38 pub const FRAC_BASE: u32 = 10u32.pow(FRAC_BASE_NB_DIGITS as u32);
     39 
     40 #[derive(
     41     Clone, Copy, PartialEq, Eq, serde_with::DeserializeFromStr, serde_with::SerializeDisplay,
     42 )]
     43 /// Inlined ISO 4217 currency string
     44 pub struct Currency(InlineStr<CURRENCY_LEN>);
     45 
     46 impl AsRef<str> for Currency {
     47     fn as_ref(&self) -> &str {
     48         self.0.as_ref()
     49     }
     50 }
     51 
     52 #[derive(Debug, thiserror::Error)]
     53 pub enum CurrencyErrorKind {
     54     #[error("contains illegal characters (only A-Z allowed)")]
     55     Invalid,
     56     #[error("too long (max {CURRENCY_LEN} chars)")]
     57     Big,
     58     #[error("is empty")]
     59     Empty,
     60 }
     61 
     62 #[derive(Debug, thiserror::Error)]
     63 #[error("currency code name '{currency}' {kind}")]
     64 pub struct ParseCurrencyError {
     65     currency: String,
     66     pub kind: CurrencyErrorKind,
     67 }
     68 
     69 impl FromStr for Currency {
     70     type Err = ParseCurrencyError;
     71 
     72     fn from_str(s: &str) -> Result<Self, Self::Err> {
     73         let bytes = s.as_bytes();
     74         let len = bytes.len();
     75         if bytes.is_empty() {
     76             Err(CurrencyErrorKind::Empty)
     77         } else if len > CURRENCY_LEN {
     78             Err(CurrencyErrorKind::Big)
     79         } else if !bytes.iter().all(|c| c.is_ascii_uppercase()) {
     80             Err(CurrencyErrorKind::Invalid)
     81         } else {
     82             Ok(Self(InlineStr::copy_from_slice(bytes)))
     83         }
     84         .map_err(|kind| ParseCurrencyError {
     85             currency: s.to_owned(),
     86             kind,
     87         })
     88     }
     89 }
     90 
     91 impl Debug for Currency {
     92     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
     93         Debug::fmt(&self.as_ref(), f)
     94     }
     95 }
     96 
     97 impl Display for Currency {
     98     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
     99         Display::fmt(&self.as_ref(), f)
    100     }
    101 }
    102 
    103 #[derive(sqlx::Type)]
    104 #[sqlx(type_name = "taler_amount")]
    105 struct PgTalerAmount {
    106     pub val: i64,
    107     pub frac: i32,
    108 }
    109 
    110 #[derive(
    111     Clone,
    112     Copy,
    113     PartialEq,
    114     Eq,
    115     PartialOrd,
    116     Ord,
    117     serde_with::DeserializeFromStr,
    118     serde_with::SerializeDisplay,
    119 )]
    120 pub struct Decimal {
    121     /** Integer part */
    122     pub val: u64,
    123     /** Factional part, multiple of FRAC_BASE */
    124     pub frac: u32,
    125 }
    126 
    127 impl Decimal {
    128     pub fn new(val: u64, frac: u32) -> Self {
    129         Self { val, frac }
    130     }
    131 
    132     pub const fn max() -> Self {
    133         Self {
    134             val: MAX_VALUE,
    135             frac: FRAC_BASE - 1,
    136         }
    137     }
    138 
    139     pub const fn zero() -> Self {
    140         Self { val: 0, frac: 0 }
    141     }
    142 
    143     fn normalize(mut self) -> Option<Self> {
    144         self.val = self.val.checked_add((self.frac / FRAC_BASE) as u64)?;
    145         self.frac %= FRAC_BASE;
    146         if self.val > MAX_VALUE {
    147             return None;
    148         }
    149         Some(self)
    150     }
    151 
    152     pub fn try_add(mut self, rhs: &Self) -> Option<Self> {
    153         self.val = self.val.checked_add(rhs.val)?;
    154         self.frac = self
    155             .frac
    156             .checked_add(rhs.frac)
    157             .expect("amount fraction overflow should never happen with normalized amounts");
    158         self.normalize()
    159     }
    160 
    161     pub fn try_sub(mut self, rhs: &Self) -> Option<Self> {
    162         if rhs.frac > self.frac {
    163             self.val = self.val.checked_sub(1)?;
    164             self.frac += FRAC_BASE;
    165         }
    166         self.val = self.val.checked_sub(rhs.val)?;
    167         self.frac = self.frac.checked_sub(rhs.frac)?;
    168         self.normalize()
    169     }
    170 
    171     pub fn to_amount(self, currency: &Currency) -> Amount {
    172         Amount::new_decimal(currency, self)
    173     }
    174 }
    175 
    176 #[derive(Debug, thiserror::Error)]
    177 pub enum DecimalErrKind {
    178     #[error("value overflow (must be <= {MAX_VALUE})")]
    179     Overflow,
    180     #[error("invalid value ({0})")]
    181     InvalidValue(ParseIntError),
    182     #[error("invalid fraction ({0})")]
    183     InvalidFraction(ParseIntError),
    184     #[error("fraction overflow (max {FRAC_BASE_NB_DIGITS} digits)")]
    185     FractionOverflow,
    186 }
    187 
    188 #[derive(Debug, thiserror::Error)]
    189 #[error("decimal '{decimal}' {kind}")]
    190 pub struct ParseDecimalErr {
    191     decimal: String,
    192     pub kind: DecimalErrKind,
    193 }
    194 
    195 impl FromStr for Decimal {
    196     type Err = ParseDecimalErr;
    197 
    198     fn from_str(s: &str) -> Result<Self, Self::Err> {
    199         let (value, fraction) = s.split_once('.').unwrap_or((s, ""));
    200 
    201         // TODO use try block when stable
    202         (|| {
    203             let value: u64 = value.parse().map_err(DecimalErrKind::InvalidValue)?;
    204             if value > MAX_VALUE {
    205                 return Err(DecimalErrKind::Overflow);
    206             }
    207 
    208             if fraction.len() > FRAC_BASE_NB_DIGITS as usize {
    209                 return Err(DecimalErrKind::FractionOverflow);
    210             }
    211             let fraction: u32 = if fraction.is_empty() {
    212                 0
    213             } else {
    214                 fraction
    215                     .parse::<u32>()
    216                     .map_err(DecimalErrKind::InvalidFraction)?
    217                     * 10u32.pow(FRAC_BASE_NB_DIGITS as u32 - fraction.len() as u32)
    218             };
    219             Ok(Self {
    220                 val: value,
    221                 frac: fraction,
    222             })
    223         })()
    224         .map_err(|kind| ParseDecimalErr {
    225             decimal: s.to_owned(),
    226             kind,
    227         })
    228     }
    229 }
    230 
    231 impl Display for Decimal {
    232     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
    233         if self.frac == 0 {
    234             f.write_fmt(format_args!("{}", self.val))
    235         } else {
    236             let num = format!("{:08}", self.frac);
    237             f.write_fmt(format_args!("{}.{}", self.val, num.trim_end_matches('0')))
    238         }
    239     }
    240 }
    241 
    242 impl Debug for Decimal {
    243     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
    244         Display::fmt(&self, f)
    245     }
    246 }
    247 
    248 impl sqlx::Type<sqlx::Postgres> for Decimal {
    249     fn type_info() -> sqlx::postgres::PgTypeInfo {
    250         PgTalerAmount::type_info()
    251     }
    252 }
    253 
    254 impl<'q> sqlx::Encode<'q, sqlx::Postgres> for Decimal {
    255     fn encode_by_ref(
    256         &self,
    257         buf: &mut sqlx::postgres::PgArgumentBuffer,
    258     ) -> Result<sqlx::encode::IsNull, sqlx::error::BoxDynError> {
    259         PgTalerAmount {
    260             val: self.val as i64,
    261             frac: self.frac as i32,
    262         }
    263         .encode_by_ref(buf)
    264     }
    265 }
    266 
    267 impl<'r> sqlx::Decode<'r, sqlx::Postgres> for Decimal {
    268     fn decode(value: sqlx::postgres::PgValueRef<'r>) -> Result<Self, sqlx::error::BoxDynError> {
    269         let pg = PgTalerAmount::decode(value)?;
    270         Ok(Self {
    271             val: pg.val as u64,
    272             frac: pg.frac as u32,
    273         })
    274     }
    275 }
    276 
    277 #[track_caller]
    278 pub fn decimal(decimal: impl AsRef<str>) -> Decimal {
    279     decimal.as_ref().parse().expect("Invalid decimal constant")
    280 }
    281 
    282 /// <https://docs.taler.net/core/api-common.html#tsref-type-Amount>
    283 #[derive(
    284     Clone, Copy, PartialEq, Eq, serde_with::DeserializeFromStr, serde_with::SerializeDisplay,
    285 )]
    286 pub struct Amount {
    287     pub currency: Currency,
    288     pub val: u64,
    289     pub frac: u32,
    290 }
    291 
    292 impl Amount {
    293     pub fn new_decimal(currency: &Currency, decimal: Decimal) -> Self {
    294         (*currency, decimal).into()
    295     }
    296 
    297     pub fn new(currency: &Currency, val: u64, frac: u32) -> Self {
    298         Self::new_decimal(currency, Decimal { val, frac })
    299     }
    300 
    301     pub fn max(currency: &Currency) -> Self {
    302         Self::new_decimal(currency, Decimal::max())
    303     }
    304 
    305     pub fn zero(currency: &Currency) -> Self {
    306         Self::new_decimal(currency, Decimal::zero())
    307     }
    308 
    309     pub fn is_zero(&self) -> bool {
    310         self.decimal() == Decimal::zero()
    311     }
    312 
    313     pub const fn decimal(&self) -> Decimal {
    314         Decimal {
    315             val: self.val,
    316             frac: self.frac,
    317         }
    318     }
    319 
    320     pub fn normalize(self) -> Option<Self> {
    321         let decimal = self.decimal().normalize()?;
    322         Some((self.currency, decimal).into())
    323     }
    324 
    325     pub fn try_add(self, rhs: &Self) -> Option<Self> {
    326         assert_eq!(self.currency, rhs.currency);
    327         let decimal = self.decimal().try_add(&rhs.decimal())?.normalize()?;
    328         Some((self.currency, decimal).into())
    329     }
    330 
    331     pub fn try_sub(self, rhs: &Self) -> Option<Self> {
    332         assert_eq!(self.currency, rhs.currency);
    333         let decimal = self.decimal().try_sub(&rhs.decimal())?.normalize()?;
    334         Some((self.currency, decimal).into())
    335     }
    336 }
    337 
    338 impl From<(Currency, Decimal)> for Amount {
    339     fn from((currency, decimal): (Currency, Decimal)) -> Self {
    340         Self {
    341             currency,
    342             val: decimal.val,
    343             frac: decimal.frac,
    344         }
    345     }
    346 }
    347 
    348 #[track_caller]
    349 pub fn amount(amount: impl AsRef<str>) -> Amount {
    350     amount.as_ref().parse().expect("Invalid amount constant")
    351 }
    352 
    353 #[derive(Debug, thiserror::Error)]
    354 pub enum AmountErrKind {
    355     #[error("invalid format")]
    356     Format,
    357     #[error("currency {0}")]
    358     Currency(#[from] CurrencyErrorKind),
    359     #[error(transparent)]
    360     Decimal(#[from] DecimalErrKind),
    361 }
    362 
    363 #[derive(Debug, thiserror::Error)]
    364 #[error("amount '{amount}' {kind}")]
    365 pub struct ParseAmountErr {
    366     amount: String,
    367     pub kind: AmountErrKind,
    368 }
    369 
    370 impl FromStr for Amount {
    371     type Err = ParseAmountErr;
    372 
    373     fn from_str(s: &str) -> Result<Self, Self::Err> {
    374         // TODO use try block when stable
    375         (|| {
    376             let (currency, amount) = s.trim().split_once(':').ok_or(AmountErrKind::Format)?;
    377             let currency = currency.parse().map_err(|e: ParseCurrencyError| e.kind)?;
    378             let decimal = amount.parse().map_err(|e: ParseDecimalErr| e.kind)?;
    379             Ok((currency, decimal).into())
    380         })()
    381         .map_err(|kind| ParseAmountErr {
    382             amount: s.to_owned(),
    383             kind,
    384         })
    385     }
    386 }
    387 
    388 impl Display for Amount {
    389     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
    390         f.write_fmt(format_args!("{}:{}", self.currency, self.decimal()))
    391     }
    392 }
    393 
    394 impl Debug for Amount {
    395     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
    396         Display::fmt(&self, f)
    397     }
    398 }
    399 
    400 impl sqlx::Type<sqlx::Postgres> for Amount {
    401     fn type_info() -> sqlx::postgres::PgTypeInfo {
    402         PgTalerAmount::type_info()
    403     }
    404 }
    405 
    406 impl<'q> sqlx::Encode<'q, sqlx::Postgres> for Amount {
    407     fn encode_by_ref(
    408         &self,
    409         buf: &mut sqlx::postgres::PgArgumentBuffer,
    410     ) -> Result<sqlx::encode::IsNull, sqlx::error::BoxDynError> {
    411         self.decimal().encode_by_ref(buf)
    412     }
    413 }
    414 
    415 #[test]
    416 fn test_amount_parse() {
    417     const TALER_AMOUNT_FRAC_BASE: u32 = 100000000;
    418     // https://git.taler.net/exchange.git/tree/src/util/test_amount.c
    419 
    420     const INVALID_AMOUNTS: [&str; 6] = [
    421         "EUR:4a",                                                                     // non-numeric,
    422         "EUR:4.4a",                                                                   // non-numeric
    423         "EUR:4.a4",                                                                   // non-numeric
    424         ":4.a4",                                                                      // no currency
    425         "EUR:4.123456789", // precision to high
    426         "EUR:1234567890123456789012345678901234567890123456789012345678901234567890", // value to big
    427     ];
    428 
    429     for str in INVALID_AMOUNTS {
    430         let amount = Amount::from_str(str);
    431         assert!(amount.is_err(), "invalid {} got {:?}", str, &amount);
    432     }
    433 
    434     let eur: Currency = "EUR".parse().unwrap();
    435     let local: Currency = "LOCAL".parse().unwrap();
    436     let valid_amounts: Vec<(&str, &str, Amount)> = vec![
    437         ("EUR:4", "EUR:4", Amount::new(&eur, 4, 0)), // without fraction
    438         (
    439             "EUR:0.02",
    440             "EUR:0.02",
    441             Amount::new(&eur, 0, TALER_AMOUNT_FRAC_BASE / 100 * 2),
    442         ), // leading zero fraction
    443         (
    444             " EUR:4.12",
    445             "EUR:4.12",
    446             Amount::new(&eur, 4, TALER_AMOUNT_FRAC_BASE / 100 * 12),
    447         ), // leading space and fraction
    448         (
    449             " LOCAL:4444.1000",
    450             "LOCAL:4444.1",
    451             Amount::new(&local, 4444, TALER_AMOUNT_FRAC_BASE / 10),
    452         ), // local currency
    453     ];
    454     for (raw, expected, goal) in valid_amounts {
    455         let amount = Amount::from_str(raw);
    456         assert!(amount.is_ok(), "Valid {} got {:?}", raw, amount);
    457         assert_eq!(
    458             *amount.as_ref().unwrap(),
    459             goal,
    460             "Expected {:?} got {:?} for {}",
    461             goal,
    462             amount,
    463             raw
    464         );
    465         let amount = amount.unwrap();
    466         let str = amount.to_string();
    467         assert_eq!(str, expected);
    468         assert_eq!(amount, Amount::from_str(&str).unwrap(), "{str}");
    469     }
    470 }
    471 
    472 #[test]
    473 fn test_amount_add() {
    474     let eur: Currency = "EUR".parse().unwrap();
    475     assert_eq!(
    476         Amount::max(&eur).try_add(&Amount::zero(&eur)),
    477         Some(Amount::max(&eur))
    478     );
    479     assert_eq!(
    480         Amount::zero(&eur).try_add(&Amount::zero(&eur)),
    481         Some(Amount::zero(&eur))
    482     );
    483     assert_eq!(
    484         amount("EUR:6.41").try_add(&amount("EUR:4.69")),
    485         Some(amount("EUR:11.1"))
    486     );
    487     assert_eq!(
    488         amount(format!("EUR:{MAX_VALUE}")).try_add(&amount("EUR:0.99999999")),
    489         Some(Amount::max(&eur))
    490     );
    491 
    492     assert_eq!(
    493         amount(format!("EUR:{}", MAX_VALUE - 5)).try_add(&amount("EUR:6")),
    494         None
    495     );
    496     assert_eq!(
    497         Amount::new(&eur, u64::MAX, 0).try_add(&amount("EUR:1")),
    498         None
    499     );
    500     assert_eq!(
    501         amount(format!("EUR:{}.{}", MAX_VALUE - 5, FRAC_BASE - 1))
    502             .try_add(&amount("EUR:5.00000002")),
    503         None
    504     );
    505 }
    506 
    507 #[test]
    508 fn test_amount_normalize() {
    509     let eur: Currency = "EUR".parse().unwrap();
    510     assert_eq!(
    511         Amount::new(&eur, 4, 2 * FRAC_BASE).normalize(),
    512         Some(amount("EUR:6"))
    513     );
    514     assert_eq!(
    515         Amount::new(&eur, 4, 2 * FRAC_BASE + 1).normalize(),
    516         Some(amount("EUR:6.00000001"))
    517     );
    518     assert_eq!(
    519         Amount::new(&eur, MAX_VALUE, FRAC_BASE - 1).normalize(),
    520         Some(Amount::new(&eur, MAX_VALUE, FRAC_BASE - 1))
    521     );
    522     assert_eq!(Amount::new(&eur, u64::MAX, FRAC_BASE).normalize(), None);
    523     assert_eq!(Amount::new(&eur, MAX_VALUE, FRAC_BASE).normalize(), None);
    524 
    525     for amount in [Amount::max(&eur), Amount::zero(&eur)] {
    526         assert_eq!(amount.normalize(), Some(amount))
    527     }
    528 }