amount.rs (14971B)
1 /* 2 This file is part of TALER 3 Copyright (C) 2024, 2025, 2026 Taler Systems SA 4 5 TALER is free software; you can redistribute it and/or modify it under the 6 terms of the GNU Affero General Public License as published by the Free Software 7 Foundation; either version 3, or (at your option) any later version. 8 9 TALER is distributed in the hope that it will be useful, but WITHOUT ANY 10 WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR 11 A PARTICULAR PURPOSE. See the GNU Affero General Public License for more details. 12 13 You should have received a copy of the GNU Affero General Public License along with 14 TALER; see the file COPYING. If not, see <http://www.gnu.org/licenses/> 15 */ 16 17 //! Type for the Taler Amount <https://docs.taler.net/core/api-common.html#tsref-type-Amount> 18 19 use std::{ 20 fmt::{Debug, Display}, 21 num::ParseIntError, 22 str::FromStr, 23 }; 24 25 use super::utils::InlineStr; 26 27 /** Number of characters we use to represent currency names */ 28 // We use the same value than the exchange -1 because we use a byte for the len instead of 0 termination 29 pub const CURRENCY_LEN: usize = 11; 30 31 /** Maximum legal value for an amount, based on IEEE double */ 32 pub const MAX_VALUE: u64 = 2 << 52; 33 34 /** The number of digits in a fraction part of an amount */ 35 pub const FRAC_BASE_NB_DIGITS: u8 = 8; 36 37 /** The fraction part of an amount represents which fraction of the value */ 38 pub const FRAC_BASE: u32 = 10u32.pow(FRAC_BASE_NB_DIGITS as u32); 39 40 #[derive(Clone, PartialEq, Eq, serde_with::DeserializeFromStr, serde_with::SerializeDisplay)] 41 /// Inlined ISO 4217 currency string 42 pub struct Currency(InlineStr<CURRENCY_LEN>); 43 44 impl AsRef<str> for Currency { 45 fn as_ref(&self) -> &str { 46 self.0.as_ref() 47 } 48 } 49 50 #[derive(Debug, thiserror::Error)] 51 pub enum CurrencyErrorKind { 52 #[error("contains illegal characters (only A-Z allowed)")] 53 Invalid, 54 #[error("too long (max {CURRENCY_LEN} chars)")] 55 Big, 56 #[error("is empty")] 57 Empty, 58 } 59 60 #[derive(Debug, thiserror::Error)] 61 #[error("currency code name '{currency}' {kind}")] 62 pub struct ParseCurrencyError { 63 currency: String, 64 pub kind: CurrencyErrorKind, 65 } 66 67 impl FromStr for Currency { 68 type Err = ParseCurrencyError; 69 70 fn from_str(s: &str) -> Result<Self, Self::Err> { 71 let bytes = s.as_bytes(); 72 let len = bytes.len(); 73 if bytes.is_empty() { 74 Err(CurrencyErrorKind::Empty) 75 } else if len > CURRENCY_LEN { 76 Err(CurrencyErrorKind::Big) 77 } else if !bytes.iter().all(|c| c.is_ascii_uppercase()) { 78 Err(CurrencyErrorKind::Invalid) 79 } else { 80 Ok(Self(InlineStr::copy_from_slice(bytes))) 81 } 82 .map_err(|kind| ParseCurrencyError { 83 currency: s.to_owned(), 84 kind, 85 }) 86 } 87 } 88 89 impl Debug for Currency { 90 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 91 Debug::fmt(&self.as_ref(), f) 92 } 93 } 94 95 impl Display for Currency { 96 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 97 Display::fmt(&self.as_ref(), f) 98 } 99 } 100 101 #[derive(sqlx::Type)] 102 #[sqlx(type_name = "taler_amount")] 103 struct PgTalerAmount { 104 pub val: i64, 105 pub frac: i32, 106 } 107 108 #[derive( 109 Debug, Clone, Copy, PartialEq, Eq, serde_with::DeserializeFromStr, serde_with::SerializeDisplay, 110 )] 111 pub struct Decimal { 112 /** Integer part */ 113 pub val: u64, 114 /** Factional part, multiple of FRAC_BASE */ 115 pub frac: u32, 116 } 117 118 impl Decimal { 119 pub fn new(val: u64, frac: u32) -> Self { 120 Self { val, frac } 121 } 122 123 pub const fn max() -> Self { 124 Self { 125 val: MAX_VALUE, 126 frac: FRAC_BASE - 1, 127 } 128 } 129 130 pub const fn zero() -> Self { 131 Self { val: 0, frac: 0 } 132 } 133 134 fn normalize(mut self) -> Option<Self> { 135 self.val = self.val.checked_add((self.frac / FRAC_BASE) as u64)?; 136 self.frac %= FRAC_BASE; 137 if self.val > MAX_VALUE { 138 return None; 139 } 140 Some(self) 141 } 142 143 pub fn try_add(mut self, rhs: &Self) -> Option<Self> { 144 self.val = self.val.checked_add(rhs.val)?; 145 self.frac = self 146 .frac 147 .checked_add(rhs.frac) 148 .expect("amount fraction overflow should never happen with normalized amounts"); 149 self.normalize() 150 } 151 152 pub fn try_sub(mut self, rhs: &Self) -> Option<Self> { 153 if rhs.frac > self.frac { 154 self.val = self.val.checked_sub(1)?; 155 self.frac += FRAC_BASE; 156 } 157 self.val = self.val.checked_sub(rhs.val)?; 158 self.frac = self.frac.checked_sub(rhs.frac)?; 159 self.normalize() 160 } 161 162 pub fn to_amount(self, currency: &Currency) -> Amount { 163 Amount::new_decimal(currency, self) 164 } 165 } 166 167 #[derive(Debug, thiserror::Error)] 168 pub enum DecimalErrKind { 169 #[error("value overflow (must be <= {MAX_VALUE})")] 170 Overflow, 171 #[error("invalid value ({0})")] 172 InvalidValue(ParseIntError), 173 #[error("invalid fraction ({0})")] 174 InvalidFraction(ParseIntError), 175 #[error("fraction overflow (max {FRAC_BASE_NB_DIGITS} digits)")] 176 FractionOverflow, 177 } 178 179 #[derive(Debug, thiserror::Error)] 180 #[error("decimal '{decimal}' {kind}")] 181 pub struct ParseDecimalErr { 182 decimal: String, 183 pub kind: DecimalErrKind, 184 } 185 186 impl FromStr for Decimal { 187 type Err = ParseDecimalErr; 188 189 fn from_str(s: &str) -> Result<Self, Self::Err> { 190 let (value, fraction) = s.split_once('.').unwrap_or((s, "")); 191 192 // TODO use try block when stable 193 (|| { 194 let value: u64 = value.parse().map_err(DecimalErrKind::InvalidValue)?; 195 if value > MAX_VALUE { 196 return Err(DecimalErrKind::Overflow); 197 } 198 199 if fraction.len() > FRAC_BASE_NB_DIGITS as usize { 200 return Err(DecimalErrKind::FractionOverflow); 201 } 202 let fraction: u32 = if fraction.is_empty() { 203 0 204 } else { 205 fraction 206 .parse::<u32>() 207 .map_err(DecimalErrKind::InvalidFraction)? 208 * 10u32.pow(FRAC_BASE_NB_DIGITS as u32 - fraction.len() as u32) 209 }; 210 Ok(Self { 211 val: value, 212 frac: fraction, 213 }) 214 })() 215 .map_err(|kind| ParseDecimalErr { 216 decimal: s.to_owned(), 217 kind, 218 }) 219 } 220 } 221 222 impl Display for Decimal { 223 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 224 if self.frac == 0 { 225 f.write_fmt(format_args!("{}", self.val)) 226 } else { 227 let num = format!("{:08}", self.frac); 228 f.write_fmt(format_args!("{}.{}", self.val, num.trim_end_matches('0'))) 229 } 230 } 231 } 232 233 impl sqlx::Type<sqlx::Postgres> for Decimal { 234 fn type_info() -> sqlx::postgres::PgTypeInfo { 235 PgTalerAmount::type_info() 236 } 237 } 238 239 impl<'q> sqlx::Encode<'q, sqlx::Postgres> for Decimal { 240 fn encode_by_ref( 241 &self, 242 buf: &mut sqlx::postgres::PgArgumentBuffer, 243 ) -> Result<sqlx::encode::IsNull, sqlx::error::BoxDynError> { 244 PgTalerAmount { 245 val: self.val as i64, 246 frac: self.frac as i32, 247 } 248 .encode_by_ref(buf) 249 } 250 } 251 252 impl<'r> sqlx::Decode<'r, sqlx::Postgres> for Decimal { 253 fn decode(value: sqlx::postgres::PgValueRef<'r>) -> Result<Self, sqlx::error::BoxDynError> { 254 let pg = PgTalerAmount::decode(value)?; 255 Ok(Self { 256 val: pg.val as u64, 257 frac: pg.frac as u32, 258 }) 259 } 260 } 261 262 #[track_caller] 263 pub fn decimal(decimal: impl AsRef<str>) -> Decimal { 264 decimal.as_ref().parse().expect("Invalid decimal constant") 265 } 266 267 /// <https://docs.taler.net/core/api-common.html#tsref-type-Amount> 268 #[derive( 269 Debug, Clone, PartialEq, Eq, serde_with::DeserializeFromStr, serde_with::SerializeDisplay, 270 )] 271 pub struct Amount { 272 pub currency: Currency, 273 pub val: u64, 274 pub frac: u32, 275 } 276 277 impl Amount { 278 pub fn new_decimal(currency: &Currency, decimal: Decimal) -> Self { 279 (currency.clone(), decimal).into() 280 } 281 282 pub fn new(currency: &Currency, val: u64, frac: u32) -> Self { 283 Self::new_decimal(currency, Decimal { val, frac }) 284 } 285 286 pub fn max(currency: &Currency) -> Self { 287 Self::new_decimal(currency, Decimal::max()) 288 } 289 290 pub fn zero(currency: &Currency) -> Self { 291 Self::new_decimal(currency, Decimal::zero()) 292 } 293 294 pub const fn decimal(&self) -> Decimal { 295 Decimal { 296 val: self.val, 297 frac: self.frac, 298 } 299 } 300 301 pub fn normalize(self) -> Option<Self> { 302 let decimal = self.decimal().normalize()?; 303 Some((self.currency, decimal).into()) 304 } 305 306 pub fn try_add(self, rhs: &Self) -> Option<Self> { 307 assert_eq!(self.currency, rhs.currency); 308 let decimal = self.decimal().try_add(&rhs.decimal())?.normalize()?; 309 Some((self.currency, decimal).into()) 310 } 311 312 pub fn try_sub(self, rhs: &Self) -> Option<Self> { 313 assert_eq!(self.currency, rhs.currency); 314 let decimal = self.decimal().try_sub(&rhs.decimal())?.normalize()?; 315 Some((self.currency, decimal).into()) 316 } 317 } 318 319 impl From<(Currency, Decimal)> for Amount { 320 fn from((currency, decimal): (Currency, Decimal)) -> Self { 321 Self { 322 currency, 323 val: decimal.val, 324 frac: decimal.frac, 325 } 326 } 327 } 328 329 #[track_caller] 330 pub fn amount(amount: impl AsRef<str>) -> Amount { 331 amount.as_ref().parse().expect("Invalid amount constant") 332 } 333 334 #[derive(Debug, thiserror::Error)] 335 pub enum AmountErrKind { 336 #[error("invalid format")] 337 Format, 338 #[error("currency {0}")] 339 Currency(#[from] CurrencyErrorKind), 340 #[error(transparent)] 341 Decimal(#[from] DecimalErrKind), 342 } 343 344 #[derive(Debug, thiserror::Error)] 345 #[error("amount '{amount}' {kind}")] 346 pub struct ParseAmountErr { 347 amount: String, 348 pub kind: AmountErrKind, 349 } 350 351 impl FromStr for Amount { 352 type Err = ParseAmountErr; 353 354 fn from_str(s: &str) -> Result<Self, Self::Err> { 355 // TODO use try block when stable 356 (|| { 357 let (currency, amount) = s.trim().split_once(':').ok_or(AmountErrKind::Format)?; 358 let currency = currency.parse().map_err(|e: ParseCurrencyError| e.kind)?; 359 let decimal = amount.parse().map_err(|e: ParseDecimalErr| e.kind)?; 360 Ok((currency, decimal).into()) 361 })() 362 .map_err(|kind| ParseAmountErr { 363 amount: s.to_owned(), 364 kind, 365 }) 366 } 367 } 368 369 impl Display for Amount { 370 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 371 f.write_fmt(format_args!("{}:{}", self.currency, self.decimal())) 372 } 373 } 374 375 impl sqlx::Type<sqlx::Postgres> for Amount { 376 fn type_info() -> sqlx::postgres::PgTypeInfo { 377 PgTalerAmount::type_info() 378 } 379 } 380 381 impl<'q> sqlx::Encode<'q, sqlx::Postgres> for Amount { 382 fn encode_by_ref( 383 &self, 384 buf: &mut sqlx::postgres::PgArgumentBuffer, 385 ) -> Result<sqlx::encode::IsNull, sqlx::error::BoxDynError> { 386 self.decimal().encode_by_ref(buf) 387 } 388 } 389 390 #[test] 391 fn test_amount_parse() { 392 const TALER_AMOUNT_FRAC_BASE: u32 = 100000000; 393 // https://git.taler.net/exchange.git/tree/src/util/test_amount.c 394 395 const INVALID_AMOUNTS: [&str; 6] = [ 396 "EUR:4a", // non-numeric, 397 "EUR:4.4a", // non-numeric 398 "EUR:4.a4", // non-numeric 399 ":4.a4", // no currency 400 "EUR:4.123456789", // precision to high 401 "EUR:1234567890123456789012345678901234567890123456789012345678901234567890", // value to big 402 ]; 403 404 for str in INVALID_AMOUNTS { 405 let amount = Amount::from_str(str); 406 assert!(amount.is_err(), "invalid {} got {:?}", str, &amount); 407 } 408 409 let eur: Currency = "EUR".parse().unwrap(); 410 let local: Currency = "LOCAL".parse().unwrap(); 411 let valid_amounts: Vec<(&str, &str, Amount)> = vec![ 412 ("EUR:4", "EUR:4", Amount::new(&eur, 4, 0)), // without fraction 413 ( 414 "EUR:0.02", 415 "EUR:0.02", 416 Amount::new(&eur, 0, TALER_AMOUNT_FRAC_BASE / 100 * 2), 417 ), // leading zero fraction 418 ( 419 " EUR:4.12", 420 "EUR:4.12", 421 Amount::new(&eur, 4, TALER_AMOUNT_FRAC_BASE / 100 * 12), 422 ), // leading space and fraction 423 ( 424 " LOCAL:4444.1000", 425 "LOCAL:4444.1", 426 Amount::new(&local, 4444, TALER_AMOUNT_FRAC_BASE / 10), 427 ), // local currency 428 ]; 429 for (raw, expected, goal) in valid_amounts { 430 let amount = Amount::from_str(raw); 431 assert!(amount.is_ok(), "Valid {} got {:?}", raw, amount); 432 assert_eq!( 433 *amount.as_ref().unwrap(), 434 goal, 435 "Expected {:?} got {:?} for {}", 436 goal, 437 amount, 438 raw 439 ); 440 let amount = amount.unwrap(); 441 let str = amount.to_string(); 442 assert_eq!(str, expected); 443 assert_eq!(amount, Amount::from_str(&str).unwrap(), "{str}"); 444 } 445 } 446 447 #[test] 448 fn test_amount_add() { 449 let eur: Currency = "EUR".parse().unwrap(); 450 assert_eq!( 451 Amount::max(&eur).try_add(&Amount::zero(&eur)), 452 Some(Amount::max(&eur)) 453 ); 454 assert_eq!( 455 Amount::zero(&eur).try_add(&Amount::zero(&eur)), 456 Some(Amount::zero(&eur)) 457 ); 458 assert_eq!( 459 amount("EUR:6.41").try_add(&amount("EUR:4.69")), 460 Some(amount("EUR:11.1")) 461 ); 462 assert_eq!( 463 amount(format!("EUR:{MAX_VALUE}")).try_add(&amount("EUR:0.99999999")), 464 Some(Amount::max(&eur)) 465 ); 466 467 assert_eq!( 468 amount(format!("EUR:{}", MAX_VALUE - 5)).try_add(&amount("EUR:6")), 469 None 470 ); 471 assert_eq!( 472 Amount::new(&eur, u64::MAX, 0).try_add(&amount("EUR:1")), 473 None 474 ); 475 assert_eq!( 476 amount(format!("EUR:{}.{}", MAX_VALUE - 5, FRAC_BASE - 1)) 477 .try_add(&amount("EUR:5.00000002")), 478 None 479 ); 480 } 481 482 #[test] 483 fn test_amount_normalize() { 484 let eur: Currency = "EUR".parse().unwrap(); 485 assert_eq!( 486 Amount::new(&eur, 4, 2 * FRAC_BASE).normalize(), 487 Some(amount("EUR:6")) 488 ); 489 assert_eq!( 490 Amount::new(&eur, 4, 2 * FRAC_BASE + 1).normalize(), 491 Some(amount("EUR:6.00000001")) 492 ); 493 assert_eq!( 494 Amount::new(&eur, MAX_VALUE, FRAC_BASE - 1).normalize(), 495 Some(Amount::new(&eur, MAX_VALUE, FRAC_BASE - 1)) 496 ); 497 assert_eq!(Amount::new(&eur, u64::MAX, FRAC_BASE).normalize(), None); 498 assert_eq!(Amount::new(&eur, MAX_VALUE, FRAC_BASE).normalize(), None); 499 500 for amount in [Amount::max(&eur), Amount::zero(&eur)] { 501 assert_eq!(amount.clone().normalize(), Some(amount)) 502 } 503 }