Skip to content

Commit 1e60904

Browse files
committed
feat: add durable_wait_for_callback decorator
Add a new decorator that enables cleaner API for wait_for_callback operations by allowing submitter functions to accept additional parameters that are bound at call time. The decorator wraps callables that take callback_id, context, and additional parameters, returning a function that binds those extra parameters and produces a submitter compatible with wait_for_callback. Changes: * Add durable_wait_for_callback decorator to context module * Export decorator from package __init__ for public API * Add integration test validating parameter binding and name propagation * Verify operation names are correctly set on STEP and CALLBACK operations Example usage: @durable_wait_for_callback def submit_task(callback_id, context, task_name, priority): external_api.submit(task_name, priority, callback_id) result = context.wait_for_callback(submit_task(my_task, priority=5))
1 parent f75878e commit 1e60904

File tree

3 files changed

+169
-2
lines changed

3 files changed

+169
-2
lines changed

src/aws_durable_execution_sdk_python/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from aws_durable_execution_sdk_python.context import (
66
DurableContext,
77
durable_step,
8+
durable_wait_for_callback,
89
durable_with_child_context,
910
)
1011

@@ -30,5 +31,6 @@
3031
"ValidationError",
3132
"durable_execution",
3233
"durable_step",
34+
"durable_wait_for_callback",
3335
"durable_with_child_context",
3436
]

src/aws_durable_execution_sdk_python/context.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,52 @@ def function_with_arguments(child_context: DurableContext):
104104
return wrapper
105105

106106

107+
def durable_wait_for_callback(
108+
func: Callable[Concatenate[str, WaitForCallbackContext, Params], T],
109+
) -> Callable[Params, Callable[[str, WaitForCallbackContext], T]]:
110+
"""Wrap your callable into a wait_for_callback submitter function.
111+
112+
This decorator allows you to define a submitter function with additional
113+
parameters that will be bound when called.
114+
115+
Args:
116+
func: A callable that takes callback_id, context, and additional parameters
117+
118+
Returns:
119+
A wrapper function that binds the additional parameters and returns
120+
a submitter function compatible with wait_for_callback
121+
122+
Example:
123+
@durable_wait_for_callback
124+
def submit_to_external_system(
125+
callback_id: str,
126+
context: WaitForCallbackContext,
127+
task_name: str,
128+
priority: int
129+
):
130+
context.logger.info(f"Submitting {task_name} with callback {callback_id}")
131+
external_api.submit_task(
132+
task_name=task_name,
133+
priority=priority,
134+
callback_id=callback_id
135+
)
136+
137+
# Usage in durable handler:
138+
result = context.wait_for_callback(
139+
submit_to_external_system("my_task", priority=5)
140+
)
141+
"""
142+
143+
def wrapper(*args, **kwargs):
144+
def submitter_with_arguments(callback_id: str, context: WaitForCallbackContext):
145+
return func(callback_id, context, *args, **kwargs)
146+
147+
submitter_with_arguments._original_name = func.__name__ # noqa: SLF001
148+
return submitter_with_arguments
149+
150+
return wrapper
151+
152+
107153
class Callback(Generic[T], CallbackProtocol[T]): # noqa: PYI059
108154
"""A future that will block on result() until callback_id returns."""
109155

tests/e2e/execution_int_test.py

Lines changed: 121 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,20 @@
1111
from aws_durable_execution_sdk_python.context import (
1212
DurableContext,
1313
durable_step,
14+
durable_wait_for_callback,
1415
durable_with_child_context,
1516
)
1617
from aws_durable_execution_sdk_python.execution import (
1718
InvocationStatus,
1819
durable_execution,
1920
)
20-
21-
# LambdaContext no longer needed - using duck typing
2221
from aws_durable_execution_sdk_python.lambda_service import (
22+
CallbackDetails,
2323
CheckpointOutput,
2424
CheckpointUpdatedExecutionState,
25+
Operation,
2526
OperationAction,
27+
OperationStatus,
2628
OperationType,
2729
)
2830
from aws_durable_execution_sdk_python.logger import LoggerInterface
@@ -487,3 +489,120 @@ def mock_checkpoint(
487489
assert checkpoint.action is OperationAction.START
488490
assert checkpoint.operation_id == next(operation_ids)
489491
assert checkpoint.wait_options.wait_seconds == 1
492+
493+
494+
def test_durable_wait_for_callback_decorator():
495+
"""Test the durable_wait_for_callback decorator with additional parameters."""
496+
497+
mock_submitter = Mock()
498+
499+
@durable_wait_for_callback
500+
def submit_to_external_system(callback_id, context, task_name, priority):
501+
mock_submitter(callback_id, task_name, priority)
502+
context.logger.info("Submitting %s with callback %s", task_name, callback_id)
503+
504+
@durable_execution
505+
def my_handler(event, context):
506+
context.wait_for_callback(submit_to_external_system("my_task", priority=5))
507+
508+
with patch(
509+
"aws_durable_execution_sdk_python.execution.LambdaClient"
510+
) as mock_client_class:
511+
mock_client = Mock()
512+
mock_client_class.initialize_from_env.return_value = mock_client
513+
514+
checkpoint_calls = []
515+
516+
def mock_checkpoint(
517+
durable_execution_arn,
518+
checkpoint_token,
519+
updates,
520+
client_token="token", # noqa: S107
521+
):
522+
checkpoint_calls.append(updates)
523+
524+
# For CALLBACK operations, return the operation with callback details
525+
operations = [
526+
Operation(
527+
operation_id=update.operation_id,
528+
operation_type=OperationType.CALLBACK,
529+
status=OperationStatus.STARTED,
530+
callback_details=CallbackDetails(
531+
callback_id=f"callback-{update.operation_id[:8]}"
532+
),
533+
)
534+
for update in updates
535+
if update.operation_type == OperationType.CALLBACK
536+
]
537+
538+
return CheckpointOutput(
539+
checkpoint_token="new_token", # noqa: S106
540+
new_execution_state=CheckpointUpdatedExecutionState(
541+
operations=operations, next_marker=None
542+
),
543+
)
544+
545+
mock_client.checkpoint = mock_checkpoint
546+
547+
event = {
548+
"DurableExecutionArn": "test-arn",
549+
"CheckpointToken": "test-token",
550+
"InitialExecutionState": {
551+
"Operations": [
552+
{
553+
"Id": "execution-1",
554+
"Type": "EXECUTION",
555+
"Status": "STARTED",
556+
"ExecutionDetails": {"InputPayload": "{}"},
557+
}
558+
],
559+
"NextMarker": "",
560+
},
561+
"LocalRunner": True,
562+
}
563+
564+
lambda_context = Mock()
565+
lambda_context.aws_request_id = "test-request-id"
566+
lambda_context.client_context = None
567+
lambda_context.identity = None
568+
lambda_context._epoch_deadline_time_in_ms = 0 # noqa: SLF001
569+
lambda_context.invoked_function_arn = "test-arn"
570+
lambda_context.tenant_id = None
571+
572+
result = my_handler(event, lambda_context)
573+
574+
assert result["Status"] == InvocationStatus.PENDING.value
575+
576+
all_operations = [op for batch in checkpoint_calls for op in batch]
577+
assert len(all_operations) == 4
578+
579+
# First: CONTEXT START
580+
first_checkpoint = all_operations[0]
581+
assert first_checkpoint.operation_type is OperationType.CONTEXT
582+
assert first_checkpoint.action is OperationAction.START
583+
assert first_checkpoint.name == "submit_to_external_system"
584+
585+
# Second: CALLBACK START
586+
second_checkpoint = all_operations[1]
587+
assert second_checkpoint.operation_type is OperationType.CALLBACK
588+
assert second_checkpoint.action is OperationAction.START
589+
assert second_checkpoint.parent_id == first_checkpoint.operation_id
590+
assert second_checkpoint.name == "submit_to_external_system create callback id"
591+
592+
# Third: STEP START
593+
third_checkpoint = all_operations[2]
594+
assert third_checkpoint.operation_type is OperationType.STEP
595+
assert third_checkpoint.action is OperationAction.START
596+
assert third_checkpoint.parent_id == first_checkpoint.operation_id
597+
assert third_checkpoint.name == "submit_to_external_system submitter"
598+
599+
# Fourth: STEP SUCCEED
600+
fourth_checkpoint = all_operations[3]
601+
assert fourth_checkpoint.operation_type is OperationType.STEP
602+
assert fourth_checkpoint.action is OperationAction.SUCCEED
603+
assert fourth_checkpoint.operation_id == third_checkpoint.operation_id
604+
605+
mock_submitter.assert_called_once()
606+
call_args = mock_submitter.call_args[0]
607+
assert call_args[1] == "my_task"
608+
assert call_args[2] == 5

0 commit comments

Comments
 (0)