notification.rs (5402B)
1 /* 2 This file is part of TALER 3 Copyright (C) 2024, 2025, 2026 Taler Systems SA 4 5 TALER is free software; you can redistribute it and/or modify it under the 6 terms of the GNU Affero General Public License as published by the Free Software 7 Foundation; either version 3, or (at your option) any later version. 8 9 TALER is distributed in the hope that it will be useful, but WITHOUT ANY 10 WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR 11 A PARTICULAR PURPOSE. See the GNU Affero General Public License for more details. 12 13 You should have received a copy of the GNU Affero General Public License along with 14 TALER; see the file COPYING. If not, see <http://www.gnu.org/licenses/> 15 */ 16 17 use std::{hash::Hash, sync::Arc}; 18 19 use dashmap::DashMap; 20 use tokio::sync::watch::{self, Receiver}; 21 22 pub mod de; 23 24 /// Listen for many postgres notification channels using a single connection 25 #[macro_export] 26 macro_rules! notification_listener { 27 ($pool: expr, $($channel:expr => ($($arg:ident: $type:ty),*) $lambda:block),*$(,)?) => {{ 28 let mut jitter = ::taler_common::ExpoBackoffDecorr::default(); 29 loop { 30 if let ::sqlx::Result::<(), ::sqlx::Error>::Err(e) = async { 31 let mut listener = ::sqlx::postgres::PgListener::connect_with($pool).await?; 32 listener.listen_all([$($channel,)*]).await?; 33 jitter.reset(); 34 loop { 35 let notification = listener.recv().await?; 36 tracing::debug!(target: "db-watcher", 37 "notif: {} - {}", 38 notification.channel(), 39 notification.payload() 40 ); 41 match notification.channel() { 42 $($channel => { 43 match $crate::notification::de::from_str::<($($type,)*)>(notification.payload()) { 44 Ok(($($arg,)*)) => { 45 $lambda 46 } 47 Err(e) => { 48 tracing::error!(target: "db-watcher", 49 "notif: {} {e} - {}", 50 notification.channel(), 51 notification.payload() 52 ); 53 } 54 } 55 }),* 56 unknown => unreachable!("{unknown}"), 57 } 58 } 59 }.await { 60 tokio::time::sleep(jitter.backoff()).await; 61 tracing::error!(target: "db-watcher", "{e}"); 62 }; 63 } 64 }} 65 } 66 67 pub use notification_listener; 68 69 #[derive(Default, Clone)] 70 pub struct NotificationChannel<K: Eq + Hash, V> { 71 map: Arc<DashMap<K, watch::Sender<V>>>, 72 } 73 74 impl<K: Eq + Hash, V> NotificationChannel<K, V> { 75 pub fn new() -> Self { 76 Self { 77 map: Arc::new(DashMap::new()), 78 } 79 } 80 } 81 82 impl<K: Eq + Hash + Clone, V: Default> NotificationChannel<K, V> { 83 /// Subscribe to events for a specific username. 84 /// Creates the channel lazily on first subscriber. 85 pub fn subscribe(&self, key: K) -> watch::Receiver<V> { 86 self.map 87 .entry(key) 88 .or_insert_with(|| { 89 let (sender, _) = watch::channel(V::default()); 90 sender 91 }) 92 .subscribe() 93 } 94 95 /// Dispatch a notification to the right user's channel. 96 pub fn dispatch(&self, key: &K, event: V) { 97 if let Some(tx) = self.map.get(key) { 98 tx.send(event).ok(); 99 /*// send_if_modified avoids waking receivers if nothing changed 100 let _ = tx.send_if_modified(|current| { 101 *current = Some(event.clone()); 102 true 103 });*/ 104 } 105 } 106 107 /// Prune channels where all receivers have been dropped. 108 pub fn prune(&self) { 109 self.map.retain(|_, tx| tx.receiver_count() > 0); 110 } 111 } 112 pub fn dummy_listen<T: Default>() -> Receiver<T> { 113 tokio::sync::watch::channel(T::default()).1 114 } 115 116 #[tokio::test] 117 async fn channel_gc() { 118 use std::time::Duration; 119 120 let channel = NotificationChannel::default(); 121 assert_eq!(0, channel.map.len()); 122 123 // Clean in future 124 let mut listener = channel.subscribe("test"); 125 assert_eq!(1, channel.map.len()); 126 tokio::time::timeout(Duration::from_millis(1), async move { 127 listener.wait_for(|it| it == 42).await.unwrap(); 128 }) 129 .await 130 .unwrap_err(); 131 channel.prune(); 132 assert_eq!(0, channel.map.len()); 133 134 // Clean on drop 135 let mut first = channel.subscribe("test"); 136 let second = channel.subscribe("test"); 137 assert_eq!(1, channel.map.len()); 138 tokio::time::timeout(Duration::from_millis(1), async move { 139 first.wait_for(|it| it == 42).await.unwrap(); 140 }) 141 .await 142 .unwrap_err(); 143 assert_eq!(1, channel.map.len()); 144 drop(second); 145 channel.prune(); 146 assert_eq!(0, channel.map.len()); 147 } 148 149 #[tokio::test] 150 async fn wake() { 151 let channel = NotificationChannel::default(); 152 let mut listener = channel.subscribe("test"); 153 let task = tokio::spawn(async move { 154 listener.wait_for(|it| *it == 42).await.unwrap(); 155 }); 156 channel.dispatch(&"test", 42); 157 task.await.unwrap(); 158 }