Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Add more tests
Browse files Browse the repository at this point in the history
Signed-off-by: Haytham Abuelfutuh <haytham@afutuh.com>
  • Loading branch information
EngHabu committed Sep 22, 2023
1 parent 71c70f1 commit 67de51e
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 5 deletions.
4 changes: 2 additions & 2 deletions go/tasks/pluginmachinery/core/phase.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ func phaseInfo(p Phase, v uint32, err *core.ExecutionError, info *TaskInfo, clea
}
}

// Return in the case the plugin is not ready to start
// PhaseInfoNotReady represents the case the plugin is not ready to start
func PhaseInfoNotReady(t time.Time, version uint32, reason string) PhaseInfo {
pi := phaseInfo(PhaseNotReady, version, nil, &TaskInfo{OccurredAt: &t}, false)
pi.reason = reason
Expand All @@ -198,7 +198,7 @@ func PhaseInfoWaitingForResources(t time.Time, version uint32, reason string) Ph
return pi
}

// Return in the case the plugin is not ready to start
// PhaseInfoWaitingForResourcesInfo represents the case the plugin is not ready to start
func PhaseInfoWaitingForResourcesInfo(t time.Time, version uint32, reason string, info *TaskInfo) PhaseInfo {
pi := phaseInfo(PhaseWaitingForResources, version, nil, info, false)
pi.reason = reason
Expand Down
8 changes: 5 additions & 3 deletions go/tasks/plugins/k8s/ray/ray.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC
}

podSpec, objectMeta, primaryContainerName, err := flytek8s.ToK8sPodSpec(ctx, taskCtx)

if err != nil {
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create pod spec: [%v]", err.Error())
}
Expand Down Expand Up @@ -123,6 +122,7 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC
if spec.MinReplicas != 0 {
minReplicas = spec.MinReplicas
}

if spec.MaxReplicas != 0 {
maxReplicas = spec.MaxReplicas
}
Expand Down Expand Up @@ -400,9 +400,10 @@ func (plugin rayJobResourceHandler) GetTaskPhase(ctx context.Context, pluginCont
if err != nil {
return pluginsCore.PhaseInfoUndefined, err
}

switch rayJob.Status.JobStatus {
case rayv1alpha1.JobStatusPending:
return pluginsCore.PhaseInfoNotReady(time.Now(), pluginsCore.DefaultPhaseVersion, "job is pending"), nil
return pluginsCore.PhaseInfoInitializing(time.Now(), pluginsCore.DefaultPhaseVersion, "job is pending", info), nil
case rayv1alpha1.JobStatusFailed:
reason := fmt.Sprintf("Failed to create Ray job: %s", rayJob.Name)
return pluginsCore.PhaseInfoFailure(flyteerr.TaskFailedWithError, reason, info), nil
Expand All @@ -411,7 +412,8 @@ func (plugin rayJobResourceHandler) GetTaskPhase(ctx context.Context, pluginCont
case rayv1alpha1.JobStatusRunning:
return pluginsCore.PhaseInfoRunning(pluginsCore.DefaultPhaseVersion, info), nil
}
return pluginsCore.PhaseInfoQueued(time.Now(), pluginsCore.DefaultPhaseVersion, "JobCreated"), nil

return pluginsCore.PhaseInfoQueued(rayJob.CreationTimestamp.Time, pluginsCore.DefaultPhaseVersion, "JobCreated"), nil
}

func init() {
Expand Down
61 changes: 61 additions & 0 deletions go/tasks/plugins/k8s/ray/ray_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ import (
"context"
"testing"

"github.com/flyteorg/flyteplugins/go/tasks/logs"
mocks2 "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/k8s/mocks"

"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/flytek8s/config"

"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils"
Expand Down Expand Up @@ -188,6 +191,64 @@ func TestBuildResourceRay(t *testing.T) {
assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec.Tolerations, toleration)
}

func newPluginContext() k8s.PluginContext {
plg := &mocks2.PluginContext{}

taskExecID := &mocks.TaskExecutionID{}
taskExecID.OnGetID().Return(core.TaskExecutionIdentifier{
NodeExecutionId: &core.NodeExecutionIdentifier{
ExecutionId: &core.WorkflowExecutionIdentifier{
Name: "my_name",
Project: "my_project",
Domain: "my_domain",
},
},
})

tskCtx := &mocks.TaskExecutionMetadata{}
tskCtx.OnGetTaskExecutionID().Return(taskExecID)
plg.OnTaskExecutionMetadata().Return(tskCtx)
return plg
}

func init() {
f := defaultConfig
f.Logs = logs.LogConfig{
IsKubernetesEnabled: true,
}

if err := SetConfig(&f); err != nil {
panic(err)
}
}

func TestGetTaskPhase(t *testing.T) {
ctx := context.Background()
rayJobResourceHandler := rayJobResourceHandler{}
pluginCtx := newPluginContext()

testCases := []struct {
rayJobPhase rayv1alpha1.JobStatus
expectedCorePhase pluginsCore.Phase
}{
{"", pluginsCore.PhaseQueued},
{rayv1alpha1.JobStatusPending, pluginsCore.PhaseInitializing},
{rayv1alpha1.JobStatusRunning, pluginsCore.PhaseRunning},
{rayv1alpha1.JobStatusSucceeded, pluginsCore.PhaseSuccess},
{rayv1alpha1.JobStatusFailed, pluginsCore.PhasePermanentFailure},
}

for _, tc := range testCases {
t.Run("TestGetTaskPhase_"+string(tc.rayJobPhase), func(t *testing.T) {
rayObject := &rayv1alpha1.RayJob{}
rayObject.Status.JobStatus = tc.rayJobPhase
phaseInfo, err := rayJobResourceHandler.GetTaskPhase(ctx, pluginCtx, rayObject)
assert.Nil(t, err)
assert.Equal(t, tc.expectedCorePhase.String(), phaseInfo.Phase().String())
})
}
}

func TestGetPropertiesRay(t *testing.T) {
rayJobResourceHandler := rayJobResourceHandler{}
expected := k8s.PluginProperties{}
Expand Down

0 comments on commit 67de51e

Please sign in to comment.