diff --git a/go/tasks/pluginmachinery/core/phase.go b/go/tasks/pluginmachinery/core/phase.go index 51d3e4e81..93fd3067d 100644 --- a/go/tasks/pluginmachinery/core/phase.go +++ b/go/tasks/pluginmachinery/core/phase.go @@ -184,7 +184,7 @@ func phaseInfo(p Phase, v uint32, err *core.ExecutionError, info *TaskInfo, clea } } -// Return in the case the plugin is not ready to start +// PhaseInfoNotReady represents the case the plugin is not ready to start func PhaseInfoNotReady(t time.Time, version uint32, reason string) PhaseInfo { pi := phaseInfo(PhaseNotReady, version, nil, &TaskInfo{OccurredAt: &t}, false) pi.reason = reason @@ -198,7 +198,7 @@ func PhaseInfoWaitingForResources(t time.Time, version uint32, reason string) Ph return pi } -// Return in the case the plugin is not ready to start +// PhaseInfoWaitingForResourcesInfo represents the case the plugin is not ready to start func PhaseInfoWaitingForResourcesInfo(t time.Time, version uint32, reason string, info *TaskInfo) PhaseInfo { pi := phaseInfo(PhaseWaitingForResources, version, nil, info, false) pi.reason = reason diff --git a/go/tasks/plugins/k8s/ray/ray.go b/go/tasks/plugins/k8s/ray/ray.go index ec3cd3e82..cb898a492 100644 --- a/go/tasks/plugins/k8s/ray/ray.go +++ b/go/tasks/plugins/k8s/ray/ray.go @@ -59,7 +59,6 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC } podSpec, objectMeta, primaryContainerName, err := flytek8s.ToK8sPodSpec(ctx, taskCtx) - if err != nil { return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create pod spec: [%v]", err.Error()) } @@ -123,6 +122,7 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC if spec.MinReplicas != 0 { minReplicas = spec.MinReplicas } + if spec.MaxReplicas != 0 { maxReplicas = spec.MaxReplicas } @@ -400,9 +400,10 @@ func (plugin rayJobResourceHandler) GetTaskPhase(ctx context.Context, pluginCont if err != nil { return pluginsCore.PhaseInfoUndefined, err } + switch rayJob.Status.JobStatus { case rayv1alpha1.JobStatusPending: - return pluginsCore.PhaseInfoNotReady(time.Now(), pluginsCore.DefaultPhaseVersion, "job is pending"), nil + return pluginsCore.PhaseInfoInitializing(time.Now(), pluginsCore.DefaultPhaseVersion, "job is pending", info), nil case rayv1alpha1.JobStatusFailed: reason := fmt.Sprintf("Failed to create Ray job: %s", rayJob.Name) return pluginsCore.PhaseInfoFailure(flyteerr.TaskFailedWithError, reason, info), nil @@ -411,7 +412,8 @@ func (plugin rayJobResourceHandler) GetTaskPhase(ctx context.Context, pluginCont case rayv1alpha1.JobStatusRunning: return pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, info), nil } - return pluginsCore.PhaseInfoQueued(time.Now(), pluginsCore.DefaultPhaseVersion, "JobCreated"), nil + + return pluginsCore.PhaseInfoQueued(rayJob.CreationTimestamp.Time, pluginsCore.DefaultPhaseVersion, "JobCreated"), nil } func init() { diff --git a/go/tasks/plugins/k8s/ray/ray_test.go b/go/tasks/plugins/k8s/ray/ray_test.go index f49b8d5d5..2742995ef 100644 --- a/go/tasks/plugins/k8s/ray/ray_test.go +++ b/go/tasks/plugins/k8s/ray/ray_test.go @@ -4,6 +4,9 @@ import ( "context" "testing" + "github.com/flyteorg/flyteplugins/go/tasks/logs" + mocks2 "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/k8s/mocks" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s/config" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils" @@ -188,6 +191,64 @@ func TestBuildResourceRay(t *testing.T) { assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec.Tolerations, toleration) } +func newPluginContext() k8s.PluginContext { + plg := &mocks2.PluginContext{} + + taskExecID := &mocks.TaskExecutionID{} + taskExecID.OnGetID().Return(core.TaskExecutionIdentifier{ + NodeExecutionId: &core.NodeExecutionIdentifier{ + ExecutionId: &core.WorkflowExecutionIdentifier{ + Name: "my_name", + Project: "my_project", + Domain: "my_domain", + }, + }, + }) + + tskCtx := &mocks.TaskExecutionMetadata{} + tskCtx.OnGetTaskExecutionID().Return(taskExecID) + plg.OnTaskExecutionMetadata().Return(tskCtx) + return plg +} + +func init() { + f := defaultConfig + f.Logs = logs.LogConfig{ + IsKubernetesEnabled: true, + } + + if err := SetConfig(&f); err != nil { + panic(err) + } +} + +func TestGetTaskPhase(t *testing.T) { + ctx := context.Background() + rayJobResourceHandler := rayJobResourceHandler{} + pluginCtx := newPluginContext() + + testCases := []struct { + rayJobPhase rayv1alpha1.JobStatus + expectedCorePhase pluginsCore.Phase + }{ + {"", pluginsCore.PhaseQueued}, + {rayv1alpha1.JobStatusPending, pluginsCore.PhaseInitializing}, + {rayv1alpha1.JobStatusRunning, pluginsCore.PhaseRunning}, + {rayv1alpha1.JobStatusSucceeded, pluginsCore.PhaseSuccess}, + {rayv1alpha1.JobStatusFailed, pluginsCore.PhasePermanentFailure}, + } + + for _, tc := range testCases { + t.Run("TestGetTaskPhase_"+string(tc.rayJobPhase), func(t *testing.T) { + rayObject := &rayv1alpha1.RayJob{} + rayObject.Status.JobStatus = tc.rayJobPhase + phaseInfo, err := rayJobResourceHandler.GetTaskPhase(ctx, pluginCtx, rayObject) + assert.Nil(t, err) + assert.Equal(t, tc.expectedCorePhase.String(), phaseInfo.Phase().String()) + }) + } +} + func TestGetPropertiesRay(t *testing.T) { rayJobResourceHandler := rayJobResourceHandler{} expected := k8s.PluginProperties{}