Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions slice/config/rbac/role.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ rules:
verbs:
- get
- list
- update
- watch
- apiGroups:
- kueue.x-k8s.io
Expand Down
97 changes: 72 additions & 25 deletions slice/internal/controller/workload_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func NewWorkloadReconciler(cl client.Client, record record.EventRecorder) *Workl
// +kubebuilder:rbac:groups=accelerator.gke.io,resources=slices,verbs=get;list;watch;create;update;patch;delete
// +kubebuilder:rbac:groups=accelerator.gke.io,resources=slices/finalizers,verbs=update
// +kubebuilder:rbac:groups="",resources=events,verbs=create;watch;update;patch
// +kubebuilder:rbac:groups=jobset.x-k8s.io,resources=jobsets,verbs=get;list;watch
// +kubebuilder:rbac:groups=jobset.x-k8s.io,resources=jobsets,verbs=get;list;watch;update
// +kubebuilder:rbac:groups="",resources=pods,verbs=get;list;watch

func (r *WorkloadReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) {
Expand Down Expand Up @@ -162,11 +162,11 @@ func (r *WorkloadReconciler) Reconcile(ctx context.Context, req ctrl.Request) (c
return ctrl.Result{}, err
}

deleted, toDelete, initializing, _ := r.groupSlices(slices)
if len(deleted) > 0 {
grouped := r.groupSlices(slices)
if len(grouped.deleted) > 0 {
log.V(3).Info(
"Waiting for deleted Slices to be cleaned up; skipping reconciliation for now",
"deletedSlices", klog.KObjSlice(deleted),
"deletedSlices", klog.KObjSlice(grouped.deleted),
)
return ctrl.Result{}, err
}
Expand All @@ -176,25 +176,33 @@ func (r *WorkloadReconciler) Reconcile(ctx context.Context, req ctrl.Request) (c
return ctrl.Result{}, err
}

if len(grouped.active) == len(slices) && len(slices) > 0 {
log.V(3).Info("Annotating owner before unsuspending")
err := r.updateOwnerBeforeUnsuspend(ctx, wl)
if err != nil {
return ctrl.Result{}, err
}
}

err = r.syncAdmissionCheckStatus(ctx, wl, ac, slices)
if err != nil {
return ctrl.Result{}, client.IgnoreNotFound(err)
}

if len(toDelete) > 0 {
if len(grouped.toDelete) > 0 {
log.V(3).Info(
"Deleting Slices",
"slices", klog.KObjSlice(toDelete),
"slices", klog.KObjSlice(grouped.toDelete),
)
err = r.deleteSlices(ctx, toDelete)
err = r.deleteSlices(ctx, grouped.toDelete)
if err != nil {
return ctrl.Result{}, err
}
}
if len(initializing) > 0 {
if len(grouped.initializing) > 0 {
log.V(3).Info(
"Waiting for Slices to be initialized",
"slices", klog.KObjSlice(initializing),
"slices", klog.KObjSlice(grouped.initializing),
)
return ctrl.Result{RequeueAfter: initializationRetryAfter}, nil
}
Expand Down Expand Up @@ -263,22 +271,22 @@ func (r *WorkloadReconciler) cleanupSlices(ctx context.Context, wl *kueue.Worklo
return false, err
}

deleted, toDelete, initializing, other := r.groupSlices(slices)
grouped := r.groupSlices(slices)

if len(deleted) == len(slices) {
if len(grouped.deleted) == len(slices) {
log.V(3).Info("All slices already deleted; finishing cleanup")
return true, nil
}

if len(other)+len(toDelete)+len(initializing) > 0 {
if len(grouped.active)+len(grouped.toDelete)+len(grouped.initializing) > 0 {
terminated, err := r.ownerPodsFinished(ctx, wl)
if err != nil || !terminated {
return false, err
}
}
// after pods are terminated we should cleanup all the slices (including active and initializing ones)
toDelete = append(toDelete, other...)
toDelete = append(toDelete, initializing...)
toDelete := append(grouped.toDelete, grouped.active...)
toDelete = append(toDelete, grouped.initializing...)
log.V(3).Info("Deleting Slices", "slices", klog.KObjSlice(toDelete))
err = r.deleteSlices(ctx, toDelete)
if err != nil {
Expand All @@ -302,6 +310,13 @@ func (r *WorkloadReconciler) findWorkloadSlices(ctx context.Context, wl *kueue.W
return slices.Items, nil
}

type groupedSlices struct {
deleted []v1beta1.Slice
toDelete []v1beta1.Slice
initializing []v1beta1.Slice
active []v1beta1.Slice
}

// groupSlices categorizes a list of Slice objects into four groups based on their state.
// It separates slices into deleted (marked for deletion), ones that should be delete
// (errored and stale), ones that are initializning, and other (active) slices.
Expand All @@ -311,24 +326,23 @@ func (r *WorkloadReconciler) findWorkloadSlices(ctx context.Context, wl *kueue.W
// slices - A slice of v1beta1.Slice objects to be categorized.
//
// Returns:
// - A slice containing deleted Slice objects (with non-zero DeletionTimestamp).
// - A slice containing Slice objects that should be deleted (errored and stale slices).
// - A slice sontaining initializing Slices objects (activating and slices without ready state yet)
// - A slice containing other Slice objects (active slices).
func (r *WorkloadReconciler) groupSlices(slices []v1beta1.Slice) (deleted []v1beta1.Slice, toDelete []v1beta1.Slice, initializing []v1beta1.Slice, other []v1beta1.Slice) {
//
// A groupedSlices struct containing categorized slices.
func (r *WorkloadReconciler) groupSlices(slices []v1beta1.Slice) groupedSlices {
gs := groupedSlices{}
for _, slice := range slices {
switch core.GetSliceState(slice) {
case core.SliceStateDeleted:
deleted = append(deleted, slice)
gs.deleted = append(gs.deleted, slice)
case core.SliceStateFailed, core.SliceStateStale:
toDelete = append(toDelete, slice)
gs.toDelete = append(gs.toDelete, slice)
case core.SliceStateCreated, core.SliceStateActivating:
initializing = append(initializing, slice)
default:
other = append(other, slice)
gs.initializing = append(gs.initializing, slice)
case core.SliceStateActive, core.SliceStateActiveDegraded:
gs.active = append(gs.active, slice)
}
}
return deleted, toDelete, initializing, other
return gs
}

func (r *WorkloadReconciler) deleteSlices(ctx context.Context, slices []v1beta1.Slice) error {
Expand Down Expand Up @@ -409,6 +423,39 @@ func (r *WorkloadReconciler) finalizeWorkload(ctx context.Context, wl *kueue.Wor
return nil
}

func (r *WorkloadReconciler) updateOwnerBeforeUnsuspend(ctx context.Context, wl *kueue.Workload) error {
// For now, we only support JobSets.
if isJobSetOwner(wl) {
return r.updateJobSetBeforeUnsuspend(ctx, wl)
}
return nil
}

func (r *WorkloadReconciler) updateJobSetBeforeUnsuspend(ctx context.Context, wl *kueue.Workload) error {
owner := metav1.GetControllerOf(wl)
log := ctrl.LoggerFrom(ctx).WithValues("jobSet", klog.KRef(wl.Namespace, owner.Name))
jobSet := &jobset.JobSet{}
jobSetKey := types.NamespacedName{Name: owner.Name, Namespace: wl.Namespace}
if err := r.client.Get(ctx, jobSetKey, jobSet); err != nil {
log.Error(err, "Failed to get JobSet")
return err
}
for i := range jobSet.Spec.ReplicatedJobs {
rj := &jobSet.Spec.ReplicatedJobs[i]
topology := rj.Template.Spec.Template.Annotations[core.TPUTopologyAnnotation]
log.V(5).Info("Copying topology annotation as nodeSelector", "topology", topology)
if rj.Template.Spec.Template.Spec.NodeSelector == nil {
rj.Template.Spec.Template.Spec.NodeSelector = make(map[string]string)
}
rj.Template.Spec.Template.Spec.NodeSelector[core.TPUTopologyAnnotation] = topology
}
if err := r.client.Update(ctx, jobSet); err != nil {
log.Error(err, "Failed to update JobSet")
return err
}
return nil
}

func validateRelevantWorkload(wl *kueue.Workload, nodes map[string]corev1.Node) error {
if !hasSupportedOwner(wl) {
return errors.New("does not have a supported owner")
Expand Down
89 changes: 84 additions & 5 deletions slice/internal/controller/workload_controller_test.go
Copy link
Collaborator

@PBundyra PBundyra Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the change reflected in the tests?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added Unit test ("should set nodeSelector at Jobset when all slices are ready")

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, it is tested in e2e tests

Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ func TestWorkloadReconciler(t *testing.T) {
objs []client.Object
wantWorkloads []kueue.Workload
wantSlices []slice.Slice
wantJobSets []jobset.JobSet
wantErr error
wantEvents []utiltesting.EventRecord
wantResult controllerruntime.Result
Expand Down Expand Up @@ -305,7 +306,7 @@ func TestWorkloadReconciler(t *testing.T) {
request: baseRequest,
objs: []client.Object{
baseAdmissionCheckWrapper.DeepCopy(),
baseJobSetWrapper.DeepCopy(),
baseJobSetWrapper.Clone().Obj(),
basePod1Wrapper.DeepCopy(),
baseWorkloadWrapper.Clone().
PodSets(basePodSets...).
Expand All @@ -330,7 +331,8 @@ func TestWorkloadReconciler(t *testing.T) {
*baseSlice1Wrapper.Clone().Degraded().Obj(),
*baseSlice2Wrapper.Clone().Degraded().Obj(),
},
wantResult: reconcile.Result{RequeueAfter: cleanupRetryAfter},
wantJobSets: []jobset.JobSet{*baseJobSetWrapper.Clone().Obj()},
wantResult: reconcile.Result{RequeueAfter: cleanupRetryAfter},
},
"should delete the finalizer because the Pod Status Succeeded": {
request: baseRequest,
Expand All @@ -355,6 +357,7 @@ func TestWorkloadReconciler(t *testing.T) {
Active(false).
Obj(),
},
wantJobSets: []jobset.JobSet{*baseJobSetWrapper.Clone().Obj()},
},
"should delete the finalizer because the Pod Status PodFailed": {
request: baseRequest,
Expand All @@ -380,6 +383,7 @@ func TestWorkloadReconciler(t *testing.T) {
Active(false).
Obj(),
},
wantJobSets: []jobset.JobSet{*baseJobSetWrapper.Clone().Obj()},
},
"shouldn't delete the finalizer because Pods still running": {
request: baseRequest,
Expand Down Expand Up @@ -411,13 +415,14 @@ func TestWorkloadReconciler(t *testing.T) {
*baseSlice1Wrapper.DeepCopy(),
*baseSlice2Wrapper.DeepCopy(),
},
wantResult: reconcile.Result{RequeueAfter: cleanupRetryAfter},
wantJobSets: []jobset.JobSet{*baseJobSetWrapper.Clone().Obj()},
wantResult: reconcile.Result{RequeueAfter: cleanupRetryAfter},
},
"shouldn't delete the finalizer because one of the Pods still running": {
request: baseRequest,
objs: []client.Object{
baseAdmissionCheckWrapper.DeepCopy(),
baseJobSetWrapper.DeepCopy(),
baseJobSetWrapper.Clone().Obj(),
basePod1Wrapper.Clone().StatusPhase(corev1.PodSucceeded).Obj(),
basePod2Wrapper.DeepCopy(),
baseWorkloadWrapper.Clone().
Expand All @@ -443,7 +448,8 @@ func TestWorkloadReconciler(t *testing.T) {
*baseSlice1Wrapper.DeepCopy(),
*baseSlice2Wrapper.DeepCopy(),
},
wantResult: reconcile.Result{RequeueAfter: initializationRetryAfter},
wantJobSets: []jobset.JobSet{*baseJobSetWrapper.Clone().Obj()},
wantResult: reconcile.Result{RequeueAfter: initializationRetryAfter},
},
"shouldn't add finalizer because invalid TPU topology annotation": {
request: baseRequest,
Expand Down Expand Up @@ -1142,6 +1148,7 @@ func TestWorkloadReconciler(t *testing.T) {
ControllerReference(jobSetGVK, baseJobSetName, baseJobSetName).
Finalizers(SliceControllerName).
Obj(),
baseJobSetWrapper.Clone().Obj(),
baseSlice1Wrapper.Clone().Active().Obj(),
baseSlice2Wrapper.Clone().Active().Obj(),
},
Expand All @@ -1158,6 +1165,7 @@ func TestWorkloadReconciler(t *testing.T) {
*baseSlice1Wrapper.Clone().Active().Obj(),
*baseSlice2Wrapper.Clone().Active().Obj(),
},
wantJobSets: []jobset.JobSet{*baseJobSetWrapper.Clone().Obj()},
wantEvents: []utiltesting.EventRecord{
buildEventRecord(corev1.NamespaceDefault, corev1.EventTypeNormal, AdmissionCheckUpdatedEventType,
fmt.Sprintf(`Admission check %q updated state from "Pending" to "Ready"`, baseACName)),
Expand All @@ -1175,6 +1183,7 @@ func TestWorkloadReconciler(t *testing.T) {
ControllerReference(jobSetGVK, baseJobSetName, baseJobSetName).
Finalizers(SliceControllerName).
Obj(),
baseJobSetWrapper.Clone().Obj(),
baseSlice1Wrapper.Clone().Active().Obj(),
baseSlice2Wrapper.Clone().Degraded().Obj(),
},
Expand All @@ -1190,11 +1199,70 @@ func TestWorkloadReconciler(t *testing.T) {
wantSlices: []slice.Slice{
*baseSlice1Wrapper.Clone().Active().Obj(),
*baseSlice2Wrapper.Clone().Degraded().Obj()},
wantJobSets: []jobset.JobSet{*baseJobSetWrapper.Clone().Obj()},
wantEvents: []utiltesting.EventRecord{
buildEventRecord(corev1.NamespaceDefault, corev1.EventTypeNormal, AdmissionCheckUpdatedEventType,
fmt.Sprintf(`Admission check %q updated state from "Pending" to "Ready"`, baseACName)),
},
},
"should set nodeSelector at Jobset when all slices are ready": {
request: baseRequest,
objs: []client.Object{
worker1Node.DeepCopy(),
worker2Node.DeepCopy(),
baseAdmissionCheckWrapper.DeepCopy(),
baseWorkloadWrapper.Clone().
PodSets(basePodSets...).
ReserveQuota(baseAdmission, now).
ControllerReference(jobSetGVK, baseJobSetName, baseJobSetName).
Finalizers(SliceControllerName).
Obj(),
baseJobSetWrapper.Clone().ReplicatedJobs(
utiltestingjobsjobset.ReplicatedJobRequirements{
Name: "ps1",
PodAnnotations: map[string]string{core.TPUTopologyAnnotation: "4x4x12"},
},
utiltestingjobsjobset.ReplicatedJobRequirements{
Name: "ps2",
PodAnnotations: map[string]string{core.TPUTopologyAnnotation: "4x4x12"},
},
).Obj(),
baseSlice1Wrapper.Clone().Active().Obj(),
baseSlice2Wrapper.Clone().Active().Obj(),
},
wantWorkloads: []kueue.Workload{
*baseWorkloadWrapper.Clone().
PodSets(basePodSets...).
ReserveQuota(baseAdmission, now).
ControllerReference(jobSetGVK, baseJobSetName, baseJobSetName).
Finalizers(SliceControllerName).
AdmissionCheck(buildAdmissionCheckState(kueue.CheckStateReady, `Slices are in states: 2 ACTIVE`)).
Obj(),
},
wantSlices: []slice.Slice{
*baseSlice1Wrapper.Clone().Active().Obj(),
*baseSlice2Wrapper.Clone().Active().Obj(),
},
wantJobSets: []jobset.JobSet{*baseJobSetWrapper.Clone().ReplicatedJobs(
utiltestingjobsjobset.ReplicatedJobRequirements{
Name: "ps1",
PodAnnotations: map[string]string{
core.TPUTopologyAnnotation: "4x4x12",
},
NodeSelector: map[string]string{core.TPUTopologyAnnotation: "4x4x12"},
},
utiltestingjobsjobset.ReplicatedJobRequirements{
Name: "ps2",
PodAnnotations: map[string]string{
core.TPUTopologyAnnotation: "4x4x12",
},
NodeSelector: map[string]string{core.TPUTopologyAnnotation: "4x4x12"},
},
).Obj()},
wantEvents: []utiltesting.EventRecord{
buildEventRecord(corev1.NamespaceDefault, corev1.EventTypeNormal, AdmissionCheckUpdatedEventType, fmt.Sprintf(`Admission check %q updated state from "Pending" to "Ready"`, baseACName)),
},
},
"should update the Workload's AdmissionCheckState when one Slice is in the Failed state": {
request: baseRequest,
objs: []client.Object{
Expand Down Expand Up @@ -1268,6 +1336,7 @@ func TestWorkloadReconciler(t *testing.T) {
ControllerReference(jobSetGVK, baseJobSetName, baseJobSetName).
Finalizers(SliceControllerName).
Obj(),
baseJobSetWrapper.Clone().Obj(),
baseSlice1Wrapper.Clone().Active().Obj(),
baseSlice2Wrapper.Clone().Active().Obj(),
},
Expand All @@ -1284,6 +1353,7 @@ func TestWorkloadReconciler(t *testing.T) {
*baseSlice1Wrapper.Clone().Active().Obj(),
*baseSlice2Wrapper.Clone().Active().Obj(),
},
wantJobSets: []jobset.JobSet{*baseJobSetWrapper.Clone().Obj()},
wantEvents: []utiltesting.EventRecord{
buildEventRecord(corev1.NamespaceDefault, corev1.EventTypeNormal, AdmissionCheckUpdatedEventType,
fmt.Sprintf(`Admission check %q updated state from "Pending" to "Ready"`, baseACName)),
Expand Down Expand Up @@ -1424,6 +1494,15 @@ func TestWorkloadReconciler(t *testing.T) {
t.Errorf("Workloads after reconcile (-want,+got):\n%s", diff)
}

jobSets := &jobset.JobSetList{}
err = kClient.List(ctx, jobSets)
if err != nil {
t.Errorf("Error listing jobsets: %v", err)
}
if diff := cmp.Diff(tc.wantJobSets, jobSets.Items, baseCmpOpts); diff != "" {
t.Errorf("JobSets after reconcile (-want,+got):\n%s", diff)
}

slices := &slice.SliceList{}
err = kClient.List(ctx, slices)
if err != nil {
Expand Down
5 changes: 5 additions & 0 deletions slice/internal/util/testingjobs/jobset/wrappers.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ func (j *JobSetWrapper) Obj() *jobsetapi.JobSet {
return &j.JobSet
}

// Clone returns a deep copy of the JobSetWrapper.
func (j *JobSetWrapper) Clone() *JobSetWrapper {
return &JobSetWrapper{*j.DeepCopy()}
}

func (j *JobSetWrapper) UID(uid string) *JobSetWrapper {
j.ObjectMeta.UID = types.UID(uid)
return j
Expand Down
Loading
Loading