Skip to content

Commit c0dd405

Browse files
committed
feat: add early exit for concurrency
- Add OrphanedChildException (BaseException) to terminate orphaned children when parent completes early - Modify ThreadPoolExecutor to shutdown without waiting (wait=False) when completion criteria met - Raise exception when orphaned children attempt to checkpoint, preventing subsequent operations from executing - Update state.py to reject orphaned child checkpoints with exception instead of silent return - Add comprehensive tests for early exit behavior and orphaned child handling When min_successful or error threshold is reached in parallel/map operations, the parent now returns immediately without waiting for remaining branches to complete. Orphaned branches are terminated on their next checkpoint attempt, preventing wasted work and ensuring correct semantics for completion criteria.
1 parent dd086e0 commit c0dd405

File tree

6 files changed

+329
-46
lines changed

6 files changed

+329
-46
lines changed

src/aws_durable_execution_sdk_python/concurrency/executor.py

Lines changed: 45 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
)
2323
from aws_durable_execution_sdk_python.config import ChildConfig
2424
from aws_durable_execution_sdk_python.exceptions import (
25+
OrphanedChildException,
2526
SuspendExecution,
2627
TimedSuspendExecution,
2728
)
@@ -198,42 +199,47 @@ def resubmitter(executable_with_state: ExecutableWithState) -> None:
198199
execution_state.create_checkpoint()
199200
submit_task(executable_with_state)
200201

201-
with (
202-
TimerScheduler(resubmitter) as scheduler,
203-
ThreadPoolExecutor(max_workers=max_workers) as thread_executor,
204-
):
205-
206-
def submit_task(executable_with_state: ExecutableWithState) -> Future:
207-
"""Submit task to the thread executor and mark its state as started."""
208-
future = thread_executor.submit(
209-
self._execute_item_in_child_context,
210-
executor_context,
211-
executable_with_state.executable,
212-
)
213-
executable_with_state.run(future)
202+
thread_executor = ThreadPoolExecutor(max_workers=max_workers)
203+
try:
204+
with TimerScheduler(resubmitter) as scheduler:
205+
206+
def submit_task(executable_with_state: ExecutableWithState) -> Future:
207+
"""Submit task to the thread executor and mark its state as started."""
208+
future = thread_executor.submit(
209+
self._execute_item_in_child_context,
210+
executor_context,
211+
executable_with_state.executable,
212+
)
213+
executable_with_state.run(future)
214214

215-
def on_done(future: Future) -> None:
216-
self._on_task_complete(executable_with_state, future, scheduler)
215+
def on_done(future: Future) -> None:
216+
self._on_task_complete(executable_with_state, future, scheduler)
217217

218-
future.add_done_callback(on_done)
219-
return future
218+
future.add_done_callback(on_done)
219+
return future
220220

221-
# Submit initial tasks
222-
futures = [
223-
submit_task(exe_state) for exe_state in self.executables_with_state
224-
]
221+
# Submit initial tasks
222+
futures = [
223+
submit_task(exe_state) for exe_state in self.executables_with_state
224+
]
225225

226-
# Wait for completion
227-
self._completion_event.wait()
226+
# Wait for completion
227+
self._completion_event.wait()
228228

229-
# Cancel remaining futures so
230-
# that we don't wait for them to join.
231-
for future in futures:
232-
future.cancel()
229+
# Cancel futures that haven't started yet
230+
for future in futures:
231+
future.cancel()
233232

234-
# Suspend execution if everything done and at least one of the tasks raised a suspend exception.
235-
if self._suspend_exception:
236-
raise self._suspend_exception
233+
# Suspend execution if everything done and at least one of the tasks raised a suspend exception.
234+
if self._suspend_exception:
235+
raise self._suspend_exception
236+
237+
finally:
238+
# Shutdown without waiting for running threads for early return when
239+
# completion criteria are met (e.g., min_successful).
240+
# Running threads will continue in background but they raise OrphanedChildException
241+
# on the next attempt to checkpoint.
242+
thread_executor.shutdown(wait=False, cancel_futures=True)
237243

238244
# Build final result
239245
return self._create_result()
@@ -291,6 +297,15 @@ def _on_task_complete(
291297
result = future.result()
292298
exe_state.complete(result)
293299
self.counters.complete_task()
300+
except OrphanedChildException:
301+
# Parent already completed and returned.
302+
# State is already RUNNING, which _create_result() marked as STARTED
303+
# Just log and exit - no state change needed
304+
logger.debug(
305+
"Terminating orphaned branch %s without error because parent has completed already",
306+
exe_state.index,
307+
)
308+
return
294309
except TimedSuspendExecution as tse:
295310
exe_state.suspend_with_timeout(tse.scheduled_timestamp)
296311
scheduler.schedule_resume(exe_state, tse.scheduled_timestamp)

src/aws_durable_execution_sdk_python/exceptions.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,3 +372,32 @@ def __str__(self) -> str:
372372

373373
class SerDesError(DurableExecutionsError):
374374
"""Raised when serialization fails."""
375+
376+
377+
class OrphanedChildException(BaseException):
378+
"""Raised when a child operation attempts to checkpoint after its parent context has completed.
379+
380+
This exception inherits from BaseException (not Exception) so that user-space doesn't
381+
accidentally catch it with broad exception handlers like 'except Exception'.
382+
383+
This exception will happen when a parallel branch or map item tries to create a checkpoint
384+
after its parent context (i.e the parallel/map operation) has already completed due to meeting
385+
completion criteria (e.g., min_successful reached, failure tolerance exceeded).
386+
387+
Although you cannot cancel running futures in user-space, this will at least terminate the
388+
child operation on the next checkpoint attempt, preventing subsequent operations in the
389+
child scope from executing.
390+
391+
Attributes:
392+
operation_id: Operation ID of the orphaned child
393+
"""
394+
395+
def __init__(self, message: str, operation_id: str):
396+
"""Initialize OrphanedChildException.
397+
398+
Args:
399+
message: Human-readable error message
400+
operation_id: Operation ID of the orphaned child (required)
401+
"""
402+
super().__init__(message)
403+
self.operation_id = operation_id

src/aws_durable_execution_sdk_python/state.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
BackgroundThreadError,
1717
CallableRuntimeError,
1818
DurableExecutionsError,
19+
OrphanedChildException,
1920
)
2021
from aws_durable_execution_sdk_python.lambda_service import (
2122
CheckpointOutput,
@@ -449,7 +450,13 @@ def create_checkpoint(
449450
"Rejecting checkpoint for operation %s - parent is done",
450451
operation_update.operation_id,
451452
)
452-
return
453+
error_msg = (
454+
"Parent context completed, child operation cannot checkpoint"
455+
)
456+
raise OrphanedChildException(
457+
error_msg,
458+
operation_id=operation_update.operation_id,
459+
)
453460

454461
# Check if background checkpointing has failed
455462
if self._checkpointing_failed.is_set():

tests/concurrency_test.py

Lines changed: 194 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@
3232
SuspendExecution,
3333
TimedSuspendExecution,
3434
)
35-
from aws_durable_execution_sdk_python.lambda_service import ErrorObject
35+
from aws_durable_execution_sdk_python.lambda_service import (
36+
ErrorObject,
37+
)
3638
from aws_durable_execution_sdk_python.operation.map import MapExecutor
3739

3840

@@ -2838,3 +2840,194 @@ def task_func(ctx, item, idx, items):
28382840
assert (
28392841
sum(1 for item in result.all if item.status == BatchItemStatus.SUCCEEDED) < 98
28402842
)
2843+
2844+
2845+
def test_executor_exits_early_with_min_successful():
2846+
"""Test that parallel exits immediately when min_successful is reached without waiting for other branches."""
2847+
2848+
class TestExecutor(ConcurrentExecutor):
2849+
def execute_item(self, child_context, executable):
2850+
return executable.func()
2851+
2852+
execution_times = []
2853+
2854+
def fast_branch():
2855+
execution_times.append(("fast", time.time()))
2856+
return "fast_result"
2857+
2858+
def slow_branch():
2859+
execution_times.append(("slow_start", time.time()))
2860+
time.sleep(2) # Long sleep
2861+
execution_times.append(("slow_end", time.time()))
2862+
return "slow_result"
2863+
2864+
executables = [
2865+
Executable(0, fast_branch),
2866+
Executable(1, slow_branch),
2867+
]
2868+
2869+
completion_config = CompletionConfig(min_successful=1)
2870+
2871+
executor = TestExecutor(
2872+
executables=executables,
2873+
max_concurrency=2,
2874+
completion_config=completion_config,
2875+
sub_type_top="TOP",
2876+
sub_type_iteration="ITER",
2877+
name_prefix="test_",
2878+
serdes=None,
2879+
)
2880+
2881+
execution_state = Mock()
2882+
execution_state.create_checkpoint = Mock()
2883+
executor_context = Mock()
2884+
executor_context._create_step_id_for_logical_step = lambda idx: f"step_{idx}" # noqa: SLF001
2885+
executor_context._parent_id = "parent" # noqa: SLF001
2886+
2887+
def create_child_context(op_id):
2888+
child = Mock()
2889+
child.state = execution_state
2890+
return child
2891+
2892+
executor_context.create_child_context = create_child_context
2893+
2894+
start_time = time.time()
2895+
result = executor.execute(execution_state, executor_context)
2896+
elapsed_time = time.time() - start_time
2897+
2898+
# Should complete in less than 1.5 second (not wait for 2-second sleep)
2899+
assert elapsed_time < 1.5, f"Took {elapsed_time}s, expected < 1.5s"
2900+
2901+
# Result should show MIN_SUCCESSFUL_REACHED
2902+
assert result.completion_reason == CompletionReason.MIN_SUCCESSFUL_REACHED
2903+
2904+
# Fast branch should succeed
2905+
assert result.all[0].status == BatchItemStatus.SUCCEEDED
2906+
assert result.all[0].result == "fast_result"
2907+
2908+
# Slow branch should be marked as STARTED (incomplete)
2909+
assert result.all[1].status == BatchItemStatus.STARTED
2910+
2911+
# Verify counts
2912+
assert result.success_count == 1
2913+
assert result.failure_count == 0
2914+
assert result.started_count == 1
2915+
assert result.total_count == 2
2916+
2917+
2918+
def test_executor_returns_with_incomplete_branches():
2919+
"""Test that executor returns when min_successful is reached, leaving other branches incomplete."""
2920+
2921+
class TestExecutor(ConcurrentExecutor):
2922+
def execute_item(self, child_context, executable):
2923+
return executable.func()
2924+
2925+
operation_tracker = Mock()
2926+
2927+
def fast_branch():
2928+
operation_tracker.fast_executed()
2929+
return "fast_result"
2930+
2931+
def slow_branch():
2932+
operation_tracker.slow_started()
2933+
time.sleep(2) # Long sleep
2934+
operation_tracker.slow_completed()
2935+
return "slow_result"
2936+
2937+
executables = [
2938+
Executable(0, fast_branch),
2939+
Executable(1, slow_branch),
2940+
]
2941+
2942+
completion_config = CompletionConfig(min_successful=1)
2943+
2944+
executor = TestExecutor(
2945+
executables=executables,
2946+
max_concurrency=2,
2947+
completion_config=completion_config,
2948+
sub_type_top="TOP",
2949+
sub_type_iteration="ITER",
2950+
name_prefix="test_",
2951+
serdes=None,
2952+
)
2953+
2954+
execution_state = Mock()
2955+
execution_state.create_checkpoint = Mock()
2956+
executor_context = Mock()
2957+
executor_context._create_step_id_for_logical_step = lambda idx: f"step_{idx}" # noqa: SLF001
2958+
executor_context._parent_id = "parent" # noqa: SLF001
2959+
executor_context.create_child_context = lambda op_id: Mock(state=execution_state)
2960+
2961+
result = executor.execute(execution_state, executor_context)
2962+
2963+
# Verify fast branch executed
2964+
assert operation_tracker.fast_executed.call_count == 1
2965+
2966+
# Slow branch may or may not have started (depends on thread scheduling)
2967+
# but it definitely should not have completed
2968+
assert (
2969+
operation_tracker.slow_completed.call_count == 0
2970+
), "Executor should return before slow branch completes"
2971+
2972+
# Result should show MIN_SUCCESSFUL_REACHED
2973+
assert result.completion_reason == CompletionReason.MIN_SUCCESSFUL_REACHED
2974+
2975+
# Verify counts - one succeeded, one incomplete
2976+
assert result.success_count == 1
2977+
assert result.failure_count == 0
2978+
assert result.started_count == 1
2979+
assert result.total_count == 2
2980+
2981+
2982+
def test_executor_returns_before_slow_branch_completes():
2983+
"""Test that executor returns immediately when min_successful is reached, not waiting for slow branches."""
2984+
2985+
class TestExecutor(ConcurrentExecutor):
2986+
def execute_item(self, child_context, executable):
2987+
return executable.func()
2988+
2989+
slow_branch_mock = Mock()
2990+
2991+
def fast_func():
2992+
return "fast"
2993+
2994+
def slow_func():
2995+
time.sleep(3) # Sleep
2996+
slow_branch_mock.completed() # Should not be called before executor returns
2997+
return "slow"
2998+
2999+
executables = [Executable(0, fast_func), Executable(1, slow_func)]
3000+
completion_config = CompletionConfig(min_successful=1)
3001+
3002+
executor = TestExecutor(
3003+
executables=executables,
3004+
max_concurrency=2,
3005+
completion_config=completion_config,
3006+
sub_type_top="TOP",
3007+
sub_type_iteration="ITER",
3008+
name_prefix="test_",
3009+
serdes=None,
3010+
)
3011+
3012+
execution_state = Mock()
3013+
execution_state.create_checkpoint = Mock()
3014+
executor_context = Mock()
3015+
executor_context._create_step_id_for_logical_step = lambda idx: f"step_{idx}" # noqa: SLF001
3016+
executor_context._parent_id = "parent" # noqa: SLF001
3017+
executor_context.create_child_context = lambda op_id: Mock(state=execution_state)
3018+
3019+
result = executor.execute(execution_state, executor_context)
3020+
3021+
# Executor should have returned before slow branch completed
3022+
assert (
3023+
not slow_branch_mock.completed.called
3024+
), "Executor should return before slow branch completes"
3025+
3026+
# Result should show MIN_SUCCESSFUL_REACHED
3027+
assert result.completion_reason == CompletionReason.MIN_SUCCESSFUL_REACHED
3028+
3029+
# Verify counts
3030+
assert result.success_count == 1
3031+
assert result.failure_count == 0
3032+
assert result.started_count == 1
3033+
assert result.total_count == 2

0 commit comments

Comments
 (0)