Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions hyperactor_telemetry/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ use std::io::Write;
use std::str::FromStr;
use std::sync::Arc;
use std::sync::Mutex;
use std::sync::OnceLock;
use std::time::Instant;

use lazy_static::lazy_static;
Expand Down Expand Up @@ -102,6 +103,35 @@ use crate::config::USE_UNIFIED_LAYER;
use crate::recorder::Recorder;
use crate::sqlite::get_reloadable_sqlite_layer;

static SHUTDOWN_HOOKS: OnceLock<Mutex<Vec<Box<dyn FnOnce() + Send>>>> = OnceLock::new();

fn get_shutdown_hooks() -> &'static Mutex<Vec<Box<dyn FnOnce() + Send>>> {
SHUTDOWN_HOOKS.get_or_init(|| Mutex::new(Vec::new()))
}

/// Register a callback to be invoked during telemetry shutdown.
/// This is useful for components that need to flush buffers or clean up
/// before the process exits.
pub(crate) fn register_shutdown_hook<F>(hook: F)
where
F: FnOnce() + Send + 'static,
{
if let Ok(mut hooks) = get_shutdown_hooks().lock() {
hooks.push(Box::new(hook));
}
}

/// Shutdown the telemetry system, invoking all registered shutdown hooks.
/// This should be called before process exit to ensure all telemetry is flushed.
/// Safe to call multiple times; hooks are only invoked once.
pub fn shutdown_telemetry() {
if let Ok(mut hooks) = get_shutdown_hooks().lock() {
for hook in hooks.drain(..) {
hook();
}
}
}

#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct TelemetrySample {
fields: Vec<(String, String)>,
Expand Down
199 changes: 105 additions & 94 deletions hyperactor_telemetry/src/trace_dispatcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
//! thread.

use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::AtomicU64;
use std::sync::atomic::Ordering;
use std::sync::mpsc;
use std::thread::JoinHandle;
use std::time::Duration;
use std::time::SystemTime;

Expand All @@ -26,7 +26,10 @@ use tracing_subscriber::layer::Context;
use tracing_subscriber::layer::Layer;
use tracing_subscriber::registry::LookupSpan;

use crate::register_shutdown_hook;

const QUEUE_CAPACITY: usize = 100_000;
const SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5);

/// Unified representation of a trace event captured from the tracing layer.
/// This is captured once on the application thread, then sent to the background
Expand Down Expand Up @@ -122,18 +125,13 @@ pub(crate) trait TraceEventSink: Send + 'static {
/// The trace event dispatcher that captures events once and dispatches to multiple sinks
/// on a background thread.
pub struct TraceEventDispatcher {
sender: Option<mpsc::SyncSender<TraceEvent>>,
sender: mpsc::SyncSender<TraceEvent>,
/// Separate channel so we are always notified of when the main queue is full and events are being dropped.
dropped_sender: Option<mpsc::Sender<TraceEvent>>,
_worker_handle: WorkerHandle,
dropped_sender: mpsc::Sender<TraceEvent>,
max_level: Option<tracing::level_filters::LevelFilter>,
dropped_events: Arc<AtomicU64>,
}

struct WorkerHandle {
join_handle: Option<JoinHandle<()>>,
}

impl TraceEventDispatcher {
/// Create a new trace event dispatcher with the given sinks.
/// Uses a bounded channel (capacity QUEUE_CAPACITY) to ensure telemetry never blocks
Expand All @@ -153,105 +151,120 @@ impl TraceEventDispatcher {
let dropped_events = Arc::new(AtomicU64::new(0));
let dropped_events_worker = Arc::clone(&dropped_events);

let shutdown_flag = Arc::new(AtomicBool::new(false));
let shutdown_flag_worker = Arc::clone(&shutdown_flag);

let (done_tx, done_rx) = mpsc::sync_channel::<()>(1);

let worker_handle = std::thread::Builder::new()
.name("telemetry-worker".into())
.spawn(move || {
worker_loop(receiver, dropped_receiver, sinks, dropped_events_worker);
worker_loop(
receiver,
dropped_receiver,
sinks,
dropped_events_worker,
shutdown_flag_worker,
);
let _ = done_tx.send(());
})
.expect("failed to spawn telemetry worker thread");

register_shutdown_hook(move || {
shutdown_flag.store(true, Ordering::Release);

if done_rx.recv_timeout(SHUTDOWN_TIMEOUT).is_err() {
eprintln!(
"[telemetry] WARNING: worker thread did not exit within {:?}, continuing shutdown",
SHUTDOWN_TIMEOUT
);
return;
}

// This join should be instant since done_rx was received
if let Err(e) = worker_handle.join() {
eprintln!(
"[telemetry] worker thread panicked during shutdown: {:?}",
e
);
}
});

Self {
sender: Some(sender),
dropped_sender: Some(dropped_sender),
_worker_handle: WorkerHandle {
join_handle: Some(worker_handle),
},
sender,
dropped_sender,
max_level,
dropped_events,
}
}

fn send_event(&self, event: TraceEvent) {
if let Some(sender) = &self.sender {
if let Err(mpsc::TrySendError::Full(_)) = sender.try_send(event) {
let dropped = self.dropped_events.fetch_add(1, Ordering::Relaxed) + 1;
if let Err(mpsc::TrySendError::Full(_)) = self.sender.try_send(event) {
let dropped = self.dropped_events.fetch_add(1, Ordering::Relaxed) + 1;

if dropped == 1 || dropped.is_multiple_of(1000) {
eprintln!(
"[telemetry]: {} events and log lines dropped que to full queue (capacity: {})",
dropped, QUEUE_CAPACITY
);
self.send_drop_event(dropped);
}
if dropped == 1 || dropped.is_multiple_of(1000) {
eprintln!(
"[telemetry]: {} events and log lines dropped due to full queue (capacity: {})",
dropped, QUEUE_CAPACITY
);
self.send_drop_event(dropped);
}
}
}

fn send_drop_event(&self, total_dropped: u64) {
if let Some(dropped_sender) = &self.dropped_sender {
#[cfg(target_os = "linux")]
let thread_id_num = {
// SAFETY: syscall(SYS_gettid) is always safe to call
unsafe { libc::syscall(libc::SYS_gettid) as u64 }
};
#[cfg(not(target_os = "linux"))]
let thread_id_num = {
let tid = std::thread::current().id();
// SAFETY: ThreadId transmute for non-Linux platforms
unsafe { std::mem::transmute::<std::thread::ThreadId, u64>(tid) }
};

let mut fields = IndexMap::new();
fields.insert(
"message".to_string(),
FieldValue::Str(format!(
"Telemetry events and log lines dropped due to full queue (capacity: {}). Worker may be falling behind.",
QUEUE_CAPACITY
)),
#[cfg(target_os = "linux")]
let thread_id_num = {
// SAFETY: syscall(SYS_gettid) is always safe to call
unsafe { libc::syscall(libc::SYS_gettid) as u64 }
};
#[cfg(not(target_os = "linux"))]
let thread_id_num = {
let tid = std::thread::current().id();
// SAFETY: ThreadId transmute for non-Linux platforms
unsafe { std::mem::transmute::<std::thread::ThreadId, u64>(tid) }
};

let mut fields = IndexMap::new();
fields.insert(
"message".to_string(),
FieldValue::Str(format!(
"Telemetry events and log lines dropped due to full queue (capacity: {}). Worker may be falling behind.",
QUEUE_CAPACITY
)),
);
fields.insert("dropped_count".to_string(), FieldValue::U64(total_dropped));

// We want to just directly construct and send a `TraceEvent::Event` here so we don't need to
// reason very hard about whether or not we are creating a DoS loop
let drop_event = TraceEvent::Event {
name: "dropped events",
target: module_path!(),
level: tracing::Level::ERROR,
fields,
timestamp: SystemTime::now(),
parent_span: None,
thread_id: thread_id_num.to_string(),
thread_name: std::thread::current()
.name()
.unwrap_or_default()
.to_string(),
module_path: Some(module_path!()),
file: Some(file!()),
line: Some(line!()),
};

if self.dropped_sender.send(drop_event).is_err() {
// Last resort
eprintln!(
"[telemetry] CRITICAL: {} events and log lines dropped and unable to log to telemetry \
(worker thread may have died). Telemetry system offline.",
total_dropped
);
fields.insert("dropped_count".to_string(), FieldValue::U64(total_dropped));

// We want to just directly construct and send a `TraceEvent::Event` here so we don't need to
// reason very hard about whether or not we are creating a DoS loop
let drop_event = TraceEvent::Event {
name: "dropped events",
target: module_path!(),
level: tracing::Level::ERROR,
fields,
timestamp: SystemTime::now(),
parent_span: None,
thread_id: thread_id_num.to_string(),
thread_name: std::thread::current()
.name()
.unwrap_or_default()
.to_string(),
module_path: Some(module_path!()),
file: Some(file!()),
line: Some(line!()),
};

if dropped_sender.send(drop_event).is_err() {
// Last resort
eprintln!(
"[telemetry] CRITICAL: {} events and log lines dropped and unable to log to telemetry \
(worker thread may have died). Telemetry system offline.",
total_dropped
);
}
}
}
}

impl Drop for TraceEventDispatcher {
fn drop(&mut self) {
// Explicitly drop both senders to close the channels.
// The next field to be dropped is `worker_handle` which
// will run its own drop impl to join the thread and flush
drop(self.sender.take());
drop(self.dropped_sender.take());
}
}

impl<S> Layer<S> for TraceEventDispatcher
where
S: Subscriber + for<'a> LookupSpan<'a>,
Expand Down Expand Up @@ -409,12 +422,13 @@ impl<'a> tracing::field::Visit for FieldVisitor<'a> {

/// Background worker loop that receives events from both regular and priority channels,
/// and dispatches them to sinks. Priority events are processed first.
/// Runs until both senders are dropped.
/// Runs until shutdown_flag is set or senders are dropped.
fn worker_loop(
receiver: mpsc::Receiver<TraceEvent>,
dropped_receiver: mpsc::Receiver<TraceEvent>,
mut sinks: Vec<Box<dyn TraceEventSink>>,
dropped_events: Arc<AtomicU64>,
shutdown_flag: Arc<AtomicBool>,
) {
const FLUSH_INTERVAL: Duration = Duration::from_millis(100);
const FLUSH_EVENT_COUNT: usize = 1000;
Expand Down Expand Up @@ -451,6 +465,10 @@ fn worker_loop(
}

loop {
if shutdown_flag.load(Ordering::Acquire) {
break;
}

while let Ok(event) = dropped_receiver.try_recv() {
dispatch_to_sinks(&mut sinks, event);
events_since_flush += 1;
Expand All @@ -469,6 +487,9 @@ fn worker_loop(
}
}
Err(mpsc::RecvTimeoutError::Timeout) => {
if shutdown_flag.load(Ordering::Acquire) {
break;
}
flush_sinks(&mut sinks);
last_flush = std::time::Instant::now();
events_since_flush = 0;
Expand Down Expand Up @@ -496,13 +517,3 @@ fn worker_loop(
);
}
}

impl Drop for WorkerHandle {
fn drop(&mut self) {
if let Some(handle) = self.join_handle.take() {
if let Err(e) = handle.join() {
eprintln!("[telemetry] worker thread panicked: {:?}", e);
}
}
}
}
2 changes: 2 additions & 0 deletions monarch_hyperactor/src/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ pub fn get_tokio_runtime<'l>() -> std::sync::MappedRwLockReadGuard<'l, tokio::ru

#[pyfunction]
pub fn shutdown_tokio_runtime() {
hyperactor_telemetry::shutdown_telemetry();

// It is important to not hold the GIL while calling this function.
// Other runtime threads may be waiting to acquire it and we will never get to shutdown.
if let Some(x) = INSTANCE.write().unwrap().take() {
Expand Down
Loading