Skip to content

Commit 850bd07

Browse files
wangyb-AAlex Wang
andauthored
feat: Adding replay mode for child execution (#10)
- During child execution, if the result exceeded checkpoint api size limit, set replay mode to True - When replay, if replay mode is true, re-execute func instead of return the result from the payload --------- Co-authored-by: Alex Wang <wangyb@amazon.com>
1 parent 44785df commit 850bd07

File tree

5 files changed

+140
-4
lines changed

5 files changed

+140
-4
lines changed

src/aws_durable_execution_sdk_python/config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,12 +104,13 @@ class CheckpointMode(Enum):
104104

105105

106106
@dataclass(frozen=True)
107-
class ChildConfig:
107+
class ChildConfig(Generic[T]):
108108
"""Options when running inside a child context."""
109109

110110
# checkpoint_mode: CheckpointMode = CheckpointMode.CHECKPOINT_AT_START_AND_FINISH
111111
serdes: SerDes | None = None
112112
sub_type: OperationSubType | None = None
113+
summary_generator: Callable[[T], str] | None = None
113114

114115

115116
class ItemsPerBatchUnit(Enum):

src/aws_durable_execution_sdk_python/lambda_service.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,11 @@ def create_context_start(
344344

345345
@classmethod
346346
def create_context_succeed(
347-
cls, identifier: OperationIdentifier, payload: str, sub_type: OperationSubType
347+
cls,
348+
identifier: OperationIdentifier,
349+
payload: str,
350+
sub_type: OperationSubType,
351+
context_options: ContextOptions | None = None,
348352
) -> OperationUpdate:
349353
"""Create an instance of OperationUpdate for type: CONTEXT, action: SUCCEED."""
350354
return cls(
@@ -355,6 +359,7 @@ def create_context_succeed(
355359
action=OperationAction.SUCCEED,
356360
name=identifier.name,
357361
payload=payload,
362+
context_options=context_options,
358363
)
359364

360365
@classmethod

src/aws_durable_execution_sdk_python/operation/child.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from aws_durable_execution_sdk_python.config import ChildConfig
99
from aws_durable_execution_sdk_python.exceptions import FatalError, SuspendExecution
1010
from aws_durable_execution_sdk_python.lambda_service import (
11+
ContextOptions,
1112
ErrorObject,
1213
OperationSubType,
1314
OperationUpdate,
@@ -24,6 +25,9 @@
2425

2526
T = TypeVar("T")
2627

28+
# Checkpoint size limit in bytes (256KB)
29+
CHECKPOINT_SIZE_LIMIT = 256 * 1024
30+
2731

2832
def child_handler(
2933
func: Callable[[], T],
@@ -40,9 +44,11 @@ def child_handler(
4044
if not config:
4145
config = ChildConfig()
4246

43-
# TODO: ReplayChildren
4447
checkpointed_result = state.get_checkpoint_result(operation_identifier.operation_id)
45-
if checkpointed_result.is_succeeded():
48+
if (
49+
checkpointed_result.is_succeeded()
50+
and not checkpointed_result.is_replay_children()
51+
):
4652
logger.debug(
4753
"Child context already completed, skipping execution for id: %s, name: %s",
4854
operation_identifier.operation_id,
@@ -71,17 +77,36 @@ def child_handler(
7177

7278
try:
7379
raw_result: T = func()
80+
if checkpointed_result.is_replay_children():
81+
logger.debug(
82+
"ReplayChildren mode: Executed child context again on replay due to large payload. Exiting child context without creating another checkpoint. id: %s, name: %s",
83+
operation_identifier.operation_id,
84+
operation_identifier.name,
85+
)
86+
return raw_result
7487
serialized_result: str = serialize(
7588
serdes=config.serdes,
7689
value=raw_result,
7790
operation_id=operation_identifier.operation_id,
7891
durable_execution_arn=state.durable_execution_arn,
7992
)
93+
replay_children: bool = False
94+
if len(serialized_result) > CHECKPOINT_SIZE_LIMIT:
95+
logger.debug(
96+
"Large payload detected, using ReplayChildren mode: id: %s, name: %s",
97+
operation_identifier.operation_id,
98+
operation_identifier.name,
99+
)
100+
replay_children = True
101+
serialized_result = (
102+
config.summary_generator(raw_result) if config.summary_generator else ""
103+
)
80104

81105
success_operation = OperationUpdate.create_context_succeed(
82106
identifier=operation_identifier,
83107
payload=serialized_result,
84108
sub_type=sub_type,
109+
context_options=ContextOptions(replay_children=replay_children),
85110
)
86111
state.create_checkpoint(operation_update=success_operation)
87112

src/aws_durable_execution_sdk_python/state.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,12 @@ def is_timed_out(self) -> bool:
113113
return False
114114
return op.status is OperationStatus.TIMED_OUT
115115

116+
def is_replay_children(self) -> bool:
117+
op = self.operation
118+
if not op:
119+
return False
120+
return op.context_details.replay_children if op.context_details else False
121+
116122
def raise_callable_error(self) -> None:
117123
if self.error is None:
118124
msg: str = "Attempted to throw exception, but no ErrorObject exists on the Checkpoint Operation."

tests/operation/child_test.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ def test_child_handler_not_started(
4141
mock_result.is_succeeded.return_value = False
4242
mock_result.is_failed.return_value = False
4343
mock_result.is_started.return_value = False
44+
mock_result.is_replay_children.return_value = False
45+
mock_result.is_replay_children.return_value = False
4446
mock_state.get_checkpoint_result.return_value = mock_result
4547
mock_callable = Mock(return_value="fresh_result")
4648

@@ -80,6 +82,7 @@ def test_child_handler_already_succeeded():
8082
mock_state.durable_execution_arn = "test_arn"
8183
mock_result = Mock()
8284
mock_result.is_succeeded.return_value = True
85+
mock_result.is_replay_children.return_value = False
8386
mock_result.result = json.dumps("cached_result")
8487
mock_state.get_checkpoint_result.return_value = mock_result
8588
mock_callable = Mock()
@@ -99,6 +102,7 @@ def test_child_handler_already_succeeded_none_result():
99102
mock_state.durable_execution_arn = "test_arn"
100103
mock_result = Mock()
101104
mock_result.is_succeeded.return_value = True
105+
mock_result.is_replay_children.return_value = False
102106
mock_result.result = None
103107
mock_state.get_checkpoint_result.return_value = mock_result
104108
mock_callable = Mock()
@@ -155,6 +159,7 @@ def test_child_handler_already_started(
155159
mock_result.is_succeeded.return_value = False
156160
mock_result.is_failed.return_value = False
157161
mock_result.is_started.return_value = True
162+
mock_result.is_replay_children.return_value = False
158163
mock_state.get_checkpoint_result.return_value = mock_result
159164
mock_callable = Mock(return_value="started_result")
160165

@@ -281,6 +286,7 @@ def test_child_handler_default_serialization():
281286
mock_result.is_succeeded.return_value = False
282287
mock_result.is_failed.return_value = False
283288
mock_result.is_started.return_value = False
289+
mock_result.is_replay_children.return_value = False
284290
mock_state.get_checkpoint_result.return_value = mock_result
285291
complex_result = {"key": "value", "number": 42, "list": [1, 2, 3]}
286292
mock_callable = Mock(return_value=complex_result)
@@ -306,6 +312,7 @@ def test_child_handler_custom_serdes_not_start():
306312
mock_result.is_succeeded.return_value = False
307313
mock_result.is_failed.return_value = False
308314
mock_result.is_started.return_value = False
315+
mock_result.is_replay_children.return_value = False
309316
mock_state.get_checkpoint_result.return_value = mock_result
310317
complex_result = {"key": "value", "number": 42, "list": [1, 2, 3]}
311318
mock_callable = Mock(return_value=complex_result)
@@ -334,6 +341,7 @@ def test_child_handler_custom_serdes_already_succeeded():
334341
mock_result.is_succeeded.return_value = True
335342
mock_result.is_failed.return_value = False
336343
mock_result.is_started.return_value = False
344+
mock_result.is_replay_children.return_value = False
337345
mock_result.result = '{"key": "VALUE", "number": "84", "list": [1, 2, 3]}'
338346
mock_state.get_checkpoint_result.return_value = mock_result
339347
mock_callable = Mock()
@@ -352,3 +360,94 @@ def test_child_handler_custom_serdes_already_succeeded():
352360

353361

354362
# endregion child_handler
363+
364+
365+
# large payload with summary generator
366+
def test_child_handler_large_payload_with_summary_generator():
367+
"""Test child_handler with large payload and summary generator."""
368+
mock_state = Mock(spec=ExecutionState)
369+
mock_state.durable_execution_arn = "test_arn"
370+
mock_result = Mock()
371+
mock_result.is_succeeded.return_value = False
372+
mock_result.is_failed.return_value = False
373+
mock_result.is_started.return_value = False
374+
mock_result.is_replay_children.return_value = False
375+
mock_state.get_checkpoint_result.return_value = mock_result
376+
large_result = "large" * 256 * 1024
377+
mock_callable = Mock(return_value=large_result)
378+
379+
def my_summary(result: str) -> str:
380+
return "summary"
381+
382+
child_config: ChildConfig = ChildConfig[str](summary_generator=my_summary)
383+
384+
actual_result = child_handler(
385+
mock_callable,
386+
mock_state,
387+
OperationIdentifier("op9", None, "test_name"),
388+
child_config,
389+
)
390+
391+
assert large_result == actual_result
392+
success_call = mock_state.create_checkpoint.call_args_list[1]
393+
success_operation = success_call[1]["operation_update"]
394+
assert success_operation.context_options.replay_children
395+
expected_checkpoointed_result = "summary"
396+
assert success_operation.payload == expected_checkpoointed_result
397+
398+
399+
# large payload without summary generator
400+
def test_child_handler_large_payload_without_summary_generator():
401+
"""Test child_handler with large payload and no summary generator."""
402+
mock_state = Mock(spec=ExecutionState)
403+
mock_state.durable_execution_arn = "test_arn"
404+
mock_result = Mock()
405+
mock_result.is_succeeded.return_value = False
406+
mock_result.is_failed.return_value = False
407+
mock_result.is_started.return_value = False
408+
mock_result.is_replay_children.return_value = False
409+
mock_state.get_checkpoint_result.return_value = mock_result
410+
large_result = "large" * 256 * 1024
411+
mock_callable = Mock(return_value=large_result)
412+
child_config: ChildConfig = ChildConfig()
413+
414+
actual_result = child_handler(
415+
mock_callable,
416+
mock_state,
417+
OperationIdentifier("op9", None, "test_name"),
418+
child_config,
419+
)
420+
421+
assert large_result == actual_result
422+
success_call = mock_state.create_checkpoint.call_args_list[1]
423+
success_operation = success_call[1]["operation_update"]
424+
assert success_operation.context_options.replay_children
425+
expected_checkpoointed_result = ""
426+
assert success_operation.payload == expected_checkpoointed_result
427+
428+
429+
# mocked children replay mode execute the function again
430+
def test_child_handler_replay_children_mode():
431+
"""Test child_handler in ReplayChildren mode."""
432+
mock_state = Mock(spec=ExecutionState)
433+
mock_state.durable_execution_arn = "test_arn"
434+
mock_result = Mock()
435+
mock_result.is_succeeded.return_value = True
436+
mock_result.is_failed.return_value = False
437+
mock_result.is_started.return_value = True
438+
mock_result.is_replay_children.return_value = True
439+
mock_state.get_checkpoint_result.return_value = mock_result
440+
complex_result = {"key": "value", "number": 42, "list": [1, 2, 3]}
441+
mock_callable = Mock(return_value=complex_result)
442+
child_config: ChildConfig = ChildConfig()
443+
444+
actual_result = child_handler(
445+
mock_callable,
446+
mock_state,
447+
OperationIdentifier("op9", None, "test_name"),
448+
child_config,
449+
)
450+
451+
assert actual_result == complex_result
452+
453+
mock_state.create_checkpoint.assert_not_called()

0 commit comments

Comments
 (0)