diff --git a/hyperactor_telemetry/src/lib.rs b/hyperactor_telemetry/src/lib.rs index 399ff46c8..60adaeeba 100644 --- a/hyperactor_telemetry/src/lib.rs +++ b/hyperactor_telemetry/src/lib.rs @@ -67,6 +67,7 @@ use std::io::Write; use std::str::FromStr; use std::sync::Arc; use std::sync::Mutex; +use std::sync::OnceLock; use std::time::Instant; use lazy_static::lazy_static; @@ -102,6 +103,35 @@ use crate::config::USE_UNIFIED_LAYER; use crate::recorder::Recorder; use crate::sqlite::get_reloadable_sqlite_layer; +static SHUTDOWN_HOOKS: OnceLock>>> = OnceLock::new(); + +fn get_shutdown_hooks() -> &'static Mutex>> { + SHUTDOWN_HOOKS.get_or_init(|| Mutex::new(Vec::new())) +} + +/// Register a callback to be invoked during telemetry shutdown. +/// This is useful for components that need to flush buffers or clean up +/// before the process exits. +pub(crate) fn register_shutdown_hook(hook: F) +where + F: FnOnce() + Send + 'static, +{ + if let Ok(mut hooks) = get_shutdown_hooks().lock() { + hooks.push(Box::new(hook)); + } +} + +/// Shutdown the telemetry system, invoking all registered shutdown hooks. +/// This should be called before process exit to ensure all telemetry is flushed. +/// Safe to call multiple times; hooks are only invoked once. +pub fn shutdown_telemetry() { + if let Ok(mut hooks) = get_shutdown_hooks().lock() { + for hook in hooks.drain(..) { + hook(); + } + } +} + #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] pub struct TelemetrySample { fields: Vec<(String, String)>, diff --git a/hyperactor_telemetry/src/trace_dispatcher.rs b/hyperactor_telemetry/src/trace_dispatcher.rs index 75c512e50..e9aab2286 100644 --- a/hyperactor_telemetry/src/trace_dispatcher.rs +++ b/hyperactor_telemetry/src/trace_dispatcher.rs @@ -11,10 +11,10 @@ //! thread. use std::sync::Arc; +use std::sync::atomic::AtomicBool; use std::sync::atomic::AtomicU64; use std::sync::atomic::Ordering; use std::sync::mpsc; -use std::thread::JoinHandle; use std::time::Duration; use std::time::SystemTime; @@ -26,7 +26,10 @@ use tracing_subscriber::layer::Context; use tracing_subscriber::layer::Layer; use tracing_subscriber::registry::LookupSpan; +use crate::register_shutdown_hook; + const QUEUE_CAPACITY: usize = 100_000; +const SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5); /// Unified representation of a trace event captured from the tracing layer. /// This is captured once on the application thread, then sent to the background @@ -122,18 +125,13 @@ pub(crate) trait TraceEventSink: Send + 'static { /// The trace event dispatcher that captures events once and dispatches to multiple sinks /// on a background thread. pub struct TraceEventDispatcher { - sender: Option>, + sender: mpsc::SyncSender, /// Separate channel so we are always notified of when the main queue is full and events are being dropped. - dropped_sender: Option>, - _worker_handle: WorkerHandle, + dropped_sender: mpsc::Sender, max_level: Option, dropped_events: Arc, } -struct WorkerHandle { - join_handle: Option>, -} - impl TraceEventDispatcher { /// Create a new trace event dispatcher with the given sinks. /// Uses a bounded channel (capacity QUEUE_CAPACITY) to ensure telemetry never blocks @@ -153,105 +151,120 @@ impl TraceEventDispatcher { let dropped_events = Arc::new(AtomicU64::new(0)); let dropped_events_worker = Arc::clone(&dropped_events); + let shutdown_flag = Arc::new(AtomicBool::new(false)); + let shutdown_flag_worker = Arc::clone(&shutdown_flag); + + let (done_tx, done_rx) = mpsc::sync_channel::<()>(1); + let worker_handle = std::thread::Builder::new() .name("telemetry-worker".into()) .spawn(move || { - worker_loop(receiver, dropped_receiver, sinks, dropped_events_worker); + worker_loop( + receiver, + dropped_receiver, + sinks, + dropped_events_worker, + shutdown_flag_worker, + ); + let _ = done_tx.send(()); }) .expect("failed to spawn telemetry worker thread"); + register_shutdown_hook(move || { + shutdown_flag.store(true, Ordering::Release); + + if done_rx.recv_timeout(SHUTDOWN_TIMEOUT).is_err() { + eprintln!( + "[telemetry] WARNING: worker thread did not exit within {:?}, continuing shutdown", + SHUTDOWN_TIMEOUT + ); + return; + } + + // This join should be instant since done_rx was received + if let Err(e) = worker_handle.join() { + eprintln!( + "[telemetry] worker thread panicked during shutdown: {:?}", + e + ); + } + }); + Self { - sender: Some(sender), - dropped_sender: Some(dropped_sender), - _worker_handle: WorkerHandle { - join_handle: Some(worker_handle), - }, + sender, + dropped_sender, max_level, dropped_events, } } fn send_event(&self, event: TraceEvent) { - if let Some(sender) = &self.sender { - if let Err(mpsc::TrySendError::Full(_)) = sender.try_send(event) { - let dropped = self.dropped_events.fetch_add(1, Ordering::Relaxed) + 1; + if let Err(mpsc::TrySendError::Full(_)) = self.sender.try_send(event) { + let dropped = self.dropped_events.fetch_add(1, Ordering::Relaxed) + 1; - if dropped == 1 || dropped.is_multiple_of(1000) { - eprintln!( - "[telemetry]: {} events and log lines dropped que to full queue (capacity: {})", - dropped, QUEUE_CAPACITY - ); - self.send_drop_event(dropped); - } + if dropped == 1 || dropped.is_multiple_of(1000) { + eprintln!( + "[telemetry]: {} events and log lines dropped due to full queue (capacity: {})", + dropped, QUEUE_CAPACITY + ); + self.send_drop_event(dropped); } } } fn send_drop_event(&self, total_dropped: u64) { - if let Some(dropped_sender) = &self.dropped_sender { - #[cfg(target_os = "linux")] - let thread_id_num = { - // SAFETY: syscall(SYS_gettid) is always safe to call - unsafe { libc::syscall(libc::SYS_gettid) as u64 } - }; - #[cfg(not(target_os = "linux"))] - let thread_id_num = { - let tid = std::thread::current().id(); - // SAFETY: ThreadId transmute for non-Linux platforms - unsafe { std::mem::transmute::(tid) } - }; - - let mut fields = IndexMap::new(); - fields.insert( - "message".to_string(), - FieldValue::Str(format!( - "Telemetry events and log lines dropped due to full queue (capacity: {}). Worker may be falling behind.", - QUEUE_CAPACITY - )), + #[cfg(target_os = "linux")] + let thread_id_num = { + // SAFETY: syscall(SYS_gettid) is always safe to call + unsafe { libc::syscall(libc::SYS_gettid) as u64 } + }; + #[cfg(not(target_os = "linux"))] + let thread_id_num = { + let tid = std::thread::current().id(); + // SAFETY: ThreadId transmute for non-Linux platforms + unsafe { std::mem::transmute::(tid) } + }; + + let mut fields = IndexMap::new(); + fields.insert( + "message".to_string(), + FieldValue::Str(format!( + "Telemetry events and log lines dropped due to full queue (capacity: {}). Worker may be falling behind.", + QUEUE_CAPACITY + )), + ); + fields.insert("dropped_count".to_string(), FieldValue::U64(total_dropped)); + + // We want to just directly construct and send a `TraceEvent::Event` here so we don't need to + // reason very hard about whether or not we are creating a DoS loop + let drop_event = TraceEvent::Event { + name: "dropped events", + target: module_path!(), + level: tracing::Level::ERROR, + fields, + timestamp: SystemTime::now(), + parent_span: None, + thread_id: thread_id_num.to_string(), + thread_name: std::thread::current() + .name() + .unwrap_or_default() + .to_string(), + module_path: Some(module_path!()), + file: Some(file!()), + line: Some(line!()), + }; + + if self.dropped_sender.send(drop_event).is_err() { + // Last resort + eprintln!( + "[telemetry] CRITICAL: {} events and log lines dropped and unable to log to telemetry \ + (worker thread may have died). Telemetry system offline.", + total_dropped ); - fields.insert("dropped_count".to_string(), FieldValue::U64(total_dropped)); - - // We want to just directly construct and send a `TraceEvent::Event` here so we don't need to - // reason very hard about whether or not we are creating a DoS loop - let drop_event = TraceEvent::Event { - name: "dropped events", - target: module_path!(), - level: tracing::Level::ERROR, - fields, - timestamp: SystemTime::now(), - parent_span: None, - thread_id: thread_id_num.to_string(), - thread_name: std::thread::current() - .name() - .unwrap_or_default() - .to_string(), - module_path: Some(module_path!()), - file: Some(file!()), - line: Some(line!()), - }; - - if dropped_sender.send(drop_event).is_err() { - // Last resort - eprintln!( - "[telemetry] CRITICAL: {} events and log lines dropped and unable to log to telemetry \ - (worker thread may have died). Telemetry system offline.", - total_dropped - ); - } } } } -impl Drop for TraceEventDispatcher { - fn drop(&mut self) { - // Explicitly drop both senders to close the channels. - // The next field to be dropped is `worker_handle` which - // will run its own drop impl to join the thread and flush - drop(self.sender.take()); - drop(self.dropped_sender.take()); - } -} - impl Layer for TraceEventDispatcher where S: Subscriber + for<'a> LookupSpan<'a>, @@ -409,12 +422,13 @@ impl<'a> tracing::field::Visit for FieldVisitor<'a> { /// Background worker loop that receives events from both regular and priority channels, /// and dispatches them to sinks. Priority events are processed first. -/// Runs until both senders are dropped. +/// Runs until shutdown_flag is set or senders are dropped. fn worker_loop( receiver: mpsc::Receiver, dropped_receiver: mpsc::Receiver, mut sinks: Vec>, dropped_events: Arc, + shutdown_flag: Arc, ) { const FLUSH_INTERVAL: Duration = Duration::from_millis(100); const FLUSH_EVENT_COUNT: usize = 1000; @@ -451,6 +465,10 @@ fn worker_loop( } loop { + if shutdown_flag.load(Ordering::Acquire) { + break; + } + while let Ok(event) = dropped_receiver.try_recv() { dispatch_to_sinks(&mut sinks, event); events_since_flush += 1; @@ -469,6 +487,9 @@ fn worker_loop( } } Err(mpsc::RecvTimeoutError::Timeout) => { + if shutdown_flag.load(Ordering::Acquire) { + break; + } flush_sinks(&mut sinks); last_flush = std::time::Instant::now(); events_since_flush = 0; @@ -496,13 +517,3 @@ fn worker_loop( ); } } - -impl Drop for WorkerHandle { - fn drop(&mut self) { - if let Some(handle) = self.join_handle.take() { - if let Err(e) = handle.join() { - eprintln!("[telemetry] worker thread panicked: {:?}", e); - } - } - } -} diff --git a/monarch_hyperactor/src/runtime.rs b/monarch_hyperactor/src/runtime.rs index a6b966eb8..cbe7393ed 100644 --- a/monarch_hyperactor/src/runtime.rs +++ b/monarch_hyperactor/src/runtime.rs @@ -68,6 +68,8 @@ pub fn get_tokio_runtime<'l>() -> std::sync::MappedRwLockReadGuard<'l, tokio::ru #[pyfunction] pub fn shutdown_tokio_runtime() { + hyperactor_telemetry::shutdown_telemetry(); + // It is important to not hold the GIL while calling this function. // Other runtime threads may be waiting to acquire it and we will never get to shutdown. if let Some(x) = INSTANCE.write().unwrap().take() {