diff --git a/monarch_hyperactor/src/actor.rs b/monarch_hyperactor/src/actor.rs index 276bd0905..68066d1a0 100644 --- a/monarch_hyperactor/src/actor.rs +++ b/monarch_hyperactor/src/actor.rs @@ -307,7 +307,7 @@ impl PythonMessage { id, unflatten_args, } => { - let broker = BrokerId::new(local_state_broker).resolve(cx).unwrap(); + let broker = BrokerId::new(local_state_broker).resolve(cx).await; let (send, recv) = cx.open_once_port(); broker.send(LocalStateBrokerMessage::Get(id, send))?; let state = recv.recv().await?; diff --git a/monarch_hyperactor/src/local_state_broker.rs b/monarch_hyperactor/src/local_state_broker.rs index 5c287f182..98b10b75b 100644 --- a/monarch_hyperactor/src/local_state_broker.rs +++ b/monarch_hyperactor/src/local_state_broker.rs @@ -75,9 +75,34 @@ impl BrokerId { pub fn new(broker_id: (String, usize)) -> Self { BrokerId(broker_id.0, broker_id.1) } - pub fn resolve(self, cx: &Context) -> Option> { + + /// Resolve the broker with exponential backoff retry. + /// Broker creation can race with messages that will use the broker, + /// so we retry with exponential backoff before panicking. + /// A better solution would be to figure out some way to get the real broker reference threaded to the client, but + /// that is more difficult to figure out right now. + pub async fn resolve( + self, + cx: &Context<'_, A>, + ) -> ActorHandle { + use std::time::Duration; + + let broker_name = format!("{:?}", self); let actor_id = ActorId(cx.proc().proc_id().clone(), self.0, self.1); let actor_ref: ActorRef = ActorRef::attest(actor_id); - actor_ref.downcast_handle(cx) + + let mut delay_ms = 1; + loop { + if let Some(handle) = actor_ref.downcast_handle(cx) { + return handle; + } + + if delay_ms > 8192 { + panic!("Failed to resolve broker {} after retries", broker_name); + } + + tokio::time::sleep(Duration::from_millis(delay_ms)).await; + delay_ms *= 2; + } } } diff --git a/monarch_tensor_worker/src/stream.rs b/monarch_tensor_worker/src/stream.rs index f25cda1c5..d59c933c6 100644 --- a/monarch_tensor_worker/src/stream.rs +++ b/monarch_tensor_worker/src/stream.rs @@ -978,7 +978,7 @@ impl StreamActor { let x: u64 = params.seq.into(); let message = LocalStateBrokerMessage::Set(x as usize, state); - let broker = BrokerId::new(params.broker_id).resolve(cx).unwrap(); + let broker = BrokerId::new(params.broker_id).resolve(cx).await; broker .send(message) .map_err(|e| CallFunctionError::Error(e.into()))?;