diff --git a/t5x/export_lib.py b/t5x/export_lib.py index bb4d2720e..8a989f32a 100644 --- a/t5x/export_lib.py +++ b/t5x/export_lib.py @@ -21,6 +21,8 @@ import json import os import os.path +import random +import string import typing from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union @@ -1075,7 +1077,7 @@ def create_postprocessor( if inference_mode == 'predict': def postprocessor( - values: Tuple[Any, Any] + values: Tuple[Any, Any], ) -> Union[Tuple[Any, Any], Mapping[str, Any]]: tokens, scores = values if decode_outputs: @@ -1175,22 +1177,96 @@ def _request_to_prediction_log( def generate_examples_with_sequence_lengths( - sequence_lengths: list[int], single_token_example: WarmupExample = 'Q' + sequence_lengths: list[int], + single_token_example: WarmupExample | None = None, + vocabulary: seqio.Vocabulary | None = None, + character_set: str | Sequence[str] = string.ascii_letters + string.digits, + leave_room_for_eos: bool = False, + chars_per_token_upper_bound: int = 10, ) -> list[WarmupExamples]: - """Creates synthetic sequences of specified sizes by repeating a single token. + """Creates synthetic sequences that have the requested number of tokens. + + The examples will be computed by one of the following methods: + - If `single_token_example` is set: repeat a single token that is known to + be always `N` tokens long when repeated `N` times. + - If `vocabulary` is set: generate a random string that is `N` tokens long, as + measured by `vocabulary`. Args: sequence_lengths: The sequence lengths to generate examples for. single_token_example: An example such that `N*ex` is always `N` tokens long. This is used to build sequences of a specified size. Defaults to `'Q'`, - which satisfies this property for the tokenizer used by pretrained T5X - models. + which satisfies this property for the tokenizer used by pretrained English + T5X models. **NOTE**: This is brittle to variations in the tokenizer, so + prefer using `vocabulary` instead. + vocabulary: The seqio.Vocabulary used by the model. + character_set: The set of characters to use when generating random strings. + Defaults to letters and digits. + leave_room_for_eos: Whether the model will add EOS after the example. If + true, the generated examples will be one token shorter than the requested + length. + chars_per_token_upper_bound: The upper bound for the amount of characters + contained in a single token for this vocabulary. This determines how large + of a random string to start with when trying to generate a specific number + of tokens. Returns: A list of WarmupExamples batches with lengths in tokens equal to `sequence_lengths`. """ - return [[single_token_example * l] for l in sequence_lengths] + if leave_room_for_eos: + sequence_lengths = [l - 1 for l in sequence_lengths] + if single_token_example and vocabulary: + raise ValueError( + 'Only one of `single_token_example` and `vocabulary` can be set.' + ) + elif vocabulary: + # Generate a random string that is exactly `N` tokens long, as measured by + # the provided tokenizer. + # TODO: b/331419045 - Add support for models with pretokenized inputs + # (dtype=tf.int32). + def _generate_example(num_tokens: int) -> WarmupExamples: + if num_tokens == 0: + return [''] + random_string = ''.join( + random.choice(character_set) + for _ in range(num_tokens * chars_per_token_upper_bound) + ) + all_ids = vocabulary.encode(random_string) + if len(all_ids) < num_tokens: + # Even if chars_per_token_upper_bound is set high enough, this can + # happen with unknown tokens. See b/294826076#comment4 (encoding Chinese + # characters with an English tokenizer). + raise ValueError( + 'Generated a random warmup example that is shorter than' + f' {num_tokens} tokens. Make sure the characters in character_set' + ' are valid in the vocabulary, or increase' + ' chars_per_token_upper_bound.' + ) + for start_index in range(len(all_ids) - num_tokens): + # Truncating may return an empty string for some IDs, for example + # vocabulary.decode([3]), so search for an ID subarray whose + # resulting string actually has the correct length. + truncated_ids = all_ids[start_index : start_index + num_tokens] + example = vocabulary.decode(truncated_ids) + generated_num_tokens = len(vocabulary.encode(example)) + if generated_num_tokens == num_tokens: + return [example] + raise ValueError( + f'Could not generate a valid string with {num_tokens} tokens. This' + ' may happen if the characters in character_set do not represent the' + ' vocabulary, or if chars_per_token_upper_bound is too small.' + ) + + return [_generate_example(length) for length in sequence_lengths] + else: + single_token_example = single_token_example or 'Q' + logging.warning( + 'Using single_token_example to generate warmup examples is brittle to' + ' variations in the tokenizer. Prefer explicitly passing a vocabulary' + ' instead.' + ) + return [[single_token_example * l] for l in sequence_lengths] def write_warmup_examples( @@ -1240,10 +1316,11 @@ def write_warmup_examples( input_tensor_dtype: The dtype of the input tensor. """ if generate_examples_fn: - logging.warning( - 'Ignoring provided warmup batch. Using `generate_examples_fn` to' - ' generate warmup examples instead.' - ) + if text_batch: + logging.warning( + 'Ignoring provided warmup batch. Using `generate_examples_fn` to' + ' generate warmup examples instead.' + ) warmup_examples = generate_examples_fn() else: warmup_examples = [text_batch]