amount.rs (15366B)
1 /* 2 This file is part of TALER 3 Copyright (C) 2024, 2025, 2026 Taler Systems SA 4 5 TALER is free software; you can redistribute it and/or modify it under the 6 terms of the GNU Affero General Public License as published by the Free Software 7 Foundation; either version 3, or (at your option) any later version. 8 9 TALER is distributed in the hope that it will be useful, but WITHOUT ANY 10 WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR 11 A PARTICULAR PURPOSE. See the GNU Affero General Public License for more details. 12 13 You should have received a copy of the GNU Affero General Public License along with 14 TALER; see the file COPYING. If not, see <http://www.gnu.org/licenses/> 15 */ 16 17 //! Type for the Taler Amount <https://docs.taler.net/core/api-common.html#tsref-type-Amount> 18 19 use std::{ 20 fmt::{Debug, Display}, 21 num::ParseIntError, 22 str::FromStr, 23 }; 24 25 use super::utils::InlineStr; 26 27 /** Number of characters we use to represent currency names */ 28 // We use the same value than the exchange -1 because we use a byte for the len instead of 0 termination 29 pub const CURRENCY_LEN: usize = 11; 30 31 /** Maximum legal value for an amount, based on IEEE double */ 32 pub const MAX_VALUE: u64 = 2 << 52; 33 34 /** The number of digits in a fraction part of an amount */ 35 pub const FRAC_BASE_NB_DIGITS: u8 = 8; 36 37 /** The fraction part of an amount represents which fraction of the value */ 38 pub const FRAC_BASE: u32 = 10u32.pow(FRAC_BASE_NB_DIGITS as u32); 39 40 #[derive( 41 Clone, Copy, PartialEq, Eq, serde_with::DeserializeFromStr, serde_with::SerializeDisplay, 42 )] 43 /// Inlined ISO 4217 currency string 44 pub struct Currency(InlineStr<CURRENCY_LEN>); 45 46 impl AsRef<str> for Currency { 47 fn as_ref(&self) -> &str { 48 self.0.as_ref() 49 } 50 } 51 52 #[derive(Debug, thiserror::Error)] 53 pub enum CurrencyErrorKind { 54 #[error("contains illegal characters (only A-Z allowed)")] 55 Invalid, 56 #[error("too long (max {CURRENCY_LEN} chars)")] 57 Big, 58 #[error("is empty")] 59 Empty, 60 } 61 62 #[derive(Debug, thiserror::Error)] 63 #[error("currency code name '{currency}' {kind}")] 64 pub struct ParseCurrencyError { 65 currency: String, 66 pub kind: CurrencyErrorKind, 67 } 68 69 impl FromStr for Currency { 70 type Err = ParseCurrencyError; 71 72 fn from_str(s: &str) -> Result<Self, Self::Err> { 73 let bytes = s.as_bytes(); 74 let len = bytes.len(); 75 if bytes.is_empty() { 76 Err(CurrencyErrorKind::Empty) 77 } else if len > CURRENCY_LEN { 78 Err(CurrencyErrorKind::Big) 79 } else if !bytes.iter().all(|c| c.is_ascii_uppercase()) { 80 Err(CurrencyErrorKind::Invalid) 81 } else { 82 Ok(Self(InlineStr::copy_from_slice(bytes))) 83 } 84 .map_err(|kind| ParseCurrencyError { 85 currency: s.to_owned(), 86 kind, 87 }) 88 } 89 } 90 91 impl Debug for Currency { 92 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 93 Debug::fmt(&self.as_ref(), f) 94 } 95 } 96 97 impl Display for Currency { 98 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 99 Display::fmt(&self.as_ref(), f) 100 } 101 } 102 103 #[derive(sqlx::Type)] 104 #[sqlx(type_name = "taler_amount")] 105 struct PgTalerAmount { 106 pub val: i64, 107 pub frac: i32, 108 } 109 110 #[derive( 111 Clone, 112 Copy, 113 PartialEq, 114 Eq, 115 PartialOrd, 116 Ord, 117 serde_with::DeserializeFromStr, 118 serde_with::SerializeDisplay, 119 )] 120 pub struct Decimal { 121 /** Integer part */ 122 pub val: u64, 123 /** Factional part, multiple of FRAC_BASE */ 124 pub frac: u32, 125 } 126 127 impl Decimal { 128 pub fn new(val: u64, frac: u32) -> Self { 129 Self { val, frac } 130 } 131 132 pub const fn max() -> Self { 133 Self { 134 val: MAX_VALUE, 135 frac: FRAC_BASE - 1, 136 } 137 } 138 139 pub const fn zero() -> Self { 140 Self { val: 0, frac: 0 } 141 } 142 143 fn normalize(mut self) -> Option<Self> { 144 self.val = self.val.checked_add((self.frac / FRAC_BASE) as u64)?; 145 self.frac %= FRAC_BASE; 146 if self.val > MAX_VALUE { 147 return None; 148 } 149 Some(self) 150 } 151 152 pub fn try_add(mut self, rhs: &Self) -> Option<Self> { 153 self.val = self.val.checked_add(rhs.val)?; 154 self.frac = self 155 .frac 156 .checked_add(rhs.frac) 157 .expect("amount fraction overflow should never happen with normalized amounts"); 158 self.normalize() 159 } 160 161 pub fn try_sub(mut self, rhs: &Self) -> Option<Self> { 162 if rhs.frac > self.frac { 163 self.val = self.val.checked_sub(1)?; 164 self.frac += FRAC_BASE; 165 } 166 self.val = self.val.checked_sub(rhs.val)?; 167 self.frac = self.frac.checked_sub(rhs.frac)?; 168 self.normalize() 169 } 170 171 pub fn to_amount(self, currency: &Currency) -> Amount { 172 Amount::new_decimal(currency, self) 173 } 174 } 175 176 #[derive(Debug, thiserror::Error)] 177 pub enum DecimalErrKind { 178 #[error("value overflow (must be <= {MAX_VALUE})")] 179 Overflow, 180 #[error("invalid value ({0})")] 181 InvalidValue(ParseIntError), 182 #[error("invalid fraction ({0})")] 183 InvalidFraction(ParseIntError), 184 #[error("fraction overflow (max {FRAC_BASE_NB_DIGITS} digits)")] 185 FractionOverflow, 186 } 187 188 #[derive(Debug, thiserror::Error)] 189 #[error("decimal '{decimal}' {kind}")] 190 pub struct ParseDecimalErr { 191 decimal: String, 192 pub kind: DecimalErrKind, 193 } 194 195 impl FromStr for Decimal { 196 type Err = ParseDecimalErr; 197 198 fn from_str(s: &str) -> Result<Self, Self::Err> { 199 let (value, fraction) = s.split_once('.').unwrap_or((s, "")); 200 201 // TODO use try block when stable 202 (|| { 203 let value: u64 = value.parse().map_err(DecimalErrKind::InvalidValue)?; 204 if value > MAX_VALUE { 205 return Err(DecimalErrKind::Overflow); 206 } 207 208 if fraction.len() > FRAC_BASE_NB_DIGITS as usize { 209 return Err(DecimalErrKind::FractionOverflow); 210 } 211 let fraction: u32 = if fraction.is_empty() { 212 0 213 } else { 214 fraction 215 .parse::<u32>() 216 .map_err(DecimalErrKind::InvalidFraction)? 217 * 10u32.pow(FRAC_BASE_NB_DIGITS as u32 - fraction.len() as u32) 218 }; 219 Ok(Self { 220 val: value, 221 frac: fraction, 222 }) 223 })() 224 .map_err(|kind| ParseDecimalErr { 225 decimal: s.to_owned(), 226 kind, 227 }) 228 } 229 } 230 231 impl Display for Decimal { 232 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 233 if self.frac == 0 { 234 f.write_fmt(format_args!("{}", self.val)) 235 } else { 236 let num = format!("{:08}", self.frac); 237 f.write_fmt(format_args!("{}.{}", self.val, num.trim_end_matches('0'))) 238 } 239 } 240 } 241 242 impl Debug for Decimal { 243 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 244 Display::fmt(&self, f) 245 } 246 } 247 248 impl sqlx::Type<sqlx::Postgres> for Decimal { 249 fn type_info() -> sqlx::postgres::PgTypeInfo { 250 PgTalerAmount::type_info() 251 } 252 } 253 254 impl<'q> sqlx::Encode<'q, sqlx::Postgres> for Decimal { 255 fn encode_by_ref( 256 &self, 257 buf: &mut sqlx::postgres::PgArgumentBuffer, 258 ) -> Result<sqlx::encode::IsNull, sqlx::error::BoxDynError> { 259 PgTalerAmount { 260 val: self.val as i64, 261 frac: self.frac as i32, 262 } 263 .encode_by_ref(buf) 264 } 265 } 266 267 impl<'r> sqlx::Decode<'r, sqlx::Postgres> for Decimal { 268 fn decode(value: sqlx::postgres::PgValueRef<'r>) -> Result<Self, sqlx::error::BoxDynError> { 269 let pg = PgTalerAmount::decode(value)?; 270 Ok(Self { 271 val: pg.val as u64, 272 frac: pg.frac as u32, 273 }) 274 } 275 } 276 277 #[track_caller] 278 pub fn decimal(decimal: impl AsRef<str>) -> Decimal { 279 decimal.as_ref().parse().expect("Invalid decimal constant") 280 } 281 282 /// <https://docs.taler.net/core/api-common.html#tsref-type-Amount> 283 #[derive( 284 Clone, Copy, PartialEq, Eq, serde_with::DeserializeFromStr, serde_with::SerializeDisplay, 285 )] 286 pub struct Amount { 287 pub currency: Currency, 288 pub val: u64, 289 pub frac: u32, 290 } 291 292 impl Amount { 293 pub fn new_decimal(currency: &Currency, decimal: Decimal) -> Self { 294 (*currency, decimal).into() 295 } 296 297 pub fn new(currency: &Currency, val: u64, frac: u32) -> Self { 298 Self::new_decimal(currency, Decimal { val, frac }) 299 } 300 301 pub fn max(currency: &Currency) -> Self { 302 Self::new_decimal(currency, Decimal::max()) 303 } 304 305 pub fn zero(currency: &Currency) -> Self { 306 Self::new_decimal(currency, Decimal::zero()) 307 } 308 309 pub fn is_zero(&self) -> bool { 310 self.decimal() == Decimal::zero() 311 } 312 313 pub const fn decimal(&self) -> Decimal { 314 Decimal { 315 val: self.val, 316 frac: self.frac, 317 } 318 } 319 320 pub fn normalize(self) -> Option<Self> { 321 let decimal = self.decimal().normalize()?; 322 Some((self.currency, decimal).into()) 323 } 324 325 pub fn try_add(self, rhs: &Self) -> Option<Self> { 326 assert_eq!(self.currency, rhs.currency); 327 let decimal = self.decimal().try_add(&rhs.decimal())?.normalize()?; 328 Some((self.currency, decimal).into()) 329 } 330 331 pub fn try_sub(self, rhs: &Self) -> Option<Self> { 332 assert_eq!(self.currency, rhs.currency); 333 let decimal = self.decimal().try_sub(&rhs.decimal())?.normalize()?; 334 Some((self.currency, decimal).into()) 335 } 336 } 337 338 impl From<(Currency, Decimal)> for Amount { 339 fn from((currency, decimal): (Currency, Decimal)) -> Self { 340 Self { 341 currency, 342 val: decimal.val, 343 frac: decimal.frac, 344 } 345 } 346 } 347 348 #[track_caller] 349 pub fn amount(amount: impl AsRef<str>) -> Amount { 350 amount.as_ref().parse().expect("Invalid amount constant") 351 } 352 353 #[derive(Debug, thiserror::Error)] 354 pub enum AmountErrKind { 355 #[error("invalid format")] 356 Format, 357 #[error("currency {0}")] 358 Currency(#[from] CurrencyErrorKind), 359 #[error(transparent)] 360 Decimal(#[from] DecimalErrKind), 361 } 362 363 #[derive(Debug, thiserror::Error)] 364 #[error("amount '{amount}' {kind}")] 365 pub struct ParseAmountErr { 366 amount: String, 367 pub kind: AmountErrKind, 368 } 369 370 impl FromStr for Amount { 371 type Err = ParseAmountErr; 372 373 fn from_str(s: &str) -> Result<Self, Self::Err> { 374 // TODO use try block when stable 375 (|| { 376 let (currency, amount) = s.trim().split_once(':').ok_or(AmountErrKind::Format)?; 377 let currency = currency.parse().map_err(|e: ParseCurrencyError| e.kind)?; 378 let decimal = amount.parse().map_err(|e: ParseDecimalErr| e.kind)?; 379 Ok((currency, decimal).into()) 380 })() 381 .map_err(|kind| ParseAmountErr { 382 amount: s.to_owned(), 383 kind, 384 }) 385 } 386 } 387 388 impl Display for Amount { 389 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 390 f.write_fmt(format_args!("{}:{}", self.currency, self.decimal())) 391 } 392 } 393 394 impl Debug for Amount { 395 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 396 Display::fmt(&self, f) 397 } 398 } 399 400 impl sqlx::Type<sqlx::Postgres> for Amount { 401 fn type_info() -> sqlx::postgres::PgTypeInfo { 402 PgTalerAmount::type_info() 403 } 404 } 405 406 impl<'q> sqlx::Encode<'q, sqlx::Postgres> for Amount { 407 fn encode_by_ref( 408 &self, 409 buf: &mut sqlx::postgres::PgArgumentBuffer, 410 ) -> Result<sqlx::encode::IsNull, sqlx::error::BoxDynError> { 411 self.decimal().encode_by_ref(buf) 412 } 413 } 414 415 #[test] 416 fn test_amount_parse() { 417 const TALER_AMOUNT_FRAC_BASE: u32 = 100000000; 418 // https://git.taler.net/exchange.git/tree/src/util/test_amount.c 419 420 const INVALID_AMOUNTS: [&str; 6] = [ 421 "EUR:4a", // non-numeric, 422 "EUR:4.4a", // non-numeric 423 "EUR:4.a4", // non-numeric 424 ":4.a4", // no currency 425 "EUR:4.123456789", // precision to high 426 "EUR:1234567890123456789012345678901234567890123456789012345678901234567890", // value to big 427 ]; 428 429 for str in INVALID_AMOUNTS { 430 let amount = Amount::from_str(str); 431 assert!(amount.is_err(), "invalid {} got {:?}", str, &amount); 432 } 433 434 let eur: Currency = "EUR".parse().unwrap(); 435 let local: Currency = "LOCAL".parse().unwrap(); 436 let valid_amounts: Vec<(&str, &str, Amount)> = vec![ 437 ("EUR:4", "EUR:4", Amount::new(&eur, 4, 0)), // without fraction 438 ( 439 "EUR:0.02", 440 "EUR:0.02", 441 Amount::new(&eur, 0, TALER_AMOUNT_FRAC_BASE / 100 * 2), 442 ), // leading zero fraction 443 ( 444 " EUR:4.12", 445 "EUR:4.12", 446 Amount::new(&eur, 4, TALER_AMOUNT_FRAC_BASE / 100 * 12), 447 ), // leading space and fraction 448 ( 449 " LOCAL:4444.1000", 450 "LOCAL:4444.1", 451 Amount::new(&local, 4444, TALER_AMOUNT_FRAC_BASE / 10), 452 ), // local currency 453 ]; 454 for (raw, expected, goal) in valid_amounts { 455 let amount = Amount::from_str(raw); 456 assert!(amount.is_ok(), "Valid {} got {:?}", raw, amount); 457 assert_eq!( 458 *amount.as_ref().unwrap(), 459 goal, 460 "Expected {:?} got {:?} for {}", 461 goal, 462 amount, 463 raw 464 ); 465 let amount = amount.unwrap(); 466 let str = amount.to_string(); 467 assert_eq!(str, expected); 468 assert_eq!(amount, Amount::from_str(&str).unwrap(), "{str}"); 469 } 470 } 471 472 #[test] 473 fn test_amount_add() { 474 let eur: Currency = "EUR".parse().unwrap(); 475 assert_eq!( 476 Amount::max(&eur).try_add(&Amount::zero(&eur)), 477 Some(Amount::max(&eur)) 478 ); 479 assert_eq!( 480 Amount::zero(&eur).try_add(&Amount::zero(&eur)), 481 Some(Amount::zero(&eur)) 482 ); 483 assert_eq!( 484 amount("EUR:6.41").try_add(&amount("EUR:4.69")), 485 Some(amount("EUR:11.1")) 486 ); 487 assert_eq!( 488 amount(format!("EUR:{MAX_VALUE}")).try_add(&amount("EUR:0.99999999")), 489 Some(Amount::max(&eur)) 490 ); 491 492 assert_eq!( 493 amount(format!("EUR:{}", MAX_VALUE - 5)).try_add(&amount("EUR:6")), 494 None 495 ); 496 assert_eq!( 497 Amount::new(&eur, u64::MAX, 0).try_add(&amount("EUR:1")), 498 None 499 ); 500 assert_eq!( 501 amount(format!("EUR:{}.{}", MAX_VALUE - 5, FRAC_BASE - 1)) 502 .try_add(&amount("EUR:5.00000002")), 503 None 504 ); 505 } 506 507 #[test] 508 fn test_amount_normalize() { 509 let eur: Currency = "EUR".parse().unwrap(); 510 assert_eq!( 511 Amount::new(&eur, 4, 2 * FRAC_BASE).normalize(), 512 Some(amount("EUR:6")) 513 ); 514 assert_eq!( 515 Amount::new(&eur, 4, 2 * FRAC_BASE + 1).normalize(), 516 Some(amount("EUR:6.00000001")) 517 ); 518 assert_eq!( 519 Amount::new(&eur, MAX_VALUE, FRAC_BASE - 1).normalize(), 520 Some(Amount::new(&eur, MAX_VALUE, FRAC_BASE - 1)) 521 ); 522 assert_eq!(Amount::new(&eur, u64::MAX, FRAC_BASE).normalize(), None); 523 assert_eq!(Amount::new(&eur, MAX_VALUE, FRAC_BASE).normalize(), None); 524 525 for amount in [Amount::max(&eur), Amount::zero(&eur)] { 526 assert_eq!(amount.normalize(), Some(amount)) 527 } 528 }