Skip to content

Commit dced541

Browse files
committed
Add tests.
1 parent 6d55882 commit dced541

File tree

1 file changed

+197
-0
lines changed

1 file changed

+197
-0
lines changed
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
from unittest import TestCase, main, mock, skipUnless
2+
3+
from oci import Response
4+
5+
from ads.jobs import ContainerRuntime, DataScienceJob, Job, PyTorchDistributedRuntime
6+
from ads.jobs.builders.infrastructure.dsc_job_runtime import MULTI_NODE_JOB_SUPPORT
7+
8+
test_cases = {"torchrun": "torchrun test_torch_distributed.py"}
9+
10+
11+
LOG_GROUP_ID = "ocid1.loggroup.oc1.iad.aaa"
12+
LOG_ID = "ocid1.log.oc1.iad.aaa"
13+
SUBNET_ID = "ocid1.subnet.oc1.iad.aaa"
14+
SHAPE_NAME = "VM.GPU.A10.2"
15+
CONDA_NAME = "pytorch24_p310_gpu_x86_64_v1"
16+
17+
CONDA_ENV_VARS = {
18+
"CONDA_ENV_SLUG": CONDA_NAME,
19+
"CONDA_ENV_TYPE": "service",
20+
"JOB_RUN_ENTRYPOINT": "driver_pytorch.py",
21+
"NODE_COUNT": "2",
22+
"OCI_LOG_LEVEL": "DEBUG",
23+
"OCI__LAUNCH_CMD": "torchrun artifact.py",
24+
}
25+
26+
CONTAINER_ENV_VARS = {
27+
"NODE_COUNT": "2",
28+
"OCI_LOG_LEVEL": "DEBUG",
29+
}
30+
31+
32+
@skipUnless(
33+
MULTI_NODE_JOB_SUPPORT,
34+
"Multi-Node Job is not supported by the OCI Python SDK installed.",
35+
)
36+
class MultiNodeJobTest(TestCase):
37+
38+
def init_job_infra(self):
39+
return (
40+
DataScienceJob()
41+
.with_compartment_id("ocid1.compartment.oc1..aaa")
42+
.with_project_id("ocid1.datascienceproject.oc1.iad.aaa")
43+
.with_log_group_id(LOG_GROUP_ID)
44+
.with_log_id(LOG_ID)
45+
.with_shape_name(SHAPE_NAME)
46+
.with_block_storage_size(256)
47+
)
48+
49+
def assert_create_job_details(self, create_job_details, envs):
50+
# Check log config
51+
log_config = create_job_details.job_log_configuration_details
52+
self.assertEqual(log_config.log_id, LOG_ID)
53+
self.assertEqual(log_config.log_group_id, LOG_GROUP_ID)
54+
55+
# Check top level configs
56+
self.assertIsNone(create_job_details.job_configuration_details)
57+
self.assertIsNone(create_job_details.job_environment_configuration_details)
58+
self.assertIsNone(create_job_details.job_infrastructure_configuration_details)
59+
60+
job_node_configuration_details = (
61+
create_job_details.job_node_configuration_details
62+
)
63+
self.assertIsNotNone(job_node_configuration_details)
64+
# Check network config
65+
self.assertEqual(
66+
job_node_configuration_details.job_network_configuration.job_network_type,
67+
"DEFAULT_NETWORK",
68+
)
69+
# Check node group config
70+
self.assertEqual(
71+
len(
72+
job_node_configuration_details.job_node_group_configuration_details_list
73+
),
74+
1,
75+
)
76+
node_group_config = (
77+
job_node_configuration_details.job_node_group_configuration_details_list[0]
78+
)
79+
self.assertEqual(
80+
node_group_config.job_configuration_details.environment_variables,
81+
envs,
82+
)
83+
self.assertEqual(node_group_config.replicas, 2)
84+
# Check infra config
85+
infra_config = node_group_config.job_infrastructure_configuration_details
86+
self.assertEqual(infra_config.shape_name, "VM.GPU.A10.2")
87+
self.assertEqual(infra_config.block_storage_size_in_gbs, 256)
88+
self.assertEqual(infra_config.job_infrastructure_type, "MULTI_NODE")
89+
90+
def assert_create_job_run_details(self, create_job_run_details):
91+
self.assertIsNone(create_job_run_details.job_configuration_override_details)
92+
self.assertIsNone(
93+
create_job_run_details.job_infrastructure_configuration_override_details
94+
)
95+
self.assertIsNone(create_job_run_details.job_log_configuration_override_details)
96+
self.assertIsNone(
97+
create_job_run_details.job_node_configuration_override_details
98+
)
99+
100+
@mock.patch(
101+
"ads.jobs.builders.runtimes.pytorch_runtime.PyTorchDistributedArtifact.build"
102+
)
103+
@mock.patch("ads.jobs.builders.infrastructure.dsc_job.DSCJob.upload_artifact")
104+
@mock.patch("oci.data_science.DataScienceClient.create_job_run")
105+
@mock.patch("oci.data_science.DataScienceClient.create_job")
106+
def test_create_multi_node_job_with_conda(self, patched_create, patched_run, *args):
107+
patched_create.return_value = Response(
108+
status=200, headers=None, request=None, data=None
109+
)
110+
111+
infra = self.init_job_infra()
112+
runtime = (
113+
PyTorchDistributedRuntime()
114+
# Specify the service conda environment by slug name.
115+
.with_service_conda(CONDA_NAME)
116+
.with_command("torchrun artifact.py")
117+
.with_environment_variable(OCI_LOG_LEVEL="DEBUG")
118+
.with_replica(2)
119+
)
120+
job = Job(name="DT Test").with_infrastructure(infra).with_runtime(runtime)
121+
job.create()
122+
create_job_details = patched_create.call_args.args[0]
123+
124+
self.assert_create_job_details(
125+
create_job_details=create_job_details,
126+
envs=CONDA_ENV_VARS,
127+
)
128+
node_group_config = create_job_details.job_node_configuration_details.job_node_group_configuration_details_list[
129+
0
130+
]
131+
self.assertIsNone(node_group_config.job_environment_configuration_details)
132+
133+
# Create Job with subnet_id
134+
patched_create.reset_mock()
135+
infra.with_subnet_id(SUBNET_ID)
136+
job = Job(name="DT Test").with_infrastructure(infra).with_runtime(runtime)
137+
job.create()
138+
create_job_details = patched_create.call_args.args[0]
139+
job_node_configuration_details = (
140+
create_job_details.job_node_configuration_details
141+
)
142+
self.assertEqual(
143+
job_node_configuration_details.job_network_configuration.subnet_id,
144+
SUBNET_ID,
145+
)
146+
patched_run.return_value = Response(
147+
status=200, headers=None, request=None, data=None
148+
)
149+
150+
# Check the payload for creating a job run
151+
job.run()
152+
create_job_run_details = patched_run.call_args.args[0]
153+
self.assert_create_job_run_details(create_job_run_details)
154+
155+
@mock.patch("oci.data_science.DataScienceClient.create_job_run")
156+
@mock.patch("oci.data_science.DataScienceClient.create_job")
157+
def test_create_multi_node_job_with_container(
158+
self, patched_create, patched_run, *args
159+
):
160+
patched_create.return_value = Response(
161+
status=200, headers=None, request=None, data=None
162+
)
163+
164+
infra = self.init_job_infra()
165+
runtime = (
166+
ContainerRuntime()
167+
# Specify the service conda environment by slug name.
168+
.with_image("container_image")
169+
.with_environment_variable(OCI_LOG_LEVEL="DEBUG")
170+
.with_replica(2)
171+
)
172+
job = Job(name="DT Test").with_infrastructure(infra).with_runtime(runtime)
173+
job.create()
174+
create_job_details = patched_create.call_args.args[0]
175+
self.assert_create_job_details(
176+
create_job_details=create_job_details,
177+
envs=CONTAINER_ENV_VARS,
178+
)
179+
node_group_config = create_job_details.job_node_configuration_details.job_node_group_configuration_details_list[
180+
0
181+
]
182+
container_config = node_group_config.job_environment_configuration_details
183+
self.assertEqual(container_config.job_environment_type, "OCIR_CONTAINER")
184+
self.assertEqual(container_config.image, "container_image")
185+
186+
patched_run.return_value = Response(
187+
status=200, headers=None, request=None, data=None
188+
)
189+
190+
# Check the payload for creating a job run
191+
job.run()
192+
create_job_run_details = patched_run.call_args.args[0]
193+
self.assert_create_job_run_details(create_job_run_details)
194+
195+
196+
if __name__ == "__main__":
197+
main()

0 commit comments

Comments
 (0)