Skip to content

Commit 2933d9a

Browse files
author
Jaya Kasiraj
committed
feat: add support for mlflow to pipelines
1 parent 8695cca commit 2933d9a

File tree

5 files changed

+492
-10
lines changed

5 files changed

+492
-10
lines changed
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""MLflow config for SageMaker pipeline."""
14+
from __future__ import absolute_import
15+
16+
from typing import Dict, Any
17+
18+
19+
class MlflowConfig:
20+
"""MLflow configuration for SageMaker pipeline."""
21+
22+
def __init__(
23+
self,
24+
mlflow_resource_arn: str,
25+
mlflow_experiment_name: str,
26+
):
27+
"""Create an MLflow configuration for SageMaker Pipeline.
28+
29+
Examples:
30+
Basic MLflow configuration::
31+
32+
mlflow_config = MlflowConfig(
33+
mlflow_resource_arn="arn:aws:sagemaker:us-west-2:123456789012:mlflow-tracking-server/my-server",
34+
mlflow_experiment_name="my-experiment"
35+
)
36+
37+
pipeline = Pipeline(
38+
name="MyPipeline",
39+
steps=[...],
40+
mlflow_config=mlflow_config
41+
)
42+
43+
Runtime override of experiment name::
44+
45+
# Override experiment name for a specific execution
46+
execution = pipeline.start(mlflow_experiment_name="custom-experiment")
47+
48+
Args:
49+
mlflow_resource_arn (str): The ARN of the MLflow tracking server resource.
50+
mlflow_experiment_name (str): The name of the MLflow experiment to be used for tracking.
51+
"""
52+
self.mlflow_resource_arn = mlflow_resource_arn
53+
self.mlflow_experiment_name = mlflow_experiment_name
54+
55+
def to_request(self) -> Dict[str, Any]:
56+
"""Returns: the request structure."""
57+
58+
return {
59+
"MlflowResourceArn": self.mlflow_resource_arn,
60+
"MlflowExperimentName": self.mlflow_experiment_name,
61+
}

sagemaker-mlops/src/sagemaker/mlops/workflow/pipeline.py

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,13 @@
3434
from sagemaker.core.remote_function.job import JOBS_CONTAINER_ENTRYPOINT
3535
from sagemaker.core.s3 import s3_path_join
3636
from sagemaker.core.helper.session_helper import Session
37-
from sagemaker.core.common_utils import resolve_value_from_config, retry_with_backoff, format_tags, Tags
37+
from sagemaker.core.common_utils import (
38+
resolve_value_from_config,
39+
retry_with_backoff,
40+
format_tags,
41+
Tags,
42+
)
43+
3844
# Orchestration imports (now in mlops)
3945
from sagemaker.mlops.workflow.callback_step import CallbackOutput, CallbackStep
4046
from sagemaker.mlops.workflow._event_bridge_client_helper import (
@@ -44,19 +50,24 @@
4450
EXECUTION_TIME_PIPELINE_PARAMETER_FORMAT,
4551
)
4652
from sagemaker.mlops.workflow.lambda_step import LambdaOutput, LambdaStep
53+
from sagemaker.mlops.workflow.mlflow_config import MlflowConfig
4754
from sagemaker.core.helper.pipeline_variable import (
4855
RequestType,
4956
PipelineVariable,
5057
)
58+
5159
# Primitive imports (stay in core)
5260
from sagemaker.core.workflow.execution_variables import ExecutionVariables
5361
from sagemaker.core.workflow.parameters import Parameter
62+
5463
# Orchestration imports (now in mlops)
5564
from sagemaker.core.workflow.pipeline_definition_config import PipelineDefinitionConfig
5665
from sagemaker.mlops.workflow.pipeline_experiment_config import PipelineExperimentConfig
5766
from sagemaker.mlops.workflow.parallelism_config import ParallelismConfiguration
67+
5868
# Primitive imports (stay in core)
5969
from sagemaker.core.workflow.properties import Properties
70+
6071
# Orchestration imports (now in mlops)
6172
from sagemaker.mlops.workflow.selective_execution_config import SelectiveExecutionConfig
6273
from sagemaker.core.workflow.step_outputs import StepOutput
@@ -87,6 +98,7 @@ def __init__(
8798
name: str = "",
8899
parameters: Optional[Sequence[Parameter]] = None,
89100
pipeline_experiment_config: Optional[PipelineExperimentConfig] = _DEFAULT_EXPERIMENT_CFG,
101+
mlflow_config: Optional[MlflowConfig] = None,
90102
steps: Optional[Sequence[Union[Step, StepOutput]]] = None,
91103
sagemaker_session: Optional[Session] = None,
92104
pipeline_definition_config: Optional[PipelineDefinitionConfig] = _DEFAULT_DEFINITION_CFG,
@@ -102,6 +114,8 @@ def __init__(
102114
the same name already exists. By default, pipeline name is used as
103115
experiment name and execution id is used as the trial name.
104116
If set to None, no experiment or trial will be created automatically.
117+
mlflow_config (Optional[MlflowConfig]): If set, the pipeline will be configured
118+
with MLflow tracking for experiment tracking and model versioning.
105119
steps (Sequence[Union[Step, StepOutput]]): The list of the
106120
non-conditional steps associated with the pipeline. Any steps that are within the
107121
`if_steps` or `else_steps` of a `ConditionStep` cannot be listed in the steps of a
@@ -118,6 +132,7 @@ def __init__(
118132
self.name = name
119133
self.parameters = parameters if parameters else []
120134
self.pipeline_experiment_config = pipeline_experiment_config
135+
self.mlflow_config = mlflow_config
121136
self.steps = steps if steps else []
122137
self.sagemaker_session = sagemaker_session if sagemaker_session else Session()
123138
self.pipeline_definition_config = pipeline_definition_config
@@ -337,6 +352,7 @@ def start(
337352
execution_description: str = None,
338353
parallelism_config: ParallelismConfiguration = None,
339354
selective_execution_config: SelectiveExecutionConfig = None,
355+
mlflow_experiment_name: str = None,
340356
):
341357
"""Starts a Pipeline execution in the Workflow service.
342358
@@ -350,6 +366,10 @@ def start(
350366
over the parallelism configuration of the parent pipeline.
351367
selective_execution_config (Optional[SelectiveExecutionConfig]): The configuration for
352368
selective step execution.
369+
mlflow_experiment_name (str): Optional MLflow experiment name to override
370+
the experiment name specified in the pipeline's mlflow_config.
371+
If provided, this will override the experiment name for this specific
372+
pipeline execution only, without modifying the pipeline definition.
353373
354374
Returns:
355375
A `_PipelineExecution` instance, if successful.
@@ -371,6 +391,7 @@ def start(
371391
PipelineExecutionDisplayName=execution_display_name,
372392
ParallelismConfiguration=parallelism_config,
373393
SelectiveExecutionConfig=selective_execution_config,
394+
MlflowExperimentName=mlflow_experiment_name,
374395
)
375396
if self.sagemaker_session.local_mode:
376397
update_args(kwargs, PipelineParameters=parameters)
@@ -409,14 +430,25 @@ def definition(self) -> str:
409430
if self.pipeline_experiment_config is not None
410431
else None
411432
),
433+
"MlflowConfig": (
434+
self.mlflow_config.to_request() if self.mlflow_config is not None else None
435+
),
412436
"Steps": list_to_request(compiled_steps),
413437
}
414-
415-
request_dict["PipelineExperimentConfig"] = interpolate(
416-
request_dict["PipelineExperimentConfig"], {}, {}, pipeline_name=self.name
417-
)
418438
callback_output_to_step_map = _map_callback_outputs(self.steps)
419439
lambda_output_to_step_name = _map_lambda_outputs(self.steps)
440+
request_dict["PipelineExperimentConfig"] = interpolate(
441+
request_dict["PipelineExperimentConfig"],
442+
callback_output_to_step_map=callback_output_to_step_map,
443+
lambda_output_to_step_map=lambda_output_to_step_name,
444+
pipeline_name=self.name,
445+
)
446+
request_dict["MlflowConfig"] = interpolate(
447+
request_dict["MlflowConfig"],
448+
callback_output_to_step_map=callback_output_to_step_map,
449+
lambda_output_to_step_map=lambda_output_to_step_name,
450+
pipeline_name=self.name,
451+
)
420452
request_dict["Steps"] = interpolate(
421453
request_dict["Steps"],
422454
callback_output_to_step_map=callback_output_to_step_map,
@@ -1081,7 +1113,6 @@ def _initialize_adjacency_list(self) -> Dict[str, List[str]]:
10811113
if isinstance(child_step, Step):
10821114
dependency_list[child_step.name].add(step.name)
10831115

1084-
10851116
adjacency_list = {}
10861117
for step in dependency_list:
10871118
for step_dependency in dependency_list[step]:
@@ -1119,9 +1150,7 @@ def is_cyclic_helper(current_step):
11191150
return True
11201151
return False
11211152

1122-
def get_steps_in_sub_dag(
1123-
self, current_step: Step, sub_dag_steps: Set[str] = None
1124-
) -> Set[str]:
1153+
def get_steps_in_sub_dag(self, current_step: Step, sub_dag_steps: Set[str] = None) -> Set[str]:
11251154
"""Get names of all steps (including current step) in the sub dag of current step.
11261155
11271156
Returns a set of step names in the sub dag.
@@ -1161,4 +1190,4 @@ def __next__(self) -> Step:
11611190

11621191
while self.stack:
11631192
return self.step_map.get(self.stack.pop())
1164-
raise StopIteration
1193+
raise StopIteration

0 commit comments

Comments
 (0)