Skip to content

Commit

Permalink
Fix cluster pool assignment validation
Browse files Browse the repository at this point in the history
  • Loading branch information
iaroslav-ciupin authored and andrewwdye committed Sep 27, 2024
1 parent ba8b4e3 commit 03986bc
Show file tree
Hide file tree
Showing 4 changed files with 599 additions and 25 deletions.
41 changes: 31 additions & 10 deletions flyteadmin/pkg/manager/impl/execution_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -403,15 +403,34 @@ func (m *ExecutionManager) getExecutionConfig(ctx context.Context, request *admi
return workflowExecConfig, nil
}

func (m *ExecutionManager) getClusterAssignment(ctx context.Context, request *admin.ExecutionCreateRequest) (
*admin.ClusterAssignment, error) {
if request.Spec.ClusterAssignment != nil {
return request.Spec.ClusterAssignment, nil
func (m *ExecutionManager) getClusterAssignment(ctx context.Context, req *admin.ExecutionCreateRequest) (*admin.ClusterAssignment, error) {
storedAssignment, err := m.fetchClusterAssignment(ctx, req.Project, req.Domain)
if err != nil {
return nil, err
}

reqAssignment := req.GetSpec().GetClusterAssignment()
reqPool := reqAssignment.GetClusterPoolName()
storedPool := storedAssignment.GetClusterPoolName()
if reqPool == "" {
return storedAssignment, nil
}

if storedPool == "" {
return reqAssignment, nil
}

if reqPool != storedPool {
return nil, errors.NewFlyteAdminErrorf(codes.InvalidArgument, "execution with project %q and domain %q cannot run on cluster pool %q, because its configured to run on pool %q", req.Project, req.Domain, reqPool, storedPool)
}

return storedAssignment, nil
}

func (m *ExecutionManager) fetchClusterAssignment(ctx context.Context, project, domain string) (*admin.ClusterAssignment, error) {
resource, err := m.resourceManager.GetResource(ctx, interfaces.ResourceRequest{
Project: request.Project,
Domain: request.Domain,
Project: project,
Domain: domain,
ResourceType: admin.MatchableResource_CLUSTER_ASSIGNMENT,
})
if err != nil && !errors.IsDoesNotExistError(err) {
Expand All @@ -421,11 +440,13 @@ func (m *ExecutionManager) getClusterAssignment(ctx context.Context, request *ad
if resource != nil && resource.Attributes.GetClusterAssignment() != nil {
return resource.Attributes.GetClusterAssignment(), nil
}
clusterPoolAssignment := m.config.ClusterPoolAssignmentConfiguration().GetClusterPoolAssignments()[request.GetDomain()]

return &admin.ClusterAssignment{
ClusterPoolName: clusterPoolAssignment.Pool,
}, nil
var clusterAssignment *admin.ClusterAssignment
domainAssignment := m.config.ClusterPoolAssignmentConfiguration().GetClusterPoolAssignments()[domain]
if domainAssignment.Pool != "" {
clusterAssignment = &admin.ClusterAssignment{ClusterPoolName: domainAssignment.Pool}
}
return clusterAssignment, nil
}

func (m *ExecutionManager) launchSingleTaskExecution(
Expand Down
112 changes: 97 additions & 15 deletions flyteadmin/pkg/manager/impl/execution_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -304,8 +304,7 @@ func TestCreateExecution(t *testing.T) {
}}
repository.ProjectRepo().(*repositoryMocks.MockProjectRepo).GetFunction = func(
ctx context.Context, projectID string) (models.Project, error) {
return transformers.CreateProjectModel(&admin.Project{
Labels: &labels}), nil
return transformers.CreateProjectModel(&admin.Project{Labels: &labels}), nil
}

clusterAssignment := admin.ClusterAssignment{ClusterPoolName: "gpu"}
Expand Down Expand Up @@ -382,8 +381,6 @@ func TestCreateExecution(t *testing.T) {

mockConfig := getMockExecutionsConfigProvider()
mockConfig.(*runtimeMocks.MockConfigurationProvider).AddQualityOfServiceConfiguration(qosProvider)

execManager := NewExecutionManager(repository, r, mockConfig, getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockScope.NewTestScope(), &mockPublisher, mockExecutionRemoteURL, nil, nil, &mockPublisher, nil, &eventWriterMocks.WorkflowExecutionEventWriter{})
request := testutils.GetExecutionRequest()
request.Spec.Metadata = &admin.ExecutionMetadata{
Principal: "unused - populated from authenticated context",
Expand All @@ -392,16 +389,18 @@ func TestCreateExecution(t *testing.T) {
request.Spec.ClusterAssignment = &clusterAssignment
request.Spec.ExecutionClusterLabel = &admin.ExecutionClusterLabel{Value: executionClusterLabel}

execManager := NewExecutionManager(repository, r, mockConfig, getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockScope.NewTestScope(), &mockPublisher, mockExecutionRemoteURL, nil, nil, &mockPublisher, nil, &eventWriterMocks.WorkflowExecutionEventWriter{})

identity, err := auth.NewIdentityContext("", principal, "", time.Now(), sets.NewString(), nil, nil)
assert.NoError(t, err)
ctx := identity.WithContext(context.Background())
response, err := execManager.CreateExecution(ctx, request, requestedAt)
assert.Nil(t, err)
assert.NoError(t, err)

expectedResponse := &admin.ExecutionCreateResponse{
Id: &executionIdentifier,
}
assert.Nil(t, err)
assert.NoError(t, err)
assert.True(t, proto.Equal(expectedResponse.Id, response.Id))

// TODO: Check for offloaded inputs
Expand Down Expand Up @@ -632,7 +631,6 @@ func TestCreateExecutionInCompatibleInputs(t *testing.T) {
}

func TestCreateExecutionPropellerFailure(t *testing.T) {
clusterAssignment := admin.ClusterAssignment{ClusterPoolName: "gpu"}
repository := getMockRepositoryForExecTest()
setDefaultLpCallbackForExecTest(repository)
expectedErr := flyteAdminErrors.NewFlyteAdminErrorf(codes.Internal, "ABC")
Expand Down Expand Up @@ -666,7 +664,6 @@ func TestCreateExecutionPropellerFailure(t *testing.T) {
Principal: "unused - populated from authenticated context",
}
request.Spec.RawOutputDataConfig = &admin.RawOutputDataConfig{OutputLocationPrefix: rawOutput}
request.Spec.ClusterAssignment = &clusterAssignment

identity, err := auth.NewIdentityContext("", principal, "", time.Now(), sets.NewString(), nil, nil)
assert.NoError(t, err)
Expand Down Expand Up @@ -5467,8 +5464,32 @@ func TestGetClusterAssignment(t *testing.T) {
assert.NoError(t, err)
assert.True(t, proto.Equal(ca, &clusterAssignment))
})
t.Run("value from request", func(t *testing.T) {
reqClusterAssignment := admin.ClusterAssignment{ClusterPoolName: "swimming-pool"}
t.Run("value from config", func(t *testing.T) {
customCP := "my_cp"
clusterPoolAsstProvider := &runtimeIFaceMocks.ClusterPoolAssignmentConfiguration{}
clusterPoolAsstProvider.OnGetClusterPoolAssignments().Return(runtimeInterfaces.ClusterPoolAssignments{
workflowIdentifier.GetDomain(): runtimeInterfaces.ClusterPoolAssignment{
Pool: customCP,
},
})
mockConfig := getMockExecutionsConfigProvider()
mockConfig.(*runtimeMocks.MockConfigurationProvider).AddClusterPoolAssignmentConfiguration(clusterPoolAsstProvider)

executionManager := ExecutionManager{
resourceManager: &managerMocks.MockResourceManager{},
config: mockConfig,
}

ca, err := executionManager.getClusterAssignment(context.TODO(), &admin.ExecutionCreateRequest{
Project: workflowIdentifier.Project,
Domain: workflowIdentifier.Domain,
Spec: &admin.ExecutionSpec{},
})
assert.NoError(t, err)
assert.Equal(t, customCP, ca.GetClusterPoolName())
})
t.Run("value from request matches value from config", func(t *testing.T) {
reqClusterAssignment := admin.ClusterAssignment{ClusterPoolName: "gpu"}
ca, err := executionManager.getClusterAssignment(context.TODO(), &admin.ExecutionCreateRequest{
Project: workflowIdentifier.Project,
Domain: workflowIdentifier.Domain,
Expand All @@ -5479,12 +5500,30 @@ func TestGetClusterAssignment(t *testing.T) {
assert.NoError(t, err)
assert.True(t, proto.Equal(ca, &reqClusterAssignment))
})
t.Run("value from config", func(t *testing.T) {
customCP := "my_cp"
t.Run("no value in DB nor in config, takes value from request", func(t *testing.T) {
mockConfig := getMockExecutionsConfigProvider()

executionManager := ExecutionManager{
resourceManager: &managerMocks.MockResourceManager{},
config: mockConfig,
}

reqClusterAssignment := admin.ClusterAssignment{ClusterPoolName: "gpu"}
ca, err := executionManager.getClusterAssignment(context.TODO(), &admin.ExecutionCreateRequest{
Project: workflowIdentifier.Project,
Domain: workflowIdentifier.Domain,
Spec: &admin.ExecutionSpec{
ClusterAssignment: &reqClusterAssignment,
},
})
assert.NoError(t, err)
assert.True(t, proto.Equal(ca, &reqClusterAssignment))
})
t.Run("empty value in DB, takes value from request", func(t *testing.T) {
clusterPoolAsstProvider := &runtimeIFaceMocks.ClusterPoolAssignmentConfiguration{}
clusterPoolAsstProvider.OnGetClusterPoolAssignments().Return(runtimeInterfaces.ClusterPoolAssignments{
workflowIdentifier.GetDomain(): runtimeInterfaces.ClusterPoolAssignment{
Pool: customCP,
Pool: "",
},
})
mockConfig := getMockExecutionsConfigProvider()
Expand All @@ -5495,13 +5534,56 @@ func TestGetClusterAssignment(t *testing.T) {
config: mockConfig,
}

reqClusterAssignment := admin.ClusterAssignment{ClusterPoolName: "gpu"}
ca, err := executionManager.getClusterAssignment(context.TODO(), &admin.ExecutionCreateRequest{
Project: workflowIdentifier.Project,
Domain: workflowIdentifier.Domain,
Spec: &admin.ExecutionSpec{},
Spec: &admin.ExecutionSpec{
ClusterAssignment: &reqClusterAssignment,
},
})
assert.NoError(t, err)
assert.Equal(t, customCP, ca.GetClusterPoolName())
assert.True(t, proto.Equal(ca, &reqClusterAssignment))
})
t.Run("value from request doesn't match value from config", func(t *testing.T) {
reqClusterAssignment := admin.ClusterAssignment{ClusterPoolName: "swimming-pool"}
_, err := executionManager.getClusterAssignment(context.TODO(), &admin.ExecutionCreateRequest{
Project: workflowIdentifier.Project,
Domain: workflowIdentifier.Domain,
Spec: &admin.ExecutionSpec{
ClusterAssignment: &reqClusterAssignment,
},
})
st, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, codes.InvalidArgument, st.Code())
assert.Equal(t, `execution with project "project" and domain "domain" cannot run on cluster pool "swimming-pool", because its configured to run on pool "gpu"`, st.Message())
})
t.Run("db error", func(t *testing.T) {
expected := errors.New("fail db")
resourceManager.GetResourceFunc = func(ctx context.Context,
request managerInterfaces.ResourceRequest) (*managerInterfaces.ResourceResponse, error) {
assert.EqualValues(t, request, managerInterfaces.ResourceRequest{
Project: workflowIdentifier.Project,
Domain: workflowIdentifier.Domain,
ResourceType: admin.MatchableResource_CLUSTER_ASSIGNMENT,
})
return &managerInterfaces.ResourceResponse{
Attributes: &admin.MatchingAttributes{
Target: &admin.MatchingAttributes_ClusterAssignment{
ClusterAssignment: &clusterAssignment,
},
},
}, expected
}

_, err := executionManager.getClusterAssignment(context.TODO(), &admin.ExecutionCreateRequest{
Project: workflowIdentifier.Project,
Domain: workflowIdentifier.Domain,
Spec: &admin.ExecutionSpec{},
})

assert.Equal(t, expected, err)
})
}

Expand Down
2 changes: 2 additions & 0 deletions flyteadmin/pkg/manager/interfaces/resource.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin"
)

//go:generate mockery -name ResourceInterface -output=../mocks -case=underscore

// ResourceInterface manages project, domain and workflow -specific attributes.
type ResourceInterface interface {
ListAll(ctx context.Context, request *admin.ListMatchableAttributesRequest) (
Expand Down
Loading

0 comments on commit 03986bc

Please sign in to comment.