Skip to content

Commit 8c9ab9b

Browse files
pintaoz-awspintaoz
andauthored
Add validation in conda env name (#5430)
Co-authored-by: pintaoz <pintaoz@amazon.com>
1 parent 87c525f commit 8c9ab9b

File tree

2 files changed

+47
-1
lines changed

2 files changed

+47
-1
lines changed

sagemaker-core/src/sagemaker/core/remote_function/client.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,9 @@ def remote(
303303
"""
304304

305305
def _remote(func):
306+
307+
if job_conda_env:
308+
RemoteExecutor._validate_env_name(job_conda_env)
306309

307310
job_settings = _JobSettings(
308311
dependencies=dependencies,
@@ -774,6 +777,9 @@ def __init__(
774777
+ "without spark_config or use_torchrun or use_mpirun. "
775778
+ "Please provide instance_count = 1"
776779
)
780+
781+
if job_conda_env:
782+
self._validate_env_name(job_conda_env)
777783

778784
self.job_settings = _JobSettings(
779785
dependencies=dependencies,
@@ -951,6 +957,25 @@ def _validate_submit_args(func, *args, **kwargs):
951957
+ f"{'arguments' if len(missing_kwargs) > 1 else 'argument'}: "
952958
+ f"{missing_kwargs_string}"
953959
)
960+
961+
@staticmethod
962+
def _validate_env_name(env_name: str) -> None:
963+
"""Validate conda environment name to prevent command injection.
964+
965+
Args:
966+
env_name (str): The environment name to validate
967+
968+
Raises:
969+
ValueError: If the environment name contains invalid characters
970+
"""
971+
972+
# Allow only alphanumeric, underscore, and hyphen
973+
import re
974+
if not re.match(r'^[a-zA-Z0-9_-]+$', env_name):
975+
raise ValueError(
976+
f"Invalid environment name '{env_name}'. "
977+
"Only alphanumeric characters, underscores, and hyphens are allowed."
978+
)
954979

955980

956981
class Future(object):

sagemaker-core/tests/unit/remote_function/test_client.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,28 @@ def my_function(x):
6464

6565
with pytest.raises(TypeError):
6666
RemoteExecutor._validate_submit_args(my_function, 1, 2)
67-
67+
68+
def test_validate_env_names_valid(self):
69+
"""Test valid conda environment names"""
70+
valid_names = [
71+
"myenv",
72+
"base",
73+
"py39",
74+
"env123",
75+
]
76+
for name in valid_names:
77+
RemoteExecutor._validate_env_name(name)
78+
79+
def test_validate_env_names_invalid(self):
80+
"""Test invalid conda environment names"""
81+
invalid_names = [
82+
"env && echo PWNED",
83+
"env > /tmp/output.txt",
84+
"sagemaker-rce-env; echo PWNED_FROM_CONDA_ENV > /tmp/conda_rce.txt #",
85+
]
86+
for name in invalid_names:
87+
with pytest.raises(ValueError):
88+
RemoteExecutor._validate_env_name(name)
6889

6990
class TestWorkerFunctions:
7091
"""Test worker thread functions"""

0 commit comments

Comments
 (0)