Skip to content

Commit

Permalink
feat: add image generation service
Browse files Browse the repository at this point in the history
  • Loading branch information
yankeguo committed Jun 26, 2024
1 parent 2dcd82a commit 9f3f54f
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 2 deletions.
18 changes: 16 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ go get -u github.com/yankeguo/zhipu

```go
// this will use environment variables ZHIPUAI_API_KEY
client := zhipu.NewClient()
client, err := zhipu.NewClient()
// or you can specify the API key
client = zhipu.NewClient(zhipu.WithAPIKey("your api key"))
client, err = zhipu.NewClient(zhipu.WithAPIKey("your api key"))
```

### Use the client
Expand Down Expand Up @@ -64,6 +64,20 @@ if err != nil {
}
```

**Embedding**

```go
service := client.EmbeddingService("embedding-v2").SetInput("你好呀")
service.Do(context.Background())
```

**Image Generation**

```go
service := client.ImageGenerationService("cogview-3").SetPrompt("一只可爱的小猫咪")
service.Do(context.Background())
```

> [!NOTE]
>
> More APIs are coming soon.
Expand Down
72 changes: 72 additions & 0 deletions image_generation.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package zhipu

import (
"context"

"github.com/go-resty/resty/v2"
)

type ImageGenerationService struct {
client *Client

model string
prompt string
userID string
}

type ImageGenerationResponse struct {
Created int64 `json:"created"`
Data []ImageURL `json:"data"`
}

func (c *Client) ImageGenerationService(model string) *ImageGenerationService {
return &ImageGenerationService{
client: c,
model: model,
}
}

// SetModel sets the model parameter
func (s *ImageGenerationService) SetModel(model string) *ImageGenerationService {
s.model = model
return s
}

// SetPrompt sets the prompt parameter
func (s *ImageGenerationService) SetPrompt(prompt string) *ImageGenerationService {
s.prompt = prompt
return s
}

// SetUserID sets the userID parameter
func (s *ImageGenerationService) SetUserID(userID string) *ImageGenerationService {
s.userID = userID
return s
}

func (s *ImageGenerationService) Do(ctx context.Context) (res ImageGenerationResponse, err error) {
var (
resp *resty.Response
apiError APIError
)

body := M{
"model": s.model,
"prompt": s.prompt,
}

if s.userID != "" {
body["user_id"] = s.userID
}

if resp, err = s.client.R(ctx).SetBody(body).SetResult(&res).SetError(&apiError).Post("images/generations"); err != nil {
return
}

if resp.IsError() {
err = apiError
return
}

return
}
21 changes: 21 additions & 0 deletions image_generation_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package zhipu

import (
"context"
"testing"

"github.com/stretchr/testify/require"
)

func TestImageGenerationService(t *testing.T) {
client, err := NewClient()
require.NoError(t, err)

s := client.ImageGenerationService("cogview-3")
s.SetPrompt("一只可爱的小猫")

res, err := s.Do(context.Background())
require.NoError(t, err)
require.NotEmpty(t, res.Data)
t.Log(res.Data[0].URL)
}

0 comments on commit 9f3f54f

Please sign in to comment.