From c9843dc0f536e37baa218e0e0364f1e474d847da Mon Sep 17 00:00:00 2001 From: Alex Guerrieri Date: Fri, 5 Jul 2024 09:10:10 +0200 Subject: [PATCH] Export test server --- common_test.go | 115 ------------------------------------- handler_test.go | 79 ++++++++++++------------- test/test_server.go | 136 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 176 insertions(+), 154 deletions(-) create mode 100644 test/test_server.go diff --git a/common_test.go b/common_test.go index ac0fd6d..7dfa9e6 100644 --- a/common_test.go +++ b/common_test.go @@ -4,10 +4,8 @@ import ( "context" "encoding/json" "log/slog" - "net" "net/http" "net/http/httptest" - "sync" "sync/atomic" "testing" "time" @@ -15,12 +13,10 @@ import ( "github.com/0xsequence/quotacontrol" "github.com/0xsequence/quotacontrol/middleware" "github.com/0xsequence/quotacontrol/proto" - "github.com/alicebob/miniredis/v2" "github.com/go-chi/jwtauth/v5" "github.com/goware/logger" "github.com/goware/cachestore/redis" - redisclient "github.com/redis/go-redis/v9" "github.com/stretchr/testify/require" ) @@ -45,117 +41,6 @@ func newQuotaClient(cfg quotacontrol.Config, service proto.Service) *quotacontro return quotacontrol.NewClient(logger, service, cfg) } -func newTestServer(t *testing.T, cfg *quotacontrol.Config) *testServer { - s := miniredis.NewMiniRedis() - s.Start() - t.Cleanup(s.Close) - cfg.Redis.Host = s.Host() - cfg.Redis.Port = uint16(s.Server().Addr().Port) - client := redisclient.NewClient(&redisclient.Options{Addr: s.Addr()}) - - store := quotacontrol.NewMemoryStore() - - listener, err := net.Listen("tcp", "localhost:0") - require.NoError(t, err) - cfg.URL = "http://" + listener.Addr().String() - - t.Cleanup(func() { require.NoError(t, listener.Close()) }) - - qc := testServer{ - logger: logger.NewLogger(logger.LogLevel_DEBUG), - listener: listener, - cache: client, - store: store, - notifications: make(map[uint64][]proto.EventType), - } - - qcCache := quotacontrol.Cache{ - QuotaCache: quotacontrol.NewRedisCache(client, time.Minute), - UsageCache: quotacontrol.NewRedisCache(client, time.Minute), - PermissionCache: quotacontrol.NewRedisCache(client, time.Minute), - } - qcStore := quotacontrol.Store{ - LimitStore: store, - AccessKeyStore: store, - UsageStore: store, - CycleStore: store, - PermissionStore: store, - } - - logger := qc.logger.With(slog.String("server", "server")) - qc.QuotaControl = quotacontrol.NewHandler(logger, qcCache, qcStore, nil) - - go func() { - http.Serve(listener, proto.NewQuotaControlServer(&qc)) - }() - - return &qc -} - -// testServer is a wrapper of quotacontrol that tracks the events that are notified and allows to inject errors -type testServer struct { - logger logger.Logger - listener net.Listener - cache *redisclient.Client - store *quotacontrol.MemoryStore - - proto.QuotaControl - - sync.Mutex - notifications map[uint64][]proto.EventType - - ErrGetProjectQuota error - ErrGetAccessQuota error - ErrPrepareUsage error - PrepareUsageDelay time.Duration -} - -func (qc *testServer) FlushCache() { - qc.cache.FlushAll(context.Background()) -} - -func (qc *testServer) GetProjectQuota(ctx context.Context, projectID uint64, now time.Time) (*proto.AccessQuota, error) { - if qc.ErrGetProjectQuota != nil { - return nil, qc.ErrGetProjectQuota - } - return qc.QuotaControl.GetProjectQuota(ctx, projectID, now) -} - -func (qc *testServer) GetAccessQuota(ctx context.Context, accessKey string, now time.Time) (*proto.AccessQuota, error) { - if qc.ErrGetAccessQuota != nil { - return nil, qc.ErrGetAccessQuota - } - return qc.QuotaControl.GetAccessQuota(ctx, accessKey, now) -} - -func (qc *testServer) PrepareUsage(ctx context.Context, projectID uint64, cycle *proto.Cycle, now time.Time) (bool, error) { - if qc.ErrPrepareUsage != nil { - return false, qc.ErrPrepareUsage - } - if qc.PrepareUsageDelay > 0 { - go func() { - time.Sleep(qc.PrepareUsageDelay) - qc.ClearUsage(ctx, projectID, now) - }() - return true, nil - } - return qc.QuotaControl.PrepareUsage(ctx, projectID, cycle, now) -} - -func (q *testServer) getEvents(projectID uint64) []proto.EventType { - q.Lock() - v := q.notifications[projectID] - q.Unlock() - return v -} - -func (q *testServer) NotifyEvent(ctx context.Context, projectID uint64, eventType proto.EventType) (bool, error) { - q.Lock() - q.notifications[projectID] = append(q.notifications[projectID], eventType) - q.Unlock() - return true, nil -} - type hitCounter int64 func (c *hitCounter) GetValue() int64 { diff --git a/handler_test.go b/handler_test.go index adafd11..efadccf 100644 --- a/handler_test.go +++ b/handler_test.go @@ -11,6 +11,7 @@ import ( . "github.com/0xsequence/quotacontrol" "github.com/0xsequence/quotacontrol/middleware" "github.com/0xsequence/quotacontrol/proto" + "github.com/0xsequence/quotacontrol/test" "github.com/go-chi/chi/v5" "github.com/go-chi/jwtauth/v5" "github.com/stretchr/testify/assert" @@ -23,7 +24,7 @@ func TestMiddlewareUseAccessKey(t *testing.T) { auth := jwtauth.New("HS256", []byte("secret"), nil) cfg := newConfig() - server := newTestServer(t, &cfg) + server := test.NewServer(t, &cfg) now := time.Now() project := uint64(7) @@ -39,9 +40,9 @@ func TestMiddlewareUseAccessKey(t *testing.T) { } ctx := context.Background() - err := server.store.SetAccessLimit(ctx, project, &limit) + err := server.Store.SetAccessLimit(ctx, project, &limit) require.NoError(t, err) - err = server.store.InsertAccessKey(ctx, &proto.AccessKey{Active: true, AccessKey: key, ProjectID: project}) + err = server.Store.InsertAccessKey(ctx, &proto.AccessKey{Active: true, AccessKey: key, ProjectID: project}) require.NoError(t, err) client := newQuotaClient(cfg, service) @@ -65,7 +66,7 @@ func TestMiddlewareUseAccessKey(t *testing.T) { go client.Run(context.Background()) ctx := middleware.WithTime(context.Background(), now) - server.notifications = make(map[uint64][]proto.EventType) + server.FlushNotifications() // Spend Free CU for i := int64(1); i < limit.FreeWarn; i++ { @@ -75,7 +76,7 @@ func TestMiddlewareUseAccessKey(t *testing.T) { assert.Equal(t, strconv.FormatInt(limit.FreeMax, 10), headers.Get(middleware.HeaderQuotaLimit)) assert.Equal(t, strconv.FormatInt(limit.FreeMax-i, 10), headers.Get(middleware.HeaderQuotaRemaining)) assert.Equal(t, "", headers.Get(middleware.HeaderQuotaOverage)) - assert.Empty(t, server.getEvents(project), i) + assert.Empty(t, server.GetEvents(project), i) expectedUsage.Add(proto.AccessUsage{ValidCompute: 1}) } @@ -86,7 +87,7 @@ func TestMiddlewareUseAccessKey(t *testing.T) { assert.Equal(t, strconv.FormatInt(limit.FreeMax, 10), headers.Get(middleware.HeaderQuotaLimit)) assert.Equal(t, "0", headers.Get(middleware.HeaderQuotaRemaining)) assert.Equal(t, "", headers.Get(middleware.HeaderQuotaOverage)) - assert.Contains(t, server.getEvents(project), proto.EventType_FreeMax) + assert.Contains(t, server.GetEvents(project), proto.EventType_FreeMax) expectedUsage.Add(proto.AccessUsage{ValidCompute: 1}) // Get close to soft quota @@ -97,7 +98,7 @@ func TestMiddlewareUseAccessKey(t *testing.T) { assert.Equal(t, strconv.FormatInt(limit.FreeMax, 10), headers.Get(middleware.HeaderQuotaLimit)) assert.Equal(t, "0", headers.Get(middleware.HeaderQuotaRemaining)) assert.Equal(t, strconv.FormatInt(i-limit.FreeWarn, 10), headers.Get(middleware.HeaderQuotaOverage)) - assert.Len(t, server.getEvents(project), 1) + assert.Len(t, server.GetEvents(project), 1) expectedUsage.Add(proto.AccessUsage{OverCompute: 1}) } @@ -108,7 +109,7 @@ func TestMiddlewareUseAccessKey(t *testing.T) { assert.Equal(t, strconv.FormatInt(limit.FreeMax, 10), headers.Get(middleware.HeaderQuotaLimit)) assert.Equal(t, "0", headers.Get(middleware.HeaderQuotaRemaining)) assert.Equal(t, strconv.FormatInt(limit.OverWarn-limit.FreeWarn, 10), headers.Get(middleware.HeaderQuotaOverage)) - assert.Contains(t, server.getEvents(project), proto.EventType_OverWarn) + assert.Contains(t, server.GetEvents(project), proto.EventType_OverWarn) expectedUsage.Add(proto.AccessUsage{OverCompute: 1}) // Get close to hard quota @@ -119,7 +120,7 @@ func TestMiddlewareUseAccessKey(t *testing.T) { assert.Equal(t, strconv.FormatInt(limit.FreeMax, 10), headers.Get(middleware.HeaderQuotaLimit)) assert.Equal(t, "0", headers.Get(middleware.HeaderQuotaRemaining)) assert.Equal(t, strconv.FormatInt(i-limit.FreeWarn, 10), headers.Get(middleware.HeaderQuotaOverage)) - assert.Len(t, server.getEvents(project), 2) + assert.Len(t, server.GetEvents(project), 2) expectedUsage.Add(proto.AccessUsage{OverCompute: 1}) } @@ -130,7 +131,7 @@ func TestMiddlewareUseAccessKey(t *testing.T) { assert.Equal(t, strconv.FormatInt(limit.FreeMax, 10), headers.Get(middleware.HeaderQuotaLimit)) assert.Equal(t, "0", headers.Get(middleware.HeaderQuotaRemaining)) assert.Equal(t, strconv.FormatInt(limit.OverMax-limit.FreeWarn, 10), headers.Get(middleware.HeaderQuotaOverage)) - assert.Contains(t, server.getEvents(project), proto.EventType_OverMax) + assert.Contains(t, server.GetEvents(project), proto.EventType_OverMax) expectedUsage.Add(proto.AccessUsage{OverCompute: 1}) // Denied @@ -146,7 +147,7 @@ func TestMiddlewareUseAccessKey(t *testing.T) { // check the usage client.Stop(context.Background()) - usage, err := server.store.GetAccountUsage(ctx, project, proto.Ptr(service), now.Add(-time.Hour), now.Add(time.Hour)) + usage, err := server.Store.GetAccountUsage(ctx, project, proto.Ptr(service), now.Add(-time.Hour), now.Add(time.Hour)) assert.NoError(t, err) assert.Equal(t, int64(expectedUsage.GetTotalUsage()), counter.GetValue()) assert.Equal(t, &expectedUsage, &usage) @@ -154,7 +155,7 @@ func TestMiddlewareUseAccessKey(t *testing.T) { t.Run("ChangeLimits", func(t *testing.T) { // Increase CreditsOverageLimit which should still allow requests to go through, etc. - err = server.store.SetAccessLimit(ctx, project, &proto.Limit{ + err = server.Store.SetAccessLimit(ctx, project, &proto.Limit{ RateLimit: 100, OverWarn: 5, OverMax: 110, @@ -166,7 +167,7 @@ func TestMiddlewareUseAccessKey(t *testing.T) { go client.Run(context.Background()) ctx := middleware.WithTime(context.Background(), now) - server.notifications = make(map[uint64][]proto.EventType) + server.FlushNotifications() ok, headers, err := executeRequest(ctx, r, "", key, "") assert.NoError(t, err) @@ -174,7 +175,7 @@ func TestMiddlewareUseAccessKey(t *testing.T) { assert.Equal(t, "0", headers.Get(middleware.HeaderQuotaLimit)) client.Stop(context.Background()) - usage, err := server.store.GetAccountUsage(ctx, project, proto.Ptr(service), now.Add(-time.Hour), now.Add(time.Hour)) + usage, err := server.Store.GetAccountUsage(ctx, project, proto.Ptr(service), now.Add(-time.Hour), now.Add(time.Hour)) assert.NoError(t, err) expectedUsage.Add(proto.AccessUsage{ValidCompute: 0, OverCompute: 1, LimitedCompute: 0}) assert.Equal(t, int64(expectedUsage.GetTotalUsage()), counter.GetValue()) @@ -200,14 +201,14 @@ func TestMiddlewareUseAccessKey(t *testing.T) { } client.Stop(context.Background()) - usage, err := server.store.GetAccountUsage(ctx, project, proto.Ptr(service), now.Add(-time.Hour), now.Add(time.Hour)) + usage, err := server.Store.GetAccountUsage(ctx, project, proto.Ptr(service), now.Add(-time.Hour), now.Add(time.Hour)) assert.NoError(t, err) assert.Equal(t, int64(expectedUsage.GetTotalUsage()), counter.GetValue()) assert.Equal(t, &expectedUsage, &usage) }) t.Run("ServerErrors", func(t *testing.T) { - server.FlushCache() + server.FlushCache(ctx) go client.Run(context.Background()) @@ -228,7 +229,7 @@ func TestMiddlewareUseAccessKey(t *testing.T) { } server.ErrGetAccessQuota = nil - server.FlushCache() + server.FlushCache(ctx) for _, err := range errList { server.ErrPrepareUsage = err @@ -240,14 +241,14 @@ func TestMiddlewareUseAccessKey(t *testing.T) { server.ErrPrepareUsage = nil client.Stop(context.Background()) - usage, err := server.store.GetAccountUsage(ctx, project, proto.Ptr(service), now.Add(-time.Hour), now.Add(time.Hour)) + usage, err := server.Store.GetAccountUsage(ctx, project, proto.Ptr(service), now.Add(-time.Hour), now.Add(time.Hour)) assert.NoError(t, err) assert.Equal(t, int64(expectedUsage.GetTotalUsage()), counter.GetValue()) assert.Equal(t, &expectedUsage, &usage) }) t.Run("ServerTimeout", func(t *testing.T) { - server.FlushCache() + server.FlushCache(ctx) go client.Run(context.Background()) @@ -260,7 +261,7 @@ func TestMiddlewareUseAccessKey(t *testing.T) { assert.NoError(t, err) client.Stop(context.Background()) - usage, err := server.store.GetAccountUsage(ctx, project, proto.Ptr(service), now.Add(-time.Hour), now.Add(time.Hour)) + usage, err := server.Store.GetAccountUsage(ctx, project, proto.Ptr(service), now.Add(-time.Hour), now.Add(time.Hour)) assert.NoError(t, err) assert.Equal(t, int64(expectedUsage.GetTotalUsage()), counter.GetValue()) assert.Equal(t, &expectedUsage, &usage) @@ -269,7 +270,7 @@ func TestMiddlewareUseAccessKey(t *testing.T) { func TestDefaultKey(t *testing.T) { cfg := newConfig() - server := newTestServer(t, &cfg) + server := test.NewServer(t, &cfg) now := time.Now() project := uint64(7) @@ -294,9 +295,9 @@ func TestDefaultKey(t *testing.T) { // populate store ctx := context.Background() - err := server.store.SetAccessLimit(ctx, project, &limit) + err := server.Store.SetAccessLimit(ctx, project, &limit) require.NoError(t, err) - err = server.store.InsertAccessKey(ctx, &proto.AccessKey{Active: true, AccessKey: keys[0], ProjectID: project}) + err = server.Store.InsertAccessKey(ctx, &proto.AccessKey{Active: true, AccessKey: keys[0], ProjectID: project}) require.NoError(t, err) client := newQuotaClient(cfg, *service) @@ -323,7 +324,7 @@ func TestDefaultKey(t *testing.T) { require.ErrorIs(t, err, proto.ErrAtLeastOneKey) assert.False(t, ok) newAccess := proto.AccessKey{Active: true, AccessKey: keys[1], ProjectID: project} - err = server.store.InsertAccessKey(ctx, &newAccess) + err = server.Store.InsertAccessKey(ctx, &newAccess) require.NoError(t, err) ok, err = server.DisableAccessKey(ctx, keys[0]) @@ -350,7 +351,7 @@ func TestJWT(t *testing.T) { counter := spendingCounter(0) cfg := newConfig() - server := newTestServer(t, &cfg) + server := test.NewServer(t, &cfg) client := newQuotaClient(cfg, service) r := chi.NewRouter() @@ -371,7 +372,7 @@ func TestJWT(t *testing.T) { OverWarn: 7, OverMax: 10, } - server.store.SetAccessLimit(ctx, project, &limit) + server.Store.SetAccessLimit(ctx, project, &limit) token := mustJWT(t, auth, middleware.Claims{"project": project, "account": account}) @@ -383,7 +384,7 @@ func TestJWT(t *testing.T) { assert.False(t, ok) assert.Equal(t, "", headers.Get(middleware.HeaderQuotaLimit)) }) - server.store.SetUserPermission(ctx, project, account, proto.UserPermission_READ_WRITE, proto.ResourceAccess{ProjectID: project}) + server.Store.SetUserPermission(ctx, project, account, proto.UserPermission_READ_WRITE, proto.ResourceAccess{ProjectID: project}) t.Run("AuthorizedUser", func(t *testing.T) { ok, headers, err := executeRequest(ctx, r, "", "", token) require.NoError(t, err) @@ -397,7 +398,7 @@ func TestJWT(t *testing.T) { assert.False(t, ok) assert.Equal(t, "", headers.Get(middleware.HeaderQuotaLimit)) }) - server.store.InsertAccessKey(ctx, &proto.AccessKey{Active: true, AccessKey: key, ProjectID: project}) + server.Store.InsertAccessKey(ctx, &proto.AccessKey{Active: true, AccessKey: key, ProjectID: project}) t.Run("AccessKeyFound", func(t *testing.T) { ok, _, err := executeRequest(ctx, r, "", key, token) require.NoError(t, err) @@ -424,7 +425,7 @@ func TestJWTAccess(t *testing.T) { counter := hitCounter(0) cfg := newConfig() - server := newTestServer(t, &cfg) + server := test.NewServer(t, &cfg) client := newQuotaClient(cfg, service) r := chi.NewRouter() @@ -443,7 +444,7 @@ func TestJWTAccess(t *testing.T) { OverWarn: 7, OverMax: 10, } - server.store.SetAccessLimit(ctx, project, &limit) + server.Store.SetAccessLimit(ctx, project, &limit) token := mustJWT(t, auth, middleware.Claims{"account": account, "project": project}) @@ -456,7 +457,7 @@ func TestJWTAccess(t *testing.T) { assert.Equal(t, "", headers.Get(middleware.HeaderQuotaLimit)) }) - server.store.SetUserPermission(ctx, project, account, proto.UserPermission_READ, proto.ResourceAccess{ProjectID: project}) + server.Store.SetUserPermission(ctx, project, account, proto.UserPermission_READ, proto.ResourceAccess{ProjectID: project}) t.Run("LowPermission", func(t *testing.T) { ok, headers, err := executeRequest(ctx, r, "", "", token) require.ErrorIs(t, err, proto.ErrUnauthorizedUser) @@ -465,8 +466,8 @@ func TestJWTAccess(t *testing.T) { assert.Equal(t, strconv.FormatInt(limit.RateLimit, 10), headers.Get(RateLimitHeader)) }) - server.store.SetUserPermission(ctx, project, account, proto.UserPermission_READ_WRITE, proto.ResourceAccess{ProjectID: project}) - server.FlushCache() + server.Store.SetUserPermission(ctx, project, account, proto.UserPermission_READ_WRITE, proto.ResourceAccess{ProjectID: project}) + server.FlushCache(ctx) t.Run("EnoughPermission", func(t *testing.T) { ok, headers, err := executeRequest(ctx, r, "", "", token) require.NoError(t, err) @@ -476,8 +477,8 @@ func TestJWTAccess(t *testing.T) { expectedHits++ }) - server.store.SetUserPermission(ctx, project, account, proto.UserPermission_ADMIN, proto.ResourceAccess{ProjectID: project}) - server.FlushCache() + server.Store.SetUserPermission(ctx, project, account, proto.UserPermission_ADMIN, proto.ResourceAccess{ProjectID: project}) + server.FlushCache(ctx) t.Run("MorePermission", func(t *testing.T) { ok, headers, err := executeRequest(ctx, r, "", "", token) require.NoError(t, err) @@ -502,7 +503,7 @@ func TestSession(t *testing.T) { counter := hitCounter(0) cfg := newConfig() - server := newTestServer(t, &cfg) + server := test.NewServer(t, &cfg) client := newQuotaClient(cfg, service) const ( @@ -541,9 +542,9 @@ func TestSession(t *testing.T) { OverWarn: 7, OverMax: 10, } - server.store.SetAccessLimit(ctx, project, &limit) - server.store.SetUserPermission(ctx, project, address, proto.UserPermission_READ, proto.ResourceAccess{ProjectID: project}) - server.store.InsertAccessKey(ctx, &proto.AccessKey{Active: true, AccessKey: key, ProjectID: project}) + server.Store.SetAccessLimit(ctx, project, &limit) + server.Store.SetUserPermission(ctx, project, address, proto.UserPermission_READ, proto.ResourceAccess{ProjectID: project}) + server.Store.InsertAccessKey(ctx, &proto.AccessKey{Active: true, AccessKey: key, ProjectID: project}) testCases := []struct { AccessKey string diff --git a/test/test_server.go b/test/test_server.go new file mode 100644 index 0000000..080ebed --- /dev/null +++ b/test/test_server.go @@ -0,0 +1,136 @@ +package test + +import ( + "context" + "log/slog" + "net" + "net/http" + "sync" + "testing" + "time" + + "github.com/0xsequence/quotacontrol" + "github.com/0xsequence/quotacontrol/proto" + "github.com/alicebob/miniredis/v2" + "github.com/goware/logger" + redisclient "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/require" +) + +func NewServer(t *testing.T, cfg *quotacontrol.Config) *TestServer { + s := miniredis.NewMiniRedis() + s.Start() + t.Cleanup(s.Close) + cfg.Redis.Host = s.Host() + cfg.Redis.Port = uint16(s.Server().Addr().Port) + client := redisclient.NewClient(&redisclient.Options{Addr: s.Addr()}) + + store := quotacontrol.NewMemoryStore() + + listener, err := net.Listen("tcp", "localhost:0") + require.NoError(t, err) + cfg.URL = "http://" + listener.Addr().String() + + t.Cleanup(func() { require.NoError(t, listener.Close()) }) + + qc := TestServer{ + logger: logger.NewLogger(logger.LogLevel_DEBUG), + listener: listener, + cache: client, + Store: store, + notifications: make(map[uint64][]proto.EventType), + } + + qcCache := quotacontrol.Cache{ + QuotaCache: quotacontrol.NewRedisCache(client, time.Minute), + UsageCache: quotacontrol.NewRedisCache(client, time.Minute), + PermissionCache: quotacontrol.NewRedisCache(client, time.Minute), + } + qcStore := quotacontrol.Store{ + LimitStore: store, + AccessKeyStore: store, + UsageStore: store, + CycleStore: store, + PermissionStore: store, + } + + logger := qc.logger.With(slog.String("server", "server")) + qc.QuotaControl = quotacontrol.NewHandler(logger, qcCache, qcStore, nil) + + go func() { + http.Serve(listener, proto.NewQuotaControlServer(&qc)) + }() + + return &qc +} + +// TestServer is a wrapper of quotacontrol that tracks the events that are notified and allows to inject errors +type TestServer struct { + logger logger.Logger + listener net.Listener + cache *redisclient.Client + + Store *quotacontrol.MemoryStore + + proto.QuotaControl + + mu sync.Mutex + notifications map[uint64][]proto.EventType + + ErrGetProjectQuota error + ErrGetAccessQuota error + ErrPrepareUsage error + PrepareUsageDelay time.Duration +} + +func (qc *TestServer) FlushNotifications() { + qc.mu.Lock() + qc.notifications = make(map[uint64][]proto.EventType) + qc.mu.Unlock() +} + +func (qc *TestServer) FlushCache(ctx context.Context) { + qc.cache.FlushAll(ctx) +} + +func (qc *TestServer) GetProjectQuota(ctx context.Context, projectID uint64, now time.Time) (*proto.AccessQuota, error) { + if qc.ErrGetProjectQuota != nil { + return nil, qc.ErrGetProjectQuota + } + return qc.QuotaControl.GetProjectQuota(ctx, projectID, now) +} + +func (qc *TestServer) GetAccessQuota(ctx context.Context, accessKey string, now time.Time) (*proto.AccessQuota, error) { + if qc.ErrGetAccessQuota != nil { + return nil, qc.ErrGetAccessQuota + } + return qc.QuotaControl.GetAccessQuota(ctx, accessKey, now) +} + +func (qc *TestServer) PrepareUsage(ctx context.Context, projectID uint64, cycle *proto.Cycle, now time.Time) (bool, error) { + if qc.ErrPrepareUsage != nil { + return false, qc.ErrPrepareUsage + } + if qc.PrepareUsageDelay > 0 { + go func() { + time.Sleep(qc.PrepareUsageDelay) + qc.ClearUsage(ctx, projectID, now) + }() + return true, nil + } + return qc.QuotaControl.PrepareUsage(ctx, projectID, cycle, now) +} + +func (q *TestServer) GetEvents(projectID uint64) []proto.EventType { + q.mu.Lock() + v := q.notifications[projectID] + q.mu.Unlock() + return v +} + +func (q *TestServer) NotifyEvent(ctx context.Context, projectID uint64, eventType proto.EventType) (bool, error) { + q.mu.Lock() + q.notifications[projectID] = append(q.notifications[projectID], eventType) + q.mu.Unlock() + return true, nil +}