Skip to content

Commit

Permalink
[WIP] Implementing Silence Processor elixir-nx#379
Browse files Browse the repository at this point in the history
  • Loading branch information
tubedude committed Oct 4, 2024
1 parent 3b56c7f commit 0d2c066
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 18 deletions.
60 changes: 43 additions & 17 deletions lib/bumblebee/audio/speech_to_text_whisper.ex
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ defmodule Bumblebee.Audio.SpeechToTextWhisper do
defn_options: [],
preallocate_params: false,
task: :transcribe,
stream: false
stream: false,
logprob_threshold: 0.6,
no_speech_threshold: -1.0
])

%{model: model, params: params, spec: spec} = model_info
Expand Down Expand Up @@ -59,7 +61,7 @@ defmodule Bumblebee.Audio.SpeechToTextWhisper do
context_num_seconds: context_num_seconds
}

{generate_opts, generation_config} = generate_opts(generation_config, opts)
{generate_opts, generation_config} = generate_opts(model_info, generation_config, opts)
generate_fun = Text.Generation.build_generate(model, spec, generation_config, generate_opts)

generate_fun = fn params, {inputs, seed} ->
Expand Down Expand Up @@ -210,27 +212,51 @@ defmodule Bumblebee.Audio.SpeechToTextWhisper do
end
end

defp generate_opts(generation_config, opts) do
defp generate_opts(model_info, generation_config, opts) do
forced_token_ids = forced_token_ids(opts, generation_config.extra_config)
generation_config = %{generation_config | forced_token_ids: forced_token_ids}

logits_processors =
if opts[:timestamps] do
[
&Bumblebee.Text.Generation.LogitsProcessing.whisper_timestamp_processor(&1, &2,
eos_token_id: generation_config.eos_token_id,
forced_token_ids: generation_config.forced_token_ids,
no_timestamps_token_id: generation_config.extra_config.no_timestamps_token_id,
timestamp_begin_id: generation_config.extra_config.no_timestamps_token_id + 1
)
]
else
[]
end
[]
|> add_timestamp_processor(opts, generation_config)
|> add_silence_processor(opts, model_info, generation_config)

opts = [logits_processors: logits_processors]
{[logits_processors: logits_processors], generation_config}
end

{opts, generation_config}
defp add_timestamp_processor(processors, opts, generation_config) do
if opts[:timestamps] do
[
(&Bumblebee.Text.Generation.LogitsProcessing.whisper_timestamp_processor(&1, &2,
eos_token_id: generation_config.eos_token_id,
forced_token_ids: generation_config.forced_token_ids,
no_timestamps_token_id: generation_config.extra_config.no_timestamps_token_id,
timestamp_begin_id: generation_config.extra_config.no_timestamps_token_id + 1
))
| processors
]
else
processors
end
end

defp add_silence_processor(processors, opts, model_info, generation_config) do
no_speech_threshold = Keyword.get(opts, :no_speech_threshold)
logprob_threshold = Keyword.get(opts, :logprob_threshold)

if no_speech_threshold && logprob_threshold do
[
(&Bumblebee.Text.Generation.LogitsProcessing.whisper_silence_processor(&1, &2,
no_speech_threshold: no_speech_threshold,
logprob_threshold: logprob_threshold,
vocab_size: model_info.spec.vocab_size,
suppress_tokens: Nx.tensor(generation_config.suppressed_token_ids)
))
| processors
]
else
processors
end
end

defp forced_token_ids(opts, extra_config) do
Expand Down
64 changes: 64 additions & 0 deletions lib/bumblebee/text/generation/logits_processing.ex
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,70 @@ defmodule Bumblebee.Text.Generation.LogitsProcessing do
end
end

defn whisper_silence_processor(logits, context, opts \\ []) do
opts =
keyword!(opts, [:no_speech_threshold, :logprob_threshold, :vocab_size, :suppress_tokens])

# Convert to tensor
suppress_tokens = opts[:suppress_tokens]
no_speech_threshold = opts[:no_speech_threshold]
logprob_threshold = opts[:logprob_threshold]
vocab_size = opts[:vocab_size]

scores = Axon.Activations.log_softmax(logits)
no_speech_prob = compute_no_speech_probability(logits)
avg_logprob = compute_avg_logprob(scores, context.sequence)

Nx.select(
no_speech_prob > no_speech_threshold and avg_logprob < logprob_threshold,
suppress_logits(logits, vocab_size, suppress_tokens),
logits
)
end

defnp compute_no_speech_probability(logits) do
# In Whisper, the no_speech probability is typically the first token's probability
# We apply softmax to get probabilities from logits
probs = Axon.Activations.log_softmax(logits)
probs[0]
end

defnp compute_avg_logprob(scores, sequence) do
# We need to compute the average log probability of the sequence
# scores should be a list of log probabilities for each token
sequence_length = Nx.size(sequence)

# Sum the log probabilities of the generated tokens
total_logprob = Nx.sum(Nx.take(scores, sequence))

# Compute average log probability
Nx.divide(total_logprob, sequence_length)
end

defnp suppress_logits(logits, vocab_size, suppress_tokens) do
# Create a mask for tokens to suppress
suppress_mask = Nx.broadcast(Nx.tensor(false, type: {:u, 8}), {vocab_size})

# Reshape suppress_tokens to have shape {n, 1}
indices = Nx.new_axis(suppress_tokens, -1)

# Broadcast updates to match the leading dimensions of indices (shape {n})
updates = Nx.broadcast(Nx.tensor(true, type: {:u, 8}), Nx.shape(suppress_tokens))

# Set mask to true for tokens we want to suppress
suppress_mask = Nx.indexed_put(suppress_mask, indices, updates)

# Apply the suppression
suppressed_logits =
Nx.select(
suppress_mask,
Nx.broadcast(Nx.Constants.neg_infinity(Nx.type(logits)), Nx.shape(logits)),
logits
)

suppressed_logits
end

defnp force_timestamp_pair(logits, context, begin_idx, eos_token_id, timestamp_begin_id) do
# Force timestamp tokens to appear in pairs, end followed by
# start, except directly before the EOS token
Expand Down
2 changes: 1 addition & 1 deletion mix.exs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
defmodule Bumblebee.MixProject do
use Mix.Project

@version "0.5.3"
@version "0.5.4"
@description "Pre-trained and transformer Neural Network models in Axon"

def project do
Expand Down

0 comments on commit 0d2c066

Please sign in to comment.