Skip to content

Commit

Permalink
refactor!: update batch service
Browse files Browse the repository at this point in the history
  • Loading branch information
yankeguo committed Jun 28, 2024
1 parent 19dd77f commit b92d438
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 22 deletions.
75 changes: 56 additions & 19 deletions batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@ const (
BatchCompletionWindow24h = "24h"
)

// BatchRequestCounts represents the counts of the batch requests.
type BatchRequestCounts struct {
Total int64 `json:"total"`
Completed int64 `json:"completed"`
Failed int64 `json:"failed"`
}

// BatchItem represents a batch item.
type BatchItem struct {
ID string `json:"id"`
Object any `json:"object"`
Expand All @@ -44,6 +46,7 @@ type BatchItem struct {
Metadata json.RawMessage `json:"metadata"`
}

// BatchCreateService is a service to create a batch.
type BatchCreateService struct {
client *Client

Expand All @@ -53,42 +56,52 @@ type BatchCreateService struct {
metadata any
}

func (c *Client) BatchCreateService() *BatchCreateService {
return &BatchCreateService{client: c}
// NewBatchCreateService creates a new BatchCreateService.
func NewBatchCreateService(client *Client) *BatchCreateService {
return &BatchCreateService{client: client}
}

// SetInputFileID sets the input file id for the batch.
func (s *BatchCreateService) SetInputFileID(inputFileID string) *BatchCreateService {
s.inputFileID = inputFileID
return s
}

// SetEndpoint sets the endpoint for the batch.
func (s *BatchCreateService) SetEndpoint(endpoint string) *BatchCreateService {
s.endpoint = endpoint
return s
}

// SetCompletionWindow sets the completion window for the batch.
func (s *BatchCreateService) SetCompletionWindow(window string) *BatchCreateService {
s.completionWindow = window
return s
}

// SetMetadata sets the metadata for the batch.
func (s *BatchCreateService) SetMetadata(metadata any) *BatchCreateService {
s.metadata = metadata
return s
}

// Do executes the batch create service.
func (s *BatchCreateService) Do(ctx context.Context) (res BatchItem, err error) {
var (
resp *resty.Response
apiError APIErrorResponse
)

if resp, err = s.client.request(ctx).SetBody(M{
"input_file_id": s.inputFileID,
"endpoint": s.endpoint,
"completion_window": s.completionWindow,
"metadata": s.metadata,
}).SetResult(&res).SetError(&apiError).Post("batches"); err != nil {
if resp, err = s.client.request(ctx).
SetBody(M{
"input_file_id": s.inputFileID,
"endpoint": s.endpoint,
"completion_window": s.completionWindow,
"metadata": s.metadata,
}).
SetResult(&res).
SetError(&apiError).
Post("batches"); err != nil {
return
}

Expand All @@ -99,28 +112,37 @@ func (s *BatchCreateService) Do(ctx context.Context) (res BatchItem, err error)
return
}

// BatchGetService is a service to get a batch.
type BatchGetService struct {
client *Client
batchID string
}

func (c *Client) BatchGetService(batchID string) *BatchGetService {
return &BatchGetService{client: c, batchID: batchID}
// BatchGetResponse represents the response of the batch get service.
type BatchGetResponse = BatchItem

// NewBatchGetService creates a new BatchGetService.
func NewBatchGetService(client *Client) *BatchGetService {
return &BatchGetService{client: client}
}

// SetBatchID sets the batch id for the batch get service.
func (s *BatchGetService) SetBatchID(batchID string) *BatchGetService {
s.batchID = batchID
return s
}

func (s *BatchGetService) Do(ctx context.Context) (res BatchItem, err error) {
// Do executes the batch get service.
func (s *BatchGetService) Do(ctx context.Context) (res BatchGetResponse, err error) {
var (
resp *resty.Response
apiError APIErrorResponse
)

if resp, err = s.client.request(ctx).
SetPathParam("batch_id", s.batchID).SetResult(&res).SetError(&apiError).
SetPathParam("batch_id", s.batchID).
SetResult(&res).
SetError(&apiError).
Get("batches/{batch_id}"); err != nil {
return
}
Expand All @@ -132,28 +154,34 @@ func (s *BatchGetService) Do(ctx context.Context) (res BatchItem, err error) {
return
}

// BatchCancelService is a service to cancel a batch.
type BatchCancelService struct {
client *Client
batchID string
}

func (c *Client) BatchCancelService(batchID string) *BatchCancelService {
return &BatchCancelService{client: c, batchID: batchID}
// NewBatchCancelService creates a new BatchCancelService.
func NewBatchCancelService(client *Client) *BatchCancelService {
return &BatchCancelService{client: client}
}

// SetBatchID sets the batch id for the batch cancel service.
func (s *BatchCancelService) SetBatchID(batchID string) *BatchCancelService {
s.batchID = batchID
return s
}

// Do executes the batch cancel service.
func (s *BatchCancelService) Do(ctx context.Context) (err error) {
var (
resp *resty.Response
apiError APIErrorResponse
)

if resp, err = s.client.request(ctx).SetBody(M{}).
SetPathParam("batch_id", s.batchID).SetError(&apiError).
if resp, err = s.client.request(ctx).
SetPathParam("batch_id", s.batchID).
SetBody(M{}).
SetError(&apiError).
Post("batches/{batch_id}/cancel"); err != nil {
return
}
Expand All @@ -165,13 +193,15 @@ func (s *BatchCancelService) Do(ctx context.Context) (err error) {
return
}

// BatchListService is a service to list batches.
type BatchListService struct {
client *Client

after *string
limit *int
}

// BatchListResponse represents the response of the batch list service.
type BatchListResponse struct {
Object string `json:"object"`
Data []BatchItem `json:"data"`
Expand All @@ -180,20 +210,24 @@ type BatchListResponse struct {
HasMore bool `json:"has_more"`
}

func (c *Client) BatchListService() *BatchListService {
return &BatchListService{client: c}
// NewBatchListService creates a new BatchListService.
func NewBatchListService(client *Client) *BatchListService {
return &BatchListService{client: client}
}

// SetAfter sets the after cursor for the batch list service.
func (s *BatchListService) SetAfter(after string) *BatchListService {
s.after = &after
return s
}

// SetLimit sets the limit for the batch list service.
func (s *BatchListService) SetLimit(limit int) *BatchListService {
s.limit = &limit
return s
}

// Do executes the batch list service.
func (s *BatchListService) Do(ctx context.Context) (res BatchListResponse, err error) {
var (
resp *resty.Response
Expand All @@ -208,7 +242,10 @@ func (s *BatchListService) Do(ctx context.Context) (res BatchListResponse, err e
req.SetQueryParam("limit", strconv.Itoa(*s.limit))
}

if resp, err = req.SetResult(&res).SetError(&apiError).Get("batches"); err != nil {
if resp, err = req.
SetResult(&res).
SetError(&apiError).
Get("batches"); err != nil {
return
}

Expand Down
9 changes: 9 additions & 0 deletions batch_support.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,25 @@ import (
"io"
)

// BatchSupport is the interface for services with batch support.
type BatchSupport interface {
BatchMethod() string
BatchURL() string
BatchBody() any
}

// BatchFileWriter is a writer for batch files.
type BatchFileWriter struct {
w io.Writer
je *json.Encoder
}

// NewBatchFileWriter creates a new BatchFileWriter.
func NewBatchFileWriter(w io.Writer) *BatchFileWriter {
return &BatchFileWriter{w: w, je: json.NewEncoder(w)}
}

// Write writes a batch file.
func (b *BatchFileWriter) Write(customID string, s BatchSupport) error {
return b.je.Encode(M{
"custom_id": customID,
Expand All @@ -29,26 +33,31 @@ func (b *BatchFileWriter) Write(customID string, s BatchSupport) error {
})
}

// BatchResultResponse is the response of a batch result.
type BatchResultResponse[T any] struct {
StatusCode int `json:"status_code"`
Body T `json:"body"`
}

// BatchResult is the result of a batch.
type BatchResult[T any] struct {
ID string `json:"id"`
CustomID string `json:"custom_id"`
Response BatchResultResponse[T] `json:"response"`
}

// BatchResultReader reads batch results.
type BatchResultReader[T any] struct {
r io.Reader
jd *json.Decoder
}

// NewBatchResultReader creates a new BatchResultReader.
func NewBatchResultReader[T any](r io.Reader) *BatchResultReader[T] {
return &BatchResultReader[T]{r: r, jd: json.NewDecoder(r)}
}

// Read reads a batch result.
func (r *BatchResultReader[T]) Read(out *BatchResult[T]) error {
return r.jd.Decode(out)
}
6 changes: 3 additions & 3 deletions batch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,22 @@ func TestBatchServiceAll(t *testing.T) {
fileID := res.FileCreateFineTuneResponse.ID
require.NotEmpty(t, fileID)

res1, err := client.BatchCreateService().
res1, err := client.BatchCreate().
SetInputFileID(fileID).
SetCompletionWindow(BatchCompletionWindow24h).
SetEndpoint(BatchEndpointV4ChatCompletions).Do(context.Background())
require.NoError(t, err)
require.NotEmpty(t, res1.ID)

res2, err := client.BatchGetService(res1.ID).Do(context.Background())
res2, err := client.BatchGet(res1.ID).Do(context.Background())
require.NoError(t, err)
require.Equal(t, res2.ID, res1.ID)

res3, err := client.BatchListService().Do(context.Background())
require.NoError(t, err)
require.NotEmpty(t, res3.Data)

err = client.BatchCancelService(res1.ID).Do(context.Background())
err = client.BatchCancel(res1.ID).Do(context.Background())
require.NoError(t, err)
}

Expand Down
20 changes: 20 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,26 @@ func NewClient(optFns ...ClientOption) (client *Client, err error) {
return
}

// BatchCreate creates a new BatchCreateService.
func (c *Client) BatchCreate() *BatchCreateService {
return NewBatchCreateService(c)
}

// BatchGet creates a new BatchGetService.
func (c *Client) BatchGet(batchID string) *BatchGetService {
return NewBatchGetService(c).SetBatchID(batchID)
}

// BatchCancel creates a new BatchCancelService.
func (c *Client) BatchCancel(batchID string) *BatchCancelService {
return NewBatchCancelService(c).SetBatchID(batchID)
}

// BatchList creates a new BatchListService.
func (c *Client) BatchListService() *BatchListService {
return NewBatchListService(c)
}

// ChatCompletion creates a new ChatCompletionService.
func (c *Client) ChatCompletion(model string) *ChatCompletionService {
return NewChatCompletionService(c).SetModel(model)
Expand Down

0 comments on commit b92d438

Please sign in to comment.