Skip to content

Commit

Permalink
Update Audio APIs from updated spec (#202)
Browse files Browse the repository at this point in the history
* Implement CreateTranscriptRequest::response_granularities

This PR adds support for `AudioResponseFormat::VerboseJson` and
`TimestampGranularity`, including updated example code. These were
defined as types before, but not fully implemented.

Implements #201.

* Modify transcription API to be more like spec

- Rename `CreateTranscriptionRespose` to `CreateTranscriptionResponseJson` (to match API spec)
- Add `CreateTranscriptionResponseVerboseJson` and `transcribe_verbose_json`
- Add `transcribe_raw` for SRT output
- Add `post_form_raw`
- Update example code
  • Loading branch information
emk authored Mar 24, 2024
1 parent e4a428f commit db4c213
Show file tree
Hide file tree
Showing 5 changed files with 186 additions and 5 deletions.
26 changes: 24 additions & 2 deletions async-openai/src/audio.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use bytes::Bytes;

use crate::{
config::Config,
error::OpenAIError,
types::{
CreateSpeechRequest, CreateSpeechResponse, CreateTranscriptionRequest,
CreateTranscriptionResponse, CreateTranslationRequest, CreateTranslationResponse,
CreateTranscriptionResponseJson, CreateTranscriptionResponseVerboseJson, CreateTranslationRequest, CreateTranslationResponse,
},
Client,
};
Expand All @@ -23,12 +25,32 @@ impl<'c, C: Config> Audio<'c, C> {
pub async fn transcribe(
&self,
request: CreateTranscriptionRequest,
) -> Result<CreateTranscriptionResponse, OpenAIError> {
) -> Result<CreateTranscriptionResponseJson, OpenAIError> {
self.client
.post_form("/audio/transcriptions", request)
.await
}

/// Transcribes audio into the input language.
pub async fn transcribe_verbose_json(
&self,
request: CreateTranscriptionRequest,
) -> Result<CreateTranscriptionResponseVerboseJson, OpenAIError> {
self.client
.post_form("/audio/transcriptions", request)
.await
}

/// Transcribes audio into the input language.
pub async fn transcribe_raw(
&self,
request: CreateTranscriptionRequest,
) -> Result<Bytes, OpenAIError> {
self.client
.post_form_raw("/audio/transcriptions", request)
.await
}

/// Translates audio into into English.
pub async fn translate(
&self,
Expand Down
19 changes: 19 additions & 0 deletions async-openai/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,25 @@ impl<C: Config> Client<C> {
self.execute(request_maker).await
}

/// POST a form at {path} and return the response body
pub(crate) async fn post_form_raw<F>(&self, path: &str, form: F) -> Result<Bytes, OpenAIError>
where
reqwest::multipart::Form: async_convert::TryFrom<F, Error = OpenAIError>,
F: Clone,
{
let request_maker = || async {
Ok(self
.http_client
.post(self.config.url(path))
.query(&self.config.query())
.headers(self.config.headers())
.multipart(async_convert::TryFrom::try_from(form.clone()).await?)
.build()?)
};

self.execute_raw(request_maker).await
}

/// POST a form at {path} and deserialize the response body
pub(crate) async fn post_form<O, F>(&self, path: &str, form: F) -> Result<O, OpenAIError>
where
Expand Down
75 changes: 74 additions & 1 deletion async-openai/src/types/audio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,84 @@ pub struct CreateTranscriptionRequest {
pub timestamp_granularities: Option<Vec<TimestampGranularity>>,
}

/// Represents a transcription response returned by model, based on the provided
/// input.
#[derive(Debug, Deserialize, Clone, Serialize)]
pub struct CreateTranscriptionResponse {
pub struct CreateTranscriptionResponseJson {
/// The transcribed text.
pub text: String,
}

/// Represents a verbose json transcription response returned by model, based on
/// the provided input.
#[derive(Debug, Deserialize, Clone, Serialize)]
pub struct CreateTranscriptionResponseVerboseJson {
/// The language of the input audio.
pub language: String,

/// The duration of the input audio.
pub duration: f32,

/// The transcribed text.
pub text: String,

/// Extracted words and their corresponding timestamps.
#[serde(skip_serializing_if = "Option::is_none")]
pub words: Option<Vec<TranscriptionWord>>,

/// Segments of the transcribed text and their corresponding details.
#[serde(skip_serializing_if = "Option::is_none")]
pub segments: Option<Vec<TranscriptionSegment>>,
}

#[derive(Debug, Deserialize, Clone, Serialize)]
pub struct TranscriptionWord {
/// The text content of the word.
pub word: String,

/// Start time of the word in seconds.
pub start: f32,

/// End time of the word in seconds.
pub end: f32,
}

#[derive(Debug, Deserialize, Clone, Serialize)]
pub struct TranscriptionSegment {
/// Unique identifier of the segment.
pub id: i32,

// Seek offset of the segment.
pub seek: i32,

/// Start time of the segment in seconds.
pub start: f32,

/// End time of the segment in seconds.
pub end: f32,

/// Text content of the segment.
pub text: String,

/// Array of token IDs for the text content.
pub tokens: Vec<i32>,

/// Temperature parameter used for generating the segment.
pub temperature: f32,

/// Average logprob of the segment. If the value is lower than -1, consider
/// the logprobs failed.
pub avg_logprob: f32,

/// Compression ratio of the segment. If the value is greater than 2.4,
/// consider the compression failed.
pub compression_ratio: f32,

/// Probability of no speech in the segment. If the value is higher than 1.0
/// and the `avg_logprob` is below -1, consider this segment silent.
pub no_speech_prob: f32,
}

#[derive(Clone, Default, Debug, Builder, PartialEq, Serialize)]
#[builder(name = "CreateSpeechRequestArgs")]
#[builder(pattern = "mutable")]
Expand Down
21 changes: 20 additions & 1 deletion async-openai/src/types/impls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use super::{
CreateImageEditRequest, CreateImageVariationRequest, CreateSpeechResponse,
CreateTranscriptionRequest, CreateTranslationRequest, DallE2ImageSize, EmbeddingInput,
FileInput, FunctionName, Image, ImageInput, ImageModel, ImageSize, ImageUrl, ImagesResponse,
ModerationInput, Prompt, ResponseFormat, Role, Stop,
ModerationInput, Prompt, ResponseFormat, Role, Stop, TimestampGranularity,
};

/// for `impl_from!(T, Enum)`, implements
Expand Down Expand Up @@ -228,6 +228,19 @@ impl Display for AudioResponseFormat {
}
}

impl Display for TimestampGranularity {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}",
match self {
TimestampGranularity::Word => "word",
TimestampGranularity::Segment => "segment",
}
)
}
}

impl Display for Role {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
Expand Down Expand Up @@ -642,6 +655,12 @@ impl async_convert::TryFrom<CreateTranscriptionRequest> for reqwest::multipart::
form = form.text("language", language);
}

if let Some(timestamp_granularities) = request.timestamp_granularities {
for tg in timestamp_granularities {
form = form.text("timestamp_granularities[]", tg.to_string());
}
}

Ok(form)
}
}
Expand Down
50 changes: 49 additions & 1 deletion examples/audio-transcribe/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,68 @@
use async_openai::{types::CreateTranscriptionRequestArgs, Client};
use async_openai::{
types::{AudioResponseFormat, CreateTranscriptionRequestArgs, TimestampGranularity},
Client
};
use std::error::Error;

#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
transcribe_json().await?;
transcribe_verbose_json().await?;
transcribe_srt().await?;
Ok(())
}

async fn transcribe_json() -> Result<(), Box<dyn Error>> {
let client = Client::new();
// Credits and Source for audio: https://www.youtube.com/watch?v=oQnDVqGIv4s
let request = CreateTranscriptionRequestArgs::default()
.file(
"./audio/A Message From Sir David Attenborough A Perfect Planet BBC Earth_320kbps.mp3",
)
.model("whisper-1")
.response_format(AudioResponseFormat::Json)
.build()?;

let response = client.audio().transcribe(request).await?;
println!("{}", response.text);
Ok(())
}

async fn transcribe_verbose_json() -> Result<(), Box<dyn Error>> {
let client = Client::new();
let request = CreateTranscriptionRequestArgs::default()
.file(
"./audio/A Message From Sir David Attenborough A Perfect Planet BBC Earth_320kbps.mp3",
)
.model("whisper-1")
.response_format(AudioResponseFormat::VerboseJson)
.timestamp_granularities(vec![TimestampGranularity::Word, TimestampGranularity::Segment])
.build()?;

let response = client.audio().transcribe_verbose_json(request).await?;

println!("{}", response.text);
if let Some(words) = &response.words {
println!("- {} words", words.len());
}
if let Some(segments) = &response.segments {
println!("- {} segments", segments.len());
}

Ok(())
}

async fn transcribe_srt() -> Result<(), Box<dyn Error>> {
let client = Client::new();
let request = CreateTranscriptionRequestArgs::default()
.file(
"./audio/A Message From Sir David Attenborough A Perfect Planet BBC Earth_320kbps.mp3",
)
.model("whisper-1")
.response_format(AudioResponseFormat::Srt)
.build()?;

let response = client.audio().transcribe_raw(request).await?;
println!("{}", String::from_utf8_lossy(response.as_ref()));
Ok(())
}

0 comments on commit db4c213

Please sign in to comment.