diff --git a/src/dstack/_internal/cli/utils/run.py b/src/dstack/_internal/cli/utils/run.py index 68dc828f7..48d121bb9 100644 --- a/src/dstack/_internal/cli/utils/run.py +++ b/src/dstack/_internal/cli/utils/run.py @@ -281,16 +281,38 @@ def _format_job_name( show_deployment_num: bool, show_replica: bool, show_job: bool, + group_index: Optional[int] = None, + last_shown_group_index: Optional[int] = None, ) -> str: name_parts = [] + prefix = "" if show_replica: - name_parts.append(f"replica={job.job_spec.replica_num}") + # Show group information if replica groups are used + if group_index is not None: + # Show group=X replica=Y when group changes, or just replica=Y when same group + if group_index != last_shown_group_index: + # First job in group: use 3 spaces indent + prefix = " " + name_parts.append(f"group={group_index} replica={job.job_spec.replica_num}") + else: + # Subsequent job in same group: align "replica=" with first job's "replica=" + # Calculate padding: width of " group={last_shown_group_index} " + padding_width = 3 + len(f"group={last_shown_group_index}") + 1 + prefix = " " * padding_width + name_parts.append(f"replica={job.job_spec.replica_num}") + else: + # Legacy behavior: no replica groups + prefix = " " + name_parts.append(f"replica={job.job_spec.replica_num}") + else: + prefix = " " + if show_job: name_parts.append(f"job={job.job_spec.job_num}") name_suffix = ( f" deployment={latest_job_submission.deployment_num}" if show_deployment_num else "" ) - name_value = " " + (" ".join(name_parts) if name_parts else "") + name_value = prefix + (" ".join(name_parts) if name_parts else "") name_value += name_suffix return name_value @@ -359,6 +381,14 @@ def get_runs_table( ) merge_job_rows = len(run.jobs) == 1 and not show_deployment_num + # Replica Group Changes: Build mapping from replica group names to indices + group_name_to_index: Dict[str, int] = {} + # Replica Group Changes: Check if replica_groups attribute exists (only available for ServiceConfiguration) + replica_groups = getattr(run.run_spec.configuration, "replica_groups", None) + if replica_groups: + for idx, group in enumerate(replica_groups): + group_name_to_index[group.name] = idx + run_row: Dict[Union[str, int], Any] = { "NAME": _format_run_name(run, show_deployment_num), "SUBMITTED": format_date(run.submitted_at), @@ -372,13 +402,35 @@ def get_runs_table( if not merge_job_rows: add_row_from_dict(table, run_row) - for job in run.jobs: + # Sort jobs by group index first, then by replica_num within each group + def get_job_sort_key(job: Job) -> tuple: + group_index = None + if group_name_to_index and job.job_spec.replica_group: + group_index = group_name_to_index.get(job.job_spec.replica_group) + # Use a large number for jobs without groups to put them at the end + return (group_index if group_index is not None else 999999, job.job_spec.replica_num) + + sorted_jobs = sorted(run.jobs, key=get_job_sort_key) + + last_shown_group_index: Optional[int] = None + for job in sorted_jobs: latest_job_submission = job.job_submissions[-1] status_formatted = _format_job_submission_status(latest_job_submission, verbose) + # Get group index for this job + group_index: Optional[int] = None + if group_name_to_index and job.job_spec.replica_group: + group_index = group_name_to_index.get(job.job_spec.replica_group) + job_row: Dict[Union[str, int], Any] = { "NAME": _format_job_name( - job, latest_job_submission, show_deployment_num, show_replica, show_job + job, + latest_job_submission, + show_deployment_num, + show_replica, + show_job, + group_index=group_index, + last_shown_group_index=last_shown_group_index, ), "STATUS": status_formatted, "PROBES": _format_job_probes( @@ -390,6 +442,9 @@ def get_runs_table( "GPU": "-", "PRICE": "-", } + # Update last shown group index for next iteration + if group_index is not None: + last_shown_group_index = group_index jpd = latest_job_submission.job_provisioning_data if jpd is not None: shared_offer: Optional[InstanceOfferWithAvailability] = None diff --git a/src/dstack/_internal/core/models/configurations.py b/src/dstack/_internal/core/models/configurations.py index 158c59b34..93dd9909d 100644 --- a/src/dstack/_internal/core/models/configurations.py +++ b/src/dstack/_internal/core/models/configurations.py @@ -612,6 +612,11 @@ class ConfigurationWithCommandsParams(CoreModel): @root_validator def check_image_or_commands_present(cls, values): + # If replica_groups is present, skip validation - commands come from replica groups + replica_groups = values.get("replica_groups") + if replica_groups: + return values + if not values.get("commands") and not values.get("image"): raise ValueError("Either `commands` or `image` must be set") return values @@ -714,6 +719,85 @@ def schema_extra(schema: Dict[str, Any]): ) +class ReplicaGroup(ConfigurationWithCommandsParams, CoreModel): + name: Annotated[ + str, + Field(description="The name of the replica group"), + ] + replicas: Annotated[ + Range[int], + Field( + description="The number of replicas. Can be a number (e.g. `2`) or a range (`0..4` or `1..8`). " + "If it's a range, the `scaling` property is required" + ), + ] + scaling: Annotated[ + Optional[ScalingSpec], + Field(description="The auto-scaling rules. Required if `replicas` is set to a range"), + ] = None + probes: Annotated[ + list[ProbeConfig], + Field(description="List of probes used to determine job health for this replica group"), + ] = [] + rate_limits: Annotated[ + list[RateLimit], + Field(description="Rate limiting rules for this replica group"), + ] = [] + # TODO: Extract to ConfigurationWithResourcesParams mixin + resources: Annotated[ + ResourcesSpec, + Field(description="The resources requirements for replicas in this group"), + ] = ResourcesSpec() + + @validator("replicas") + def convert_replicas(cls, v: Range[int]) -> Range[int]: + if v.max is None: + raise ValueError("The maximum number of replicas is required") + if v.min is None: + v.min = 0 + if v.min < 0: + raise ValueError("The minimum number of replicas must be greater than or equal to 0") + return v + + @root_validator() + def override_commands_validation(cls, values): + """ + Override parent validator from ConfigurationWithCommandsParams. + ReplicaGroup always requires commands (no image option). + """ + commands = values.get("commands", []) + if not commands: + raise ValueError("`commands` must be set for replica groups") + return values + + @root_validator() + def validate_scaling(cls, values): + scaling = values.get("scaling") + replicas = values.get("replicas") + if replicas and replicas.min != replicas.max and not scaling: + raise ValueError("When you set `replicas` to a range, ensure to specify `scaling`.") + if replicas and replicas.min == replicas.max and scaling: + raise ValueError("To use `scaling`, `replicas` must be set to a range.") + return values + + @validator("rate_limits") + def validate_rate_limits(cls, v: list[RateLimit]) -> list[RateLimit]: + counts = Counter(limit.prefix for limit in v) + duplicates = [prefix for prefix, count in counts.items() if count > 1] + if duplicates: + raise ValueError( + f"Prefixes {duplicates} are used more than once." + " Each rate limit should have a unique path prefix" + ) + return v + + @validator("probes") + def validate_probes(cls, v: list[ProbeConfig]) -> list[ProbeConfig]: + if has_duplicates(v): + raise ValueError("Probes must be unique") + return v + + class ServiceConfigurationParams(CoreModel): port: Annotated[ # NOTE: it's a PortMapping for historical reasons. Only `port.container_port` is used. @@ -771,6 +855,19 @@ class ServiceConfigurationParams(CoreModel): Field(description="List of probes used to determine job health"), ] = [] + replica_groups: Annotated[ + Optional[List[ReplicaGroup]], + Field( + description=( + "List of replica groups. Each group defines replicas with shared configuration " + "(commands, port, resources, scaling, probes, rate_limits). " + "When specified, the top-level `replicas`, `commands`, `port`, `resources`, " + "`scaling`, `probes`, and `rate_limits` are ignored. " + "Each replica group must have a unique name." + ) + ), + ] = None + @validator("port") def convert_port(cls, v) -> PortMapping: if isinstance(v, int): @@ -807,6 +904,12 @@ def validate_gateway( @root_validator() def validate_scaling(cls, values): + replica_groups = values.get("replica_groups") + # If replica_groups are set, we don't need to validate scaling. + # Each replica group has its own scaling. + if replica_groups: + return values + scaling = values.get("scaling") replicas = values.get("replicas") if replicas and replicas.min != replicas.max and not scaling: @@ -815,6 +918,42 @@ def validate_scaling(cls, values): raise ValueError("To use `scaling`, `replicas` must be set to a range.") return values + @root_validator() + def normalize_to_replica_groups(cls, values): + replica_groups = values.get("replica_groups") + if replica_groups: + return values + + # TEMP: prove we’re here and see the inputs + print( + "[normalize_to_replica_groups]", + "commands:", + values.get("commands"), + "replicas:", + values.get("replicas"), + "resources:", + values.get("resources"), + "scaling:", + values.get("scaling"), + "probes:", + values.get("probes"), + "rate_limits:", + values.get("rate_limits"), + ) + # If replica_groups is not set, we need to normalize the configuration to replica groups. + values["replica_groups"] = [ + ReplicaGroup( + name="default", + replicas=values.get("replicas"), + commands=values.get("commands"), + resources=values.get("resources"), + scaling=values.get("scaling"), + probes=values.get("probes"), + rate_limits=values.get("rate_limits"), + ) + ] + return values + @validator("rate_limits") def validate_rate_limits(cls, v: list[RateLimit]) -> list[RateLimit]: counts = Counter(limit.prefix for limit in v) @@ -836,6 +975,24 @@ def validate_probes(cls, v: list[ProbeConfig]) -> list[ProbeConfig]: raise ValueError("Probes must be unique") return v + @validator("replica_groups") + def validate_replica_groups( + cls, v: Optional[List[ReplicaGroup]] + ) -> Optional[List[ReplicaGroup]]: + if v is None: + return v + if not v: + raise ValueError("`replica_groups` cannot be an empty list") + # Check for duplicate names + names = [group.name for group in v] + if len(names) != len(set(names)): + duplicates = [name for name in set(names) if names.count(name) > 1] + raise ValueError( + f"Duplicate replica group names found: {duplicates}. " + "Each replica group must have a unique name." + ) + return v + class ServiceConfigurationConfig( ProfileParamsConfig, diff --git a/src/dstack/_internal/core/models/runs.py b/src/dstack/_internal/core/models/runs.py index 13e6a1572..9b8dfae4f 100644 --- a/src/dstack/_internal/core/models/runs.py +++ b/src/dstack/_internal/core/models/runs.py @@ -253,6 +253,7 @@ class JobSpec(CoreModel): job_num: int job_name: str jobs_per_replica: int = 1 # default value for backward compatibility + replica_group: Optional[str] = "default" app_specs: Optional[List[AppSpec]] user: Optional[UnixUser] = None # default value for backward compatibility commands: List[str] diff --git a/src/dstack/_internal/server/background/tasks/process_runs.py b/src/dstack/_internal/server/background/tasks/process_runs.py index 4ab2633d9..4c074da25 100644 --- a/src/dstack/_internal/server/background/tasks/process_runs.py +++ b/src/dstack/_internal/server/background/tasks/process_runs.py @@ -1,5 +1,6 @@ import asyncio import datetime +import json from typing import List, Optional, Set, Tuple from sqlalchemy import and_, or_, select @@ -8,6 +9,7 @@ import dstack._internal.server.services.services.autoscalers as autoscalers from dstack._internal.core.errors import ServerError +from dstack._internal.core.models.configurations import ReplicaGroup from dstack._internal.core.models.profiles import RetryEvent, StopCriteria from dstack._internal.core.models.runs import ( Job, @@ -37,6 +39,7 @@ from dstack._internal.server.services.locking import get_locker from dstack._internal.server.services.prometheus.client_metrics import run_metrics from dstack._internal.server.services.runs import ( + create_group_run_spec, fmt, process_terminating_run, run_model_to_run, @@ -45,6 +48,7 @@ is_replica_registered, retry_run_replica_jobs, scale_run_replicas, + scale_run_replicas_per_group, ) from dstack._internal.server.services.secrets import get_project_secrets_mapping from dstack._internal.server.services.services import update_service_desired_replica_count @@ -190,7 +194,7 @@ async def _process_pending_run(session: AsyncSession, run_model: RunModel): logger.debug("%s: retrying run is not yet ready for resubmission", fmt(run_model)) return - run_model.desired_replica_count = 1 + # run_model.desired_replica_count = 1 if run.run_spec.configuration.type == "service": run_model.desired_replica_count = run.run_spec.configuration.replicas.min or 0 await update_service_desired_replica_count( @@ -201,11 +205,21 @@ async def _process_pending_run(session: AsyncSession, run_model: RunModel): last_scaled_at=None, ) - if run_model.desired_replica_count == 0: - # stay zero scaled - return + if run_model.desired_replica_count == 0: + # stay zero scaled + return - await scale_run_replicas(session, run_model, replicas_diff=run_model.desired_replica_count) + # Per group scaling because single replica is also normalized to replica groups. + replica_groups = run.run_spec.configuration.replica_groups or [] + counts = ( + json.loads(run_model.desired_replica_counts) + if run_model.desired_replica_counts + else {} + ) + await scale_run_replicas_per_group(session, run_model, replica_groups, counts) + else: + run_model.desired_replica_count = 1 + await scale_run_replicas(session, run_model, replicas_diff=run_model.desired_replica_count) run_model.status = RunStatus.SUBMITTED logger.info("%s: run status has changed PENDING -> SUBMITTED", fmt(run_model)) @@ -449,6 +463,32 @@ async def _handle_run_replicas( # FIXME: should only include scaling events, not retries and deployments last_scaled_at=max((r.timestamp for r in replicas_info), default=None), ) + replica_groups = run_spec.configuration.replica_groups or [] + if replica_groups: + counts = ( + json.loads(run_model.desired_replica_counts) + if run_model.desired_replica_counts + else {} + ) + await scale_run_replicas_per_group(session, run_model, replica_groups, counts) + + # Handle per-group rolling deployment + await _update_jobs_to_new_deployment_in_place( + session=session, + run_model=run_model, + run_spec=run_spec, + replica_groups=replica_groups, + ) + # Process per-group rolling deployment + for group in replica_groups: + await _handle_rolling_deployment_for_group( + session=session, + run_model=run_model, + group=group, + base_run_spec=run_spec, + desired_replica_counts=counts, + ) + return max_replica_count = run_model.desired_replica_count if _has_out_of_date_replicas(run_model): @@ -514,7 +554,10 @@ async def _handle_run_replicas( async def _update_jobs_to_new_deployment_in_place( - session: AsyncSession, run_model: RunModel, run_spec: RunSpec + session: AsyncSession, + run_model: RunModel, + run_spec: RunSpec, + replica_groups: Optional[List] = None, ) -> None: """ Bump deployment_num for jobs that do not require redeployment. @@ -523,14 +566,30 @@ async def _update_jobs_to_new_deployment_in_place( session=session, project=run_model.project, ) + base_run_spec = run_spec + for replica_num, job_models in group_jobs_by_replica_latest(run_model.jobs): if all(j.status.is_finished() for j in job_models): continue if all(j.deployment_num == run_model.deployment_num for j in job_models): continue + + # Determine which group this replica belongs to + replica_group_name = None + group_run_spec = base_run_spec + + if replica_groups: + job_spec = JobSpec.__response__.parse_raw(job_models[0].job_spec_data) + replica_group_name = job_spec.replica_group or "default" + + for group in replica_groups: + if group.name == replica_group_name: + group_run_spec = create_group_run_spec(base_run_spec, group) + break + # FIXME: Handle getting image configuration errors or skip it. new_job_specs = await get_job_specs_from_run_spec( - run_spec=run_spec, + run_spec=group_run_spec, secrets=secrets, replica_num=replica_num, ) @@ -548,8 +607,15 @@ async def _update_jobs_to_new_deployment_in_place( job_model.deployment_num = run_model.deployment_num -def _has_out_of_date_replicas(run: RunModel) -> bool: +def _has_out_of_date_replicas(run: RunModel, group_filter: Optional[str] = None) -> bool: for job in run.jobs: + # Filter jobs by group if specified + if group_filter is not None: + job_spec = JobSpec.__response__.parse_raw(job.job_spec_data) + # Handle None case: treat None as "default" for backward compatibility + job_replica_group = job_spec.replica_group or "default" + if job_replica_group != group_filter: + continue if job.deployment_num < run.deployment_num and not ( job.status.is_finished() or job.termination_reason == JobTerminationReason.SCALED_DOWN ): @@ -612,3 +678,109 @@ def _should_stop_on_master_done(run: Run) -> bool: if is_master_job(job) and job.job_submissions[-1].status == JobStatus.DONE: return True return False + + +async def _handle_rolling_deployment_for_group( + session: AsyncSession, + run_model: RunModel, + group: ReplicaGroup, + base_run_spec: RunSpec, + desired_replica_counts: dict, +) -> None: + """ + Handle rolling deployment for a single replica group. + """ + from dstack._internal.server.services.runs.replicas import ( + _build_replica_lists, + scale_run_replicas_for_group, + ) + + group_desired = desired_replica_counts.get(group.name, group.replicas.min or 0) + + # Check if group has out-of-date replicas + if not _has_out_of_date_replicas(run_model, group_filter=group.name): + return # Group is up-to-date + + # Calculate max replicas (allow surge during deployment) + group_max_replica_count = group_desired + ROLLING_DEPLOYMENT_MAX_SURGE + + # Count non-terminated replicas for this group only + + non_terminated_replica_count = len( + { + j.replica_num + for j in run_model.jobs + if not j.status.is_finished() and _job_belongs_to_group(job=j, group_name=group.name) + } + ) + + # Start new up-to-date replicas if needed + if non_terminated_replica_count < group_max_replica_count: + active_replicas, inactive_replicas = _build_replica_lists( + run_model=run_model, + jobs=run_model.jobs, + group_filter=group.name, + ) + + await scale_run_replicas_for_group( + session=session, + run_model=run_model, + group=group, + replicas_diff=group_max_replica_count - non_terminated_replica_count, + base_run_spec=base_run_spec, + active_replicas=active_replicas, + inactive_replicas=inactive_replicas, + ) + + # Stop out-of-date replicas that are not registered + replicas_to_stop_count = 0 + for _, jobs in group_jobs_by_replica_latest(run_model.jobs): + job_spec = JobSpec.__response__.parse_raw(jobs[0].job_spec_data) + if job_spec.replica_group != group.name: + continue + # Check if replica is out-of-date and not registered + if ( + any(j.deployment_num < run_model.deployment_num for j in jobs) + and any( + j.status not in [JobStatus.TERMINATING] + JobStatus.finished_statuses() + for j in jobs + ) + and not is_replica_registered(jobs) + ): + replicas_to_stop_count += 1 + + # Stop excessive registered out-of-date replicas + non_terminating_registered_replicas_count = 0 + for _, jobs in group_jobs_by_replica_latest(run_model.jobs): + # Filter by group + job_spec = JobSpec.__response__.parse_raw(jobs[0].job_spec_data) + if job_spec.replica_group != group.name: + continue + + if is_replica_registered(jobs) and all(j.status != JobStatus.TERMINATING for j in jobs): + non_terminating_registered_replicas_count += 1 + + replicas_to_stop_count += max(0, non_terminating_registered_replicas_count - group_desired) + + if replicas_to_stop_count > 0: + # Build lists again to get current state + active_replicas, inactive_replicas = _build_replica_lists( + run_model=run_model, + jobs=run_model.jobs, + group_filter=group.name, + ) + + await scale_run_replicas_for_group( + session=session, + run_model=run_model, + group=group, + replicas_diff=-replicas_to_stop_count, + base_run_spec=base_run_spec, + active_replicas=active_replicas, + inactive_replicas=inactive_replicas, + ) + + +def _job_belongs_to_group(job: JobModel, group_name: str) -> bool: + job_spec = JobSpec.__response__.parse_raw(job.job_spec_data) + return job_spec.replica_group == group_name diff --git a/src/dstack/_internal/server/migrations/versions/706e0acc3a7d_add_runmodel_desired_replica_counts.py b/src/dstack/_internal/server/migrations/versions/706e0acc3a7d_add_runmodel_desired_replica_counts.py new file mode 100644 index 000000000..f615560cb --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/706e0acc3a7d_add_runmodel_desired_replica_counts.py @@ -0,0 +1,26 @@ +"""add runmodel desired_replica_counts + +Revision ID: 706e0acc3a7d +Revises: 22d74df9897e +Create Date: 2025-12-18 10:54:13.508297 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "706e0acc3a7d" +down_revision = "22d74df9897e" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + with op.batch_alter_table("runs", schema=None) as batch_op: + batch_op.add_column(sa.Column("desired_replica_counts", sa.Text(), nullable=True)) + + +def downgrade() -> None: + with op.batch_alter_table("runs", schema=None) as batch_op: + batch_op.drop_column("desired_replica_counts") diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index b64d58b9e..8d39d5dcb 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -385,7 +385,7 @@ class RunModel(BaseModel): priority: Mapped[int] = mapped_column(Integer, default=0) deployment_num: Mapped[int] = mapped_column(Integer) desired_replica_count: Mapped[int] = mapped_column(Integer) - + desired_replica_counts: Mapped[Optional[str]] = mapped_column(Text, nullable=True) jobs: Mapped[List["JobModel"]] = relationship( back_populates="run", lazy="selectin", order_by="[JobModel.replica_num, JobModel.job_num]" ) diff --git a/src/dstack/_internal/server/services/runs/__init__.py b/src/dstack/_internal/server/services/runs/__init__.py index 18c0847b4..3304e24a7 100644 --- a/src/dstack/_internal/server/services/runs/__init__.py +++ b/src/dstack/_internal/server/services/runs/__init__.py @@ -18,6 +18,7 @@ ServerClientError, ) from dstack._internal.core.models.common import ApplyAction +from dstack._internal.core.models.configurations import ReplicaGroup from dstack._internal.core.models.profiles import ( RetryEvent, ) @@ -460,12 +461,8 @@ async def submit_run( submitted_at = common_utils.get_current_datetime() initial_status = RunStatus.SUBMITTED - initial_replicas = 1 if run_spec.merged_profile.schedule is not None: initial_status = RunStatus.PENDING - initial_replicas = 0 - elif run_spec.configuration.type == "service": - initial_replicas = run_spec.configuration.replicas.min or 0 run_model = RunModel( id=uuid.uuid4(), @@ -493,12 +490,50 @@ async def submit_run( if run_spec.configuration.type == "service": await services.register_service(session, run_model, run_spec) + service_config = run_spec.configuration - for replica_num in range(initial_replicas): + global_replica_num = 0 # Global counter across all groups for unique replica_num + + for replica_group in service_config.replica_groups: + if run_spec.merged_profile.schedule is not None: + group_initial_replicas = 0 + else: + group_initial_replicas = replica_group.replicas.min or 0 + + # Each replica in this group gets the same group-specific configuration + for group_replica_num in range(group_initial_replicas): + group_run_spec = create_group_run_spec( + base_run_spec=run_spec, + replica_group=replica_group, + ) + jobs = await get_jobs_from_run_spec( + run_spec=group_run_spec, + secrets=secrets, + replica_num=global_replica_num, + ) + + for job in jobs: + job.job_spec.replica_group = replica_group.name + job_model = create_job_model_for_new_submission( + run_model=run_model, + job=job, + status=JobStatus.SUBMITTED, + ) + session.add(job_model) + events.emit( + session, + f"Job created on run submission. Status: {job_model.status.upper()}", + actor=events.SystemActor(), + targets=[ + events.Target.from_model(job_model), + ], + ) + global_replica_num += 1 + else: jobs = await get_jobs_from_run_spec( run_spec=run_spec, secrets=secrets, - replica_num=replica_num, + replica_num=0, ) for job in jobs: job_model = create_job_model_for_new_submission( @@ -522,6 +557,31 @@ async def submit_run( return common_utils.get_or_error(run) +def create_group_run_spec( + base_run_spec: RunSpec, + replica_group: ReplicaGroup, +) -> RunSpec: + # Create a copy of the configuration as a dict + config_dict = base_run_spec.configuration.dict() + + # Override with group-specific values (only if provided) + if replica_group.commands: + config_dict["commands"] = replica_group.commands + + if replica_group.resources: + config_dict["resources"] = replica_group.resources + + # Create new configuration object with merged values + # Use the same class as base (ServiceConfiguration) + new_config = base_run_spec.configuration.__class__.parse_obj(config_dict) + + # Create new RunSpec with modified configuration + # Preserve all other RunSpec properties (repo_data, file_archives, etc.) + run_spec_dict = base_run_spec.dict() + run_spec_dict["configuration"] = new_config + return RunSpec.parse_obj(run_spec_dict) + + def create_job_model_for_new_submission( run_model: RunModel, job: Job, diff --git a/src/dstack/_internal/server/services/runs/replicas.py b/src/dstack/_internal/server/services/runs/replicas.py index b1c33c90c..61cf82e70 100644 --- a/src/dstack/_internal/server/services/runs/replicas.py +++ b/src/dstack/_internal/server/services/runs/replicas.py @@ -1,15 +1,20 @@ -from typing import List +from typing import Dict, List, Optional, Tuple from sqlalchemy.ext.asyncio import AsyncSession -from dstack._internal.core.models.runs import JobStatus, JobTerminationReason, RunSpec +from dstack._internal.core.models.configurations import ReplicaGroup +from dstack._internal.core.models.runs import JobSpec, JobStatus, JobTerminationReason, RunSpec from dstack._internal.server.models import JobModel, RunModel from dstack._internal.server.services.jobs import ( get_jobs_from_run_spec, group_jobs_by_replica_latest, ) from dstack._internal.server.services.logging import fmt -from dstack._internal.server.services.runs import create_job_model_for_new_submission, logger +from dstack._internal.server.services.runs import ( + create_group_run_spec, + create_job_model_for_new_submission, + logger, +) from dstack._internal.server.services.secrets import get_project_secrets_mapping @@ -21,8 +26,28 @@ async def retry_run_replica_jobs( session=session, project=run_model.project, ) + + # Determine replica group from existing job + base_run_spec = RunSpec.__response__.parse_raw(run_model.run_spec) + job_spec = JobSpec.parse_raw(latest_jobs[0].job_spec_data) + replica_group_name = job_spec.replica_group + replica_group = None + + # Find matching replica group + if replica_group_name and base_run_spec.configuration.replica_groups: + for group in base_run_spec.configuration.replica_groups: + if group.name == replica_group_name: + replica_group = group + break + + run_spec = ( + base_run_spec + if replica_group is None + else create_group_run_spec(base_run_spec, replica_group) + ) + new_jobs = await get_jobs_from_run_spec( - run_spec=RunSpec.__response__.parse_raw(run_model.run_spec), + run_spec=run_spec, secrets=secrets, replica_num=latest_jobs[0].replica_num, ) @@ -38,6 +63,10 @@ async def retry_run_replica_jobs( job_model.status = JobStatus.TERMINATING job_model.termination_reason = JobTerminationReason.TERMINATED_BY_SERVER + # Set replica_group on retried jobs to maintain group identity + if replica_group_name: + new_job.job_spec.replica_group = replica_group_name + new_job_model = create_job_model_for_new_submission( run_model=run_model, job=new_job, @@ -55,7 +84,6 @@ def is_replica_registered(jobs: list[JobModel]) -> bool: async def scale_run_replicas(session: AsyncSession, run_model: RunModel, replicas_diff: int): if replicas_diff == 0: - # nothing to do return logger.info( @@ -65,14 +93,48 @@ async def scale_run_replicas(session: AsyncSession, run_model: RunModel, replica abs(replicas_diff), ) + active_replicas, inactive_replicas = _build_replica_lists(run_model, run_model.jobs) + run_spec = RunSpec.__response__.parse_raw(run_model.run_spec) + + if replicas_diff < 0: + _scale_down_replicas(active_replicas, abs(replicas_diff)) + else: + await _scale_up_replicas( + session, + run_model, + active_replicas, + inactive_replicas, + replicas_diff, + run_spec, + group_name=None, + ) + + +def _build_replica_lists( + run_model: RunModel, + jobs: List[JobModel], + group_filter: Optional[str] = None, +) -> Tuple[ + List[Tuple[int, bool, int, List[JobModel]]], List[Tuple[int, bool, int, List[JobModel]]] +]: # lists of (importance, is_out_of_date, replica_num, jobs) active_replicas = [] inactive_replicas = [] - for replica_num, replica_jobs in group_jobs_by_replica_latest(run_model.jobs): + for replica_num, replica_jobs in group_jobs_by_replica_latest(jobs): + # Filter by group if specified + if group_filter is not None: + try: + job_spec = JobSpec.parse_raw(replica_jobs[0].job_spec_data) + if job_spec.replica_group != group_filter: + continue + except Exception: + continue + statuses = set(job.status for job in replica_jobs) deployment_num = replica_jobs[0].deployment_num # same for all jobs is_out_of_date = deployment_num < run_model.deployment_num + if {JobStatus.TERMINATING, *JobStatus.finished_statuses()} & statuses: # if there are any terminating or finished jobs, the replica is inactive inactive_replicas.append((0, is_out_of_date, replica_num, replica_jobs)) @@ -89,47 +151,159 @@ async def scale_run_replicas(session: AsyncSession, run_model: RunModel, replica # all jobs are running and ready, the replica is active and has the importance of 3 active_replicas.append((3, is_out_of_date, replica_num, replica_jobs)) - # sort by is_out_of_date (up-to-date first), importance (desc), and replica_num (asc) + # Sort by is_out_of_date (up-to-date first), importance (desc), and replica_num (asc) active_replicas.sort(key=lambda r: (r[1], -r[0], r[2])) - run_spec = RunSpec.__response__.parse_raw(run_model.run_spec) - if replicas_diff < 0: - for _, _, _, replica_jobs in reversed(active_replicas[-abs(replicas_diff) :]): - # scale down the less important replicas first - for job in replica_jobs: - if job.status.is_finished() or job.status == JobStatus.TERMINATING: - continue - job.status = JobStatus.TERMINATING - job.termination_reason = JobTerminationReason.SCALED_DOWN - # background task will process the job later - else: - scheduled_replicas = 0 + return active_replicas, inactive_replicas - # rerun inactive replicas - for _, _, _, replica_jobs in inactive_replicas: - if scheduled_replicas == replicas_diff: - break - await retry_run_replica_jobs(session, run_model, replica_jobs, only_failed=False) - scheduled_replicas += 1 +def _scale_down_replicas( + active_replicas: List[Tuple[int, bool, int, List[JobModel]]], + count: int, +) -> None: + """Scale down by terminating the least important replicas""" + if count <= 0: + return + + for _, _, _, replica_jobs in reversed(active_replicas[-count:]): + for job in replica_jobs: + if job.status.is_finished() or job.status == JobStatus.TERMINATING: + continue + job.status = JobStatus.TERMINATING + job.termination_reason = JobTerminationReason.SCALED_DOWN + + +async def _scale_up_replicas( + session: AsyncSession, + run_model: RunModel, + active_replicas: List[Tuple[int, bool, int, List[JobModel]]], + inactive_replicas: List[Tuple[int, bool, int, List[JobModel]]], + replicas_diff: int, + run_spec: RunSpec, + group_name: Optional[str] = None, +) -> None: + """Scale up by retrying inactive replicas and creating new ones""" + if replicas_diff <= 0: + return + + scheduled_replicas = 0 + + # Retry inactive replicas first + for _, _, _, replica_jobs in inactive_replicas: + if scheduled_replicas == replicas_diff: + break + await retry_run_replica_jobs(session, run_model, replica_jobs, only_failed=False) + scheduled_replicas += 1 + + # Create new replicas + if scheduled_replicas < replicas_diff: secrets = await get_project_secrets_mapping( session=session, project=run_model.project, ) - for replica_num in range( - len(active_replicas) + scheduled_replicas, len(active_replicas) + replicas_diff - ): - # FIXME: Handle getting image configuration errors or skip it. + max_replica_num = max((job.replica_num for job in run_model.jobs), default=-1) + + new_replicas_needed = replicas_diff - scheduled_replicas + for i in range(new_replicas_needed): + new_replica_num = max_replica_num + 1 + i jobs = await get_jobs_from_run_spec( run_spec=run_spec, secrets=secrets, - replica_num=replica_num, + replica_num=new_replica_num, ) for job in jobs: + # Set replica_group if specified + if group_name is not None: + job.job_spec.replica_group = group_name job_model = create_job_model_for_new_submission( run_model=run_model, job=job, status=JobStatus.SUBMITTED, ) session.add(job_model) + run_model.jobs.append(job_model) + + +async def scale_run_replicas_per_group( + session: AsyncSession, + run_model: RunModel, + replica_groups: List[ReplicaGroup], + desired_replica_counts: Dict[str, int], +) -> None: + """Scale each replica group independently""" + if not replica_groups: + return + + for group in replica_groups: + group_desired = desired_replica_counts.get(group.name, group.replicas.min or 0) + + # Build replica lists filtered by this group + active_replicas, inactive_replicas = _build_replica_lists( + run_model=run_model, jobs=run_model.jobs, group_filter=group.name + ) + + # Count active replicas + active_group_count = len(active_replicas) + group_diff = group_desired - active_group_count + + if group_diff != 0: + # Check if rolling deployment is in progress for THIS GROUP + from dstack._internal.server.background.tasks.process_runs import ( + _has_out_of_date_replicas, + ) + + group_has_out_of_date = _has_out_of_date_replicas(run_model, group_filter=group.name) + + # During rolling deployment, don't scale down old replicas + # Let rolling deployment handle stopping old replicas + if group_diff < 0 and group_has_out_of_date: + # Skip scaling down during rolling deployment + continue + await scale_run_replicas_for_group( + session=session, + run_model=run_model, + group=group, + replicas_diff=group_diff, + base_run_spec=RunSpec.__response__.parse_raw(run_model.run_spec), + active_replicas=active_replicas, + inactive_replicas=inactive_replicas, + ) + + +async def scale_run_replicas_for_group( + session: AsyncSession, + run_model: RunModel, + group: ReplicaGroup, + replicas_diff: int, + base_run_spec: RunSpec, + active_replicas: List[Tuple[int, bool, int, List[JobModel]]], + inactive_replicas: List[Tuple[int, bool, int, List[JobModel]]], +) -> None: + """Scale a specific replica group up or down""" + if replicas_diff == 0: + return + + logger.info( + "%s: scaling %s %s replica(s) for group '%s'", + fmt(run_model), + "UP" if replicas_diff > 0 else "DOWN", + abs(replicas_diff), + group.name, + ) + + # Get group-specific run_spec + group_run_spec = create_group_run_spec(base_run_spec, group) + + if replicas_diff < 0: + _scale_down_replicas(active_replicas, abs(replicas_diff)) + else: + await _scale_up_replicas( + session=session, + run_model=run_model, + active_replicas=active_replicas, + inactive_replicas=inactive_replicas, + replicas_diff=replicas_diff, + run_spec=group_run_spec, + group_name=group.name, + ) diff --git a/src/dstack/_internal/server/services/runs/spec.py b/src/dstack/_internal/server/services/runs/spec.py index 73b6d9fc7..53d0c2192 100644 --- a/src/dstack/_internal/server/services/runs/spec.py +++ b/src/dstack/_internal/server/services/runs/spec.py @@ -50,6 +50,7 @@ "env", "shell", "commands", + "replica_groups", ], } diff --git a/src/dstack/_internal/server/services/services/__init__.py b/src/dstack/_internal/server/services/services/__init__.py index 05c1fa909..d029c145e 100644 --- a/src/dstack/_internal/server/services/services/__init__.py +++ b/src/dstack/_internal/server/services/services/__init__.py @@ -2,6 +2,7 @@ Application logic related to `type: service` runs. """ +import json import uuid from datetime import datetime from typing import Optional @@ -299,13 +300,39 @@ async def update_service_desired_replica_count( configuration: ServiceConfiguration, last_scaled_at: Optional[datetime], ) -> None: - scaler = get_service_scaler(configuration) stats = None if run_model.gateway_id is not None: conn = await get_or_add_gateway_connection(session, run_model.gateway_id) stats = await conn.get_stats(run_model.project.name, run_model.run_name) - run_model.desired_replica_count = scaler.get_desired_count( - current_desired_count=run_model.desired_replica_count, - stats=stats, - last_scaled_at=last_scaled_at, - ) + if configuration.replica_groups: + desired_replica_counts = {} + total = 0 + prev_counts = ( + json.loads(run_model.desired_replica_counts) + if run_model.desired_replica_counts + else {} + ) + for group in configuration.replica_groups: + # temp group_wise config to get the group_wise desired replica count. + group_config = configuration.copy( + exclude={"replica_groups"}, + update={"replicas": group.replicas, "scaling": group.scaling}, + ) + scaler = get_service_scaler(group_config) + group_desired = scaler.get_desired_count( + current_desired_count=prev_counts.get(group.name, group.replicas.min or 0), + stats=stats, + last_scaled_at=last_scaled_at, + ) + desired_replica_counts[group.name] = group_desired + total += group_desired + run_model.desired_replica_counts = json.dumps(desired_replica_counts) + run_model.desired_replica_count = total + else: + # Todo Not required as single replica is normalized to replica_groups. + scaler = get_service_scaler(configuration) + run_model.desired_replica_count = scaler.get_desired_count( + current_desired_count=run_model.desired_replica_count, + stats=stats, + last_scaled_at=last_scaled_at, + )