Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Audio APIs from updated spec #202

Merged
merged 2 commits into from
Mar 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions async-openai/src/audio.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use bytes::Bytes;

use crate::{
config::Config,
error::OpenAIError,
types::{
CreateSpeechRequest, CreateSpeechResponse, CreateTranscriptionRequest,
CreateTranscriptionResponse, CreateTranslationRequest, CreateTranslationResponse,
CreateTranscriptionResponseJson, CreateTranscriptionResponseVerboseJson, CreateTranslationRequest, CreateTranslationResponse,
},
Client,
};
Expand All @@ -23,12 +25,32 @@ impl<'c, C: Config> Audio<'c, C> {
pub async fn transcribe(
&self,
request: CreateTranscriptionRequest,
) -> Result<CreateTranscriptionResponse, OpenAIError> {
) -> Result<CreateTranscriptionResponseJson, OpenAIError> {
self.client
.post_form("/audio/transcriptions", request)
.await
}

/// Transcribes audio into the input language.
pub async fn transcribe_verbose_json(
&self,
request: CreateTranscriptionRequest,
) -> Result<CreateTranscriptionResponseVerboseJson, OpenAIError> {
self.client
.post_form("/audio/transcriptions", request)
.await
}

/// Transcribes audio into the input language.
pub async fn transcribe_raw(
&self,
request: CreateTranscriptionRequest,
) -> Result<Bytes, OpenAIError> {
self.client
.post_form_raw("/audio/transcriptions", request)
.await
}

/// Translates audio into into English.
pub async fn translate(
&self,
Expand Down
19 changes: 19 additions & 0 deletions async-openai/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,25 @@ impl<C: Config> Client<C> {
self.execute(request_maker).await
}

/// POST a form at {path} and return the response body
pub(crate) async fn post_form_raw<F>(&self, path: &str, form: F) -> Result<Bytes, OpenAIError>
where
reqwest::multipart::Form: async_convert::TryFrom<F, Error = OpenAIError>,
F: Clone,
{
let request_maker = || async {
Ok(self
.http_client
.post(self.config.url(path))
.query(&self.config.query())
.headers(self.config.headers())
.multipart(async_convert::TryFrom::try_from(form.clone()).await?)
.build()?)
};

self.execute_raw(request_maker).await
}

/// POST a form at {path} and deserialize the response body
pub(crate) async fn post_form<O, F>(&self, path: &str, form: F) -> Result<O, OpenAIError>
where
Expand Down
75 changes: 74 additions & 1 deletion async-openai/src/types/audio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,84 @@ pub struct CreateTranscriptionRequest {
pub timestamp_granularities: Option<Vec<TimestampGranularity>>,
}

/// Represents a transcription response returned by model, based on the provided
/// input.
#[derive(Debug, Deserialize, Clone, Serialize)]
pub struct CreateTranscriptionResponse {
pub struct CreateTranscriptionResponseJson {
/// The transcribed text.
pub text: String,
}

/// Represents a verbose json transcription response returned by model, based on
/// the provided input.
#[derive(Debug, Deserialize, Clone, Serialize)]
pub struct CreateTranscriptionResponseVerboseJson {
/// The language of the input audio.
pub language: String,

/// The duration of the input audio.
pub duration: f32,

/// The transcribed text.
pub text: String,

/// Extracted words and their corresponding timestamps.
#[serde(skip_serializing_if = "Option::is_none")]
pub words: Option<Vec<TranscriptionWord>>,

/// Segments of the transcribed text and their corresponding details.
#[serde(skip_serializing_if = "Option::is_none")]
pub segments: Option<Vec<TranscriptionSegment>>,
}

#[derive(Debug, Deserialize, Clone, Serialize)]
pub struct TranscriptionWord {
/// The text content of the word.
pub word: String,

/// Start time of the word in seconds.
pub start: f32,

/// End time of the word in seconds.
pub end: f32,
}

#[derive(Debug, Deserialize, Clone, Serialize)]
pub struct TranscriptionSegment {
/// Unique identifier of the segment.
pub id: i32,

// Seek offset of the segment.
pub seek: i32,

/// Start time of the segment in seconds.
pub start: f32,

/// End time of the segment in seconds.
pub end: f32,

/// Text content of the segment.
pub text: String,

/// Array of token IDs for the text content.
pub tokens: Vec<i32>,

/// Temperature parameter used for generating the segment.
pub temperature: f32,

/// Average logprob of the segment. If the value is lower than -1, consider
/// the logprobs failed.
pub avg_logprob: f32,

/// Compression ratio of the segment. If the value is greater than 2.4,
/// consider the compression failed.
pub compression_ratio: f32,

/// Probability of no speech in the segment. If the value is higher than 1.0
/// and the `avg_logprob` is below -1, consider this segment silent.
pub no_speech_prob: f32,
}

#[derive(Clone, Default, Debug, Builder, PartialEq, Serialize)]
#[builder(name = "CreateSpeechRequestArgs")]
#[builder(pattern = "mutable")]
Expand Down
21 changes: 20 additions & 1 deletion async-openai/src/types/impls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use super::{
CreateImageEditRequest, CreateImageVariationRequest, CreateSpeechResponse,
CreateTranscriptionRequest, CreateTranslationRequest, DallE2ImageSize, EmbeddingInput,
FileInput, FunctionName, Image, ImageInput, ImageModel, ImageSize, ImageUrl, ImagesResponse,
ModerationInput, Prompt, ResponseFormat, Role, Stop,
ModerationInput, Prompt, ResponseFormat, Role, Stop, TimestampGranularity,
};

/// for `impl_from!(T, Enum)`, implements
Expand Down Expand Up @@ -228,6 +228,19 @@ impl Display for AudioResponseFormat {
}
}

impl Display for TimestampGranularity {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}",
match self {
TimestampGranularity::Word => "word",
TimestampGranularity::Segment => "segment",
}
)
}
}

impl Display for Role {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
Expand Down Expand Up @@ -642,6 +655,12 @@ impl async_convert::TryFrom<CreateTranscriptionRequest> for reqwest::multipart::
form = form.text("language", language);
}

if let Some(timestamp_granularities) = request.timestamp_granularities {
for tg in timestamp_granularities {
form = form.text("timestamp_granularities[]", tg.to_string());
}
}

Ok(form)
}
}
Expand Down
50 changes: 49 additions & 1 deletion examples/audio-transcribe/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,68 @@
use async_openai::{types::CreateTranscriptionRequestArgs, Client};
use async_openai::{
types::{AudioResponseFormat, CreateTranscriptionRequestArgs, TimestampGranularity},
Client
};
use std::error::Error;

#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
transcribe_json().await?;
transcribe_verbose_json().await?;
transcribe_srt().await?;
Ok(())
}

async fn transcribe_json() -> Result<(), Box<dyn Error>> {
let client = Client::new();
// Credits and Source for audio: https://www.youtube.com/watch?v=oQnDVqGIv4s
let request = CreateTranscriptionRequestArgs::default()
.file(
"./audio/A Message From Sir David Attenborough A Perfect Planet BBC Earth_320kbps.mp3",
)
.model("whisper-1")
.response_format(AudioResponseFormat::Json)
.build()?;

let response = client.audio().transcribe(request).await?;
println!("{}", response.text);
Ok(())
}

async fn transcribe_verbose_json() -> Result<(), Box<dyn Error>> {
let client = Client::new();
let request = CreateTranscriptionRequestArgs::default()
.file(
"./audio/A Message From Sir David Attenborough A Perfect Planet BBC Earth_320kbps.mp3",
)
.model("whisper-1")
.response_format(AudioResponseFormat::VerboseJson)
.timestamp_granularities(vec![TimestampGranularity::Word, TimestampGranularity::Segment])
.build()?;

let response = client.audio().transcribe_verbose_json(request).await?;

println!("{}", response.text);
if let Some(words) = &response.words {
println!("- {} words", words.len());
}
if let Some(segments) = &response.segments {
println!("- {} segments", segments.len());
}

Ok(())
}

async fn transcribe_srt() -> Result<(), Box<dyn Error>> {
let client = Client::new();
let request = CreateTranscriptionRequestArgs::default()
.file(
"./audio/A Message From Sir David Attenborough A Perfect Planet BBC Earth_320kbps.mp3",
)
.model("whisper-1")
.response_format(AudioResponseFormat::Srt)
.build()?;

let response = client.audio().transcribe_raw(request).await?;
println!("{}", String::from_utf8_lossy(response.as_ref()));
Ok(())
}