Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 59 additions & 4 deletions src/dstack/_internal/cli/utils/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,16 +281,38 @@ def _format_job_name(
show_deployment_num: bool,
show_replica: bool,
show_job: bool,
group_index: Optional[int] = None,
last_shown_group_index: Optional[int] = None,
) -> str:
name_parts = []
prefix = ""
if show_replica:
name_parts.append(f"replica={job.job_spec.replica_num}")
# Show group information if replica groups are used
if group_index is not None:
# Show group=X replica=Y when group changes, or just replica=Y when same group
if group_index != last_shown_group_index:
# First job in group: use 3 spaces indent
prefix = " "
name_parts.append(f"group={group_index} replica={job.job_spec.replica_num}")
else:
# Subsequent job in same group: align "replica=" with first job's "replica="
# Calculate padding: width of " group={last_shown_group_index} "
padding_width = 3 + len(f"group={last_shown_group_index}") + 1
prefix = " " * padding_width
name_parts.append(f"replica={job.job_spec.replica_num}")
else:
# Legacy behavior: no replica groups
prefix = " "
name_parts.append(f"replica={job.job_spec.replica_num}")
else:
prefix = " "

if show_job:
name_parts.append(f"job={job.job_spec.job_num}")
name_suffix = (
f" deployment={latest_job_submission.deployment_num}" if show_deployment_num else ""
)
name_value = " " + (" ".join(name_parts) if name_parts else "")
name_value = prefix + (" ".join(name_parts) if name_parts else "")
name_value += name_suffix
return name_value

Expand Down Expand Up @@ -359,6 +381,14 @@ def get_runs_table(
)
merge_job_rows = len(run.jobs) == 1 and not show_deployment_num

# Replica Group Changes: Build mapping from replica group names to indices
group_name_to_index: Dict[str, int] = {}
# Replica Group Changes: Check if replica_groups attribute exists (only available for ServiceConfiguration)
replica_groups = getattr(run.run_spec.configuration, "replica_groups", None)
if replica_groups:
for idx, group in enumerate(replica_groups):
group_name_to_index[group.name] = idx

run_row: Dict[Union[str, int], Any] = {
"NAME": _format_run_name(run, show_deployment_num),
"SUBMITTED": format_date(run.submitted_at),
Expand All @@ -372,13 +402,35 @@ def get_runs_table(
if not merge_job_rows:
add_row_from_dict(table, run_row)

for job in run.jobs:
# Sort jobs by group index first, then by replica_num within each group
def get_job_sort_key(job: Job) -> tuple:
group_index = None
if group_name_to_index and job.job_spec.replica_group:
group_index = group_name_to_index.get(job.job_spec.replica_group)
# Use a large number for jobs without groups to put them at the end
return (group_index if group_index is not None else 999999, job.job_spec.replica_num)

sorted_jobs = sorted(run.jobs, key=get_job_sort_key)

last_shown_group_index: Optional[int] = None
for job in sorted_jobs:
latest_job_submission = job.job_submissions[-1]
status_formatted = _format_job_submission_status(latest_job_submission, verbose)

# Get group index for this job
group_index: Optional[int] = None
if group_name_to_index and job.job_spec.replica_group:
group_index = group_name_to_index.get(job.job_spec.replica_group)

job_row: Dict[Union[str, int], Any] = {
"NAME": _format_job_name(
job, latest_job_submission, show_deployment_num, show_replica, show_job
job,
latest_job_submission,
show_deployment_num,
show_replica,
show_job,
group_index=group_index,
last_shown_group_index=last_shown_group_index,
),
"STATUS": status_formatted,
"PROBES": _format_job_probes(
Expand All @@ -390,6 +442,9 @@ def get_runs_table(
"GPU": "-",
"PRICE": "-",
}
# Update last shown group index for next iteration
if group_index is not None:
last_shown_group_index = group_index
jpd = latest_job_submission.job_provisioning_data
if jpd is not None:
shared_offer: Optional[InstanceOfferWithAvailability] = None
Expand Down
157 changes: 157 additions & 0 deletions src/dstack/_internal/core/models/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,11 @@ class ConfigurationWithCommandsParams(CoreModel):

@root_validator
def check_image_or_commands_present(cls, values):
# If replica_groups is present, skip validation - commands come from replica groups
replica_groups = values.get("replica_groups")
if replica_groups:
return values

if not values.get("commands") and not values.get("image"):
raise ValueError("Either `commands` or `image` must be set")
return values
Expand Down Expand Up @@ -714,6 +719,85 @@ def schema_extra(schema: Dict[str, Any]):
)


class ReplicaGroup(ConfigurationWithCommandsParams, CoreModel):
name: Annotated[
str,
Field(description="The name of the replica group"),
]
replicas: Annotated[
Range[int],
Field(
description="The number of replicas. Can be a number (e.g. `2`) or a range (`0..4` or `1..8`). "
"If it's a range, the `scaling` property is required"
),
]
scaling: Annotated[
Optional[ScalingSpec],
Field(description="The auto-scaling rules. Required if `replicas` is set to a range"),
] = None
probes: Annotated[
list[ProbeConfig],
Field(description="List of probes used to determine job health for this replica group"),
] = []
rate_limits: Annotated[
list[RateLimit],
Field(description="Rate limiting rules for this replica group"),
] = []
# TODO: Extract to ConfigurationWithResourcesParams mixin
resources: Annotated[
ResourcesSpec,
Field(description="The resources requirements for replicas in this group"),
] = ResourcesSpec()

@validator("replicas")
def convert_replicas(cls, v: Range[int]) -> Range[int]:
if v.max is None:
raise ValueError("The maximum number of replicas is required")
if v.min is None:
v.min = 0
if v.min < 0:
raise ValueError("The minimum number of replicas must be greater than or equal to 0")
return v

@root_validator()
def override_commands_validation(cls, values):
"""
Override parent validator from ConfigurationWithCommandsParams.
ReplicaGroup always requires commands (no image option).
"""
commands = values.get("commands", [])
if not commands:
raise ValueError("`commands` must be set for replica groups")
return values

@root_validator()
def validate_scaling(cls, values):
scaling = values.get("scaling")
replicas = values.get("replicas")
if replicas and replicas.min != replicas.max and not scaling:
raise ValueError("When you set `replicas` to a range, ensure to specify `scaling`.")
if replicas and replicas.min == replicas.max and scaling:
raise ValueError("To use `scaling`, `replicas` must be set to a range.")
return values

@validator("rate_limits")
def validate_rate_limits(cls, v: list[RateLimit]) -> list[RateLimit]:
counts = Counter(limit.prefix for limit in v)
duplicates = [prefix for prefix, count in counts.items() if count > 1]
if duplicates:
raise ValueError(
f"Prefixes {duplicates} are used more than once."
" Each rate limit should have a unique path prefix"
)
return v

@validator("probes")
def validate_probes(cls, v: list[ProbeConfig]) -> list[ProbeConfig]:
if has_duplicates(v):
raise ValueError("Probes must be unique")
return v


class ServiceConfigurationParams(CoreModel):
port: Annotated[
# NOTE: it's a PortMapping for historical reasons. Only `port.container_port` is used.
Expand Down Expand Up @@ -771,6 +855,19 @@ class ServiceConfigurationParams(CoreModel):
Field(description="List of probes used to determine job health"),
] = []

replica_groups: Annotated[
Optional[List[ReplicaGroup]],
Field(
description=(
"List of replica groups. Each group defines replicas with shared configuration "
"(commands, port, resources, scaling, probes, rate_limits). "
"When specified, the top-level `replicas`, `commands`, `port`, `resources`, "
"`scaling`, `probes`, and `rate_limits` are ignored. "
"Each replica group must have a unique name."
)
),
] = None

@validator("port")
def convert_port(cls, v) -> PortMapping:
if isinstance(v, int):
Expand Down Expand Up @@ -807,6 +904,12 @@ def validate_gateway(

@root_validator()
def validate_scaling(cls, values):
replica_groups = values.get("replica_groups")
# If replica_groups are set, we don't need to validate scaling.
# Each replica group has its own scaling.
if replica_groups:
return values

scaling = values.get("scaling")
replicas = values.get("replicas")
if replicas and replicas.min != replicas.max and not scaling:
Expand All @@ -815,6 +918,42 @@ def validate_scaling(cls, values):
raise ValueError("To use `scaling`, `replicas` must be set to a range.")
return values

@root_validator()
def normalize_to_replica_groups(cls, values):
replica_groups = values.get("replica_groups")
if replica_groups:
return values

# TEMP: prove we’re here and see the inputs
print(
"[normalize_to_replica_groups]",
"commands:",
values.get("commands"),
"replicas:",
values.get("replicas"),
"resources:",
values.get("resources"),
"scaling:",
values.get("scaling"),
"probes:",
values.get("probes"),
"rate_limits:",
values.get("rate_limits"),
)
# If replica_groups is not set, we need to normalize the configuration to replica groups.
values["replica_groups"] = [
ReplicaGroup(
name="default",
replicas=values.get("replicas"),
commands=values.get("commands"),
resources=values.get("resources"),
scaling=values.get("scaling"),
probes=values.get("probes"),
rate_limits=values.get("rate_limits"),
)
]
return values

@validator("rate_limits")
def validate_rate_limits(cls, v: list[RateLimit]) -> list[RateLimit]:
counts = Counter(limit.prefix for limit in v)
Expand All @@ -836,6 +975,24 @@ def validate_probes(cls, v: list[ProbeConfig]) -> list[ProbeConfig]:
raise ValueError("Probes must be unique")
return v

@validator("replica_groups")
def validate_replica_groups(
cls, v: Optional[List[ReplicaGroup]]
) -> Optional[List[ReplicaGroup]]:
if v is None:
return v
if not v:
raise ValueError("`replica_groups` cannot be an empty list")
# Check for duplicate names
names = [group.name for group in v]
if len(names) != len(set(names)):
duplicates = [name for name in set(names) if names.count(name) > 1]
raise ValueError(
f"Duplicate replica group names found: {duplicates}. "
"Each replica group must have a unique name."
)
return v


class ServiceConfigurationConfig(
ProfileParamsConfig,
Expand Down
1 change: 1 addition & 0 deletions src/dstack/_internal/core/models/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ class JobSpec(CoreModel):
job_num: int
job_name: str
jobs_per_replica: int = 1 # default value for backward compatibility
replica_group: Optional[str] = "default"
app_specs: Optional[List[AppSpec]]
user: Optional[UnixUser] = None # default value for backward compatibility
commands: List[str]
Expand Down
Loading