Skip to content

Commit

Permalink
Export test server
Browse files Browse the repository at this point in the history
  • Loading branch information
klaidliadon committed Jul 5, 2024
1 parent f39c6fd commit c9843dc
Show file tree
Hide file tree
Showing 3 changed files with 176 additions and 154 deletions.
115 changes: 0 additions & 115 deletions common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,19 @@ import (
"context"
"encoding/json"
"log/slog"
"net"
"net/http"
"net/http/httptest"
"sync"
"sync/atomic"
"testing"
"time"

"github.com/0xsequence/quotacontrol"
"github.com/0xsequence/quotacontrol/middleware"
"github.com/0xsequence/quotacontrol/proto"
"github.com/alicebob/miniredis/v2"
"github.com/go-chi/jwtauth/v5"
"github.com/goware/logger"

"github.com/goware/cachestore/redis"
redisclient "github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
)

Expand All @@ -45,117 +41,6 @@ func newQuotaClient(cfg quotacontrol.Config, service proto.Service) *quotacontro
return quotacontrol.NewClient(logger, service, cfg)
}

func newTestServer(t *testing.T, cfg *quotacontrol.Config) *testServer {
s := miniredis.NewMiniRedis()
s.Start()
t.Cleanup(s.Close)
cfg.Redis.Host = s.Host()
cfg.Redis.Port = uint16(s.Server().Addr().Port)
client := redisclient.NewClient(&redisclient.Options{Addr: s.Addr()})

store := quotacontrol.NewMemoryStore()

listener, err := net.Listen("tcp", "localhost:0")
require.NoError(t, err)
cfg.URL = "http://" + listener.Addr().String()

t.Cleanup(func() { require.NoError(t, listener.Close()) })

qc := testServer{
logger: logger.NewLogger(logger.LogLevel_DEBUG),
listener: listener,
cache: client,
store: store,
notifications: make(map[uint64][]proto.EventType),
}

qcCache := quotacontrol.Cache{
QuotaCache: quotacontrol.NewRedisCache(client, time.Minute),
UsageCache: quotacontrol.NewRedisCache(client, time.Minute),
PermissionCache: quotacontrol.NewRedisCache(client, time.Minute),
}
qcStore := quotacontrol.Store{
LimitStore: store,
AccessKeyStore: store,
UsageStore: store,
CycleStore: store,
PermissionStore: store,
}

logger := qc.logger.With(slog.String("server", "server"))
qc.QuotaControl = quotacontrol.NewHandler(logger, qcCache, qcStore, nil)

go func() {
http.Serve(listener, proto.NewQuotaControlServer(&qc))
}()

return &qc
}

// testServer is a wrapper of quotacontrol that tracks the events that are notified and allows to inject errors
type testServer struct {
logger logger.Logger
listener net.Listener
cache *redisclient.Client
store *quotacontrol.MemoryStore

proto.QuotaControl

sync.Mutex
notifications map[uint64][]proto.EventType

ErrGetProjectQuota error
ErrGetAccessQuota error
ErrPrepareUsage error
PrepareUsageDelay time.Duration
}

func (qc *testServer) FlushCache() {
qc.cache.FlushAll(context.Background())
}

func (qc *testServer) GetProjectQuota(ctx context.Context, projectID uint64, now time.Time) (*proto.AccessQuota, error) {
if qc.ErrGetProjectQuota != nil {
return nil, qc.ErrGetProjectQuota
}
return qc.QuotaControl.GetProjectQuota(ctx, projectID, now)
}

func (qc *testServer) GetAccessQuota(ctx context.Context, accessKey string, now time.Time) (*proto.AccessQuota, error) {
if qc.ErrGetAccessQuota != nil {
return nil, qc.ErrGetAccessQuota
}
return qc.QuotaControl.GetAccessQuota(ctx, accessKey, now)
}

func (qc *testServer) PrepareUsage(ctx context.Context, projectID uint64, cycle *proto.Cycle, now time.Time) (bool, error) {
if qc.ErrPrepareUsage != nil {
return false, qc.ErrPrepareUsage
}
if qc.PrepareUsageDelay > 0 {
go func() {
time.Sleep(qc.PrepareUsageDelay)
qc.ClearUsage(ctx, projectID, now)
}()
return true, nil
}
return qc.QuotaControl.PrepareUsage(ctx, projectID, cycle, now)
}

func (q *testServer) getEvents(projectID uint64) []proto.EventType {
q.Lock()
v := q.notifications[projectID]
q.Unlock()
return v
}

func (q *testServer) NotifyEvent(ctx context.Context, projectID uint64, eventType proto.EventType) (bool, error) {
q.Lock()
q.notifications[projectID] = append(q.notifications[projectID], eventType)
q.Unlock()
return true, nil
}

type hitCounter int64

func (c *hitCounter) GetValue() int64 {
Expand Down
Loading

0 comments on commit c9843dc

Please sign in to comment.