taler-rust

GNU Taler code in Rust. Largely core banking integrations.
Log | Files | Refs | Submodules | README | LICENSE

notification.rs (5402B)


      1 /*
      2   This file is part of TALER
      3   Copyright (C) 2024, 2025, 2026 Taler Systems SA
      4 
      5   TALER is free software; you can redistribute it and/or modify it under the
      6   terms of the GNU Affero General Public License as published by the Free Software
      7   Foundation; either version 3, or (at your option) any later version.
      8 
      9   TALER is distributed in the hope that it will be useful, but WITHOUT ANY
     10   WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
     11   A PARTICULAR PURPOSE.  See the GNU Affero General Public License for more details.
     12 
     13   You should have received a copy of the GNU Affero General Public License along with
     14   TALER; see the file COPYING.  If not, see <http://www.gnu.org/licenses/>
     15 */
     16 
     17 use std::{hash::Hash, sync::Arc};
     18 
     19 use dashmap::DashMap;
     20 use tokio::sync::watch::{self, Receiver};
     21 
     22 pub mod de;
     23 
     24 /// Listen for many postgres notification channels using a single connection
     25 #[macro_export]
     26 macro_rules! notification_listener {
     27     ($pool: expr, $($channel:expr => ($($arg:ident: $type:ty),*) $lambda:block),*$(,)?) => {{
     28         let mut jitter = ::taler_common::ExpoBackoffDecorr::default();
     29         loop {
     30             if let ::sqlx::Result::<(), ::sqlx::Error>::Err(e) = async {
     31                 let mut listener = ::sqlx::postgres::PgListener::connect_with($pool).await?;
     32                 listener.listen_all([$($channel,)*]).await?;
     33                 jitter.reset();
     34                 loop {
     35                     let notification = listener.recv().await?;
     36                     tracing::debug!(target: "db-watcher",
     37                         "notif: {} - {}",
     38                         notification.channel(),
     39                         notification.payload()
     40                     );
     41                     match notification.channel() {
     42                         $($channel => {
     43                             match $crate::notification::de::from_str::<($($type,)*)>(notification.payload()) {
     44                                 Ok(($($arg,)*)) => {
     45                                     $lambda
     46                                 }
     47                                 Err(e) => {
     48                                     tracing::error!(target: "db-watcher",
     49                                         "notif: {} {e} - {}",
     50                                         notification.channel(),
     51                                         notification.payload()
     52                                     );
     53                                 }
     54                             }
     55                         }),*
     56                         unknown => unreachable!("{unknown}"),
     57                     }
     58                 }
     59             }.await {
     60                 tokio::time::sleep(jitter.backoff()).await;
     61                 tracing::error!(target: "db-watcher", "{e}");
     62             };
     63         }
     64     }}
     65 }
     66 
     67 pub use notification_listener;
     68 
     69 #[derive(Default, Clone)]
     70 pub struct NotificationChannel<K: Eq + Hash, V> {
     71     map: Arc<DashMap<K, watch::Sender<V>>>,
     72 }
     73 
     74 impl<K: Eq + Hash, V> NotificationChannel<K, V> {
     75     pub fn new() -> Self {
     76         Self {
     77             map: Arc::new(DashMap::new()),
     78         }
     79     }
     80 }
     81 
     82 impl<K: Eq + Hash + Clone, V: Default> NotificationChannel<K, V> {
     83     /// Subscribe to events for a specific username.
     84     /// Creates the channel lazily on first subscriber.
     85     pub fn subscribe(&self, key: K) -> watch::Receiver<V> {
     86         self.map
     87             .entry(key)
     88             .or_insert_with(|| {
     89                 let (sender, _) = watch::channel(V::default());
     90                 sender
     91             })
     92             .subscribe()
     93     }
     94 
     95     /// Dispatch a notification to the right user's channel.
     96     pub fn dispatch(&self, key: &K, event: V) {
     97         if let Some(tx) = self.map.get(key) {
     98             tx.send(event).ok();
     99             /*// send_if_modified avoids waking receivers if nothing changed
    100             let _ = tx.send_if_modified(|current| {
    101                 *current = Some(event.clone());
    102                 true
    103             });*/
    104         }
    105     }
    106 
    107     /// Prune channels where all receivers have been dropped.
    108     pub fn prune(&self) {
    109         self.map.retain(|_, tx| tx.receiver_count() > 0);
    110     }
    111 }
    112 pub fn dummy_listen<T: Default>() -> Receiver<T> {
    113     tokio::sync::watch::channel(T::default()).1
    114 }
    115 
    116 #[tokio::test]
    117 async fn channel_gc() {
    118     use std::time::Duration;
    119 
    120     let channel = NotificationChannel::default();
    121     assert_eq!(0, channel.map.len());
    122 
    123     // Clean in future
    124     let mut listener = channel.subscribe("test");
    125     assert_eq!(1, channel.map.len());
    126     tokio::time::timeout(Duration::from_millis(1), async move {
    127         listener.wait_for(|it| it == 42).await.unwrap();
    128     })
    129     .await
    130     .unwrap_err();
    131     channel.prune();
    132     assert_eq!(0, channel.map.len());
    133 
    134     // Clean on drop
    135     let mut first = channel.subscribe("test");
    136     let second = channel.subscribe("test");
    137     assert_eq!(1, channel.map.len());
    138     tokio::time::timeout(Duration::from_millis(1), async move {
    139         first.wait_for(|it| it == 42).await.unwrap();
    140     })
    141     .await
    142     .unwrap_err();
    143     assert_eq!(1, channel.map.len());
    144     drop(second);
    145     channel.prune();
    146     assert_eq!(0, channel.map.len());
    147 }
    148 
    149 #[tokio::test]
    150 async fn wake() {
    151     let channel = NotificationChannel::default();
    152     let mut listener = channel.subscribe("test");
    153     let task = tokio::spawn(async move {
    154         listener.wait_for(|it| *it == 42).await.unwrap();
    155     });
    156     channel.dispatch(&"test", 42);
    157     task.await.unwrap();
    158 }