diff --git a/guides/img/getting_started/class-diagram.png b/guides/img/getting_started/class-diagram.png new file mode 100644 index 0000000000..7f14614d5a Binary files /dev/null and b/guides/img/getting_started/class-diagram.png differ diff --git a/guides/img/getting_started/getting_started_11_1.png b/guides/img/getting_started/getting_started_11_1.png new file mode 100644 index 0000000000..eb146c8420 Binary files /dev/null and b/guides/img/getting_started/getting_started_11_1.png differ diff --git a/guides/img/getting_started/getting_started_55_1.png b/guides/img/getting_started/getting_started_55_1.png new file mode 100644 index 0000000000..1082b84e2c Binary files /dev/null and b/guides/img/getting_started/getting_started_55_1.png differ diff --git a/guides/ipynb/keras_hub/getting_started.ipynb b/guides/ipynb/keras_hub/getting_started.ipynb index e585a8a7e6..4b2f7d731f 100644 --- a/guides/ipynb/keras_hub/getting_started.ipynb +++ b/guides/ipynb/keras_hub/getting_started.ipynb @@ -8,9 +8,9 @@ "source": [ "# Getting Started with KerasHub\n", "\n", - "**Author:** [Jonathan Bischof](https://github.com/jbischof)
\n", + "**Author:** [Matthew Watson](https://github.com/mattdangerw/), [Jonathan Bischof](https://github.com/jbischof)
\n", "**Date created:** 2022/12/15
\n", - "**Last modified:** 2023/07/01
\n", + "**Last modified:** 2024/10/17
\n", "**Description:** An introduction to the KerasHub API." ] }, @@ -20,38 +20,44 @@ "colab_type": "text" }, "source": [ - "## Introduction\n", + "**KerasHub** is a pretrained modeling library that aims to be simple, flexible, and fast.\n", + "The library provides [Keras 3](https://keras.io/keras_3/) implementations of popular\n", + "model architectures, paired with a collection of pretrained checkpoints available on\n", + "[Kaggle](https://www.kaggle.com/organizations/keras/models). Models can be used for both\n", + "training and inference on any of the TensorFlow, Jax, and Torch backends.\n", "\n", - "KerasHub is a natural language processing library that supports users through\n", - "their entire development cycle. Our workflows are built from modular components\n", - "that have state-of-the-art preset weights and architectures when used\n", - "out-of-the-box and are easily customizable when more control is needed.\n", + "KerasHub is an extension of the core Keras API; KerasHub components are provided as\n", + "`keras.Layer`s and `keras.Model`s. If you are familiar with Keras, congratulations! You\n", + "already understand most of KerasHub.\n", "\n", - "This library is an extension of the core Keras API; all high-level modules are\n", - "[`Layers`](/api/layers/) or [`Models`](/api/models/). If you are familiar with Keras,\n", - "congratulations! You already understand most of KerasHub.\n", + "This guide is meant to be an accessible introduction to the entire library. We will start\n", + "by using high-level APIs to classify images and generate text, then progressively show\n", + "deeper customization of models and training. Throughout the guide, we use Professor Keras,\n", + "the official Keras mascot, as a visual reference for the complexity of the material:\n", "\n", - "KerasHub uses Keras 3 to work with any of TensorFlow, Pytorch and Jax. In the\n", - "guide below, we will use the `jax` backend for training our models, and\n", - "[tf.data](https://www.tensorflow.org/guide/data) for efficiently running our\n", - "input preprocessing. But feel free to mix things up! This guide runs in\n", - "TensorFlow or PyTorch backends with zero changes, simply update the\n", - "`KERAS_BACKEND` below.\n", + "![](https://storage.googleapis.com/keras-nlp/getting_started_guide/prof_keras_evolution.png)\n", "\n", - "This guide demonstrates our modular approach using a sentiment analysis example at six\n", - "levels of complexity:\n", - "\n", - "* Inference with a pretrained classifier\n", - "* Fine tuning a pretrained backbone\n", - "* Fine tuning with user-controlled preprocessing\n", - "* Fine tuning a custom model\n", - "* Pretraining a backbone model\n", - "* Build and train your own transformer from scratch\n", - "\n", - "Throughout our guide, we use Professor Keras, the official Keras mascot, as a visual\n", - "reference for the complexity of the material:\n", - "\n", - "\"drawing\"" + "As always, we'll keep our Keras guides focused on real-world code examples. You can play\n", + "with the code here at any time by clicking the Colab link at the top of the guide." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## Installation and Setup" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "To begin, let's install keras-hub. The library is available on PyPI, so we can simply\n", + "install it with pip." ] }, { @@ -62,8 +68,22 @@ }, "outputs": [], "source": [ - "!pip install -q --upgrade keras-hub\n", - "!pip install -q --upgrade keras # Upgrade to Keras 3." + "!pip install --upgrade --quiet keras-hub-nightly keras-nightly" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "Keras 3 was built to work on top of TensorFlow, Jax, and Torch backends. You should\n", + "specify the backend first thing when writing Keras code, before any library imports. We\n", + "will use the Jax backend for this guide, but you can use `torch` or `tensorflow` without\n", + "changing a single line in the rest of this guide. That's the power of Keras 3!\n", + "\n", + "We will also set `XLA_PYTHON_CLIENT_MEM_FRACTION`, which frees up the whole GPU for\n", + "Jax to use from the start." ] }, { @@ -77,12 +97,7 @@ "import os\n", "\n", "os.environ[\"KERAS_BACKEND\"] = \"jax\" # or \"tensorflow\" or \"torch\"\n", - "\n", - "import keras_hub\n", - "import keras\n", - "\n", - "# Use mixed precision to speed up all training in this guide.\n", - "keras.mixed_precision.set_global_policy(\"mixed_float16\")" + "os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"] = \"1.0\"" ] }, { @@ -91,52 +106,85 @@ "colab_type": "text" }, "source": [ - "## API quickstart\n", - "\n", - "Our highest level API is `keras_hub.models`. These symbols cover the complete user\n", - "journey of converting strings to tokens, tokens to dense features, and dense features to\n", - "task-specific output. For each `XX` architecture (e.g., `Bert`), we offer the following\n", - "modules:\n", - "\n", - "* **Tokenizer**: `keras_hub.models.XXTokenizer`\n", + "Lastly, we need to do some extra setup to access the models used in this guide. Many\n", + "popular open LLMs, such as Gemma from Google and Llama from Meta, require accepting\n", + "a community license before accessing the model weights. We will be using Gemma in this\n", + "guide, so we can follow the following steps:\n", + "\n", + "1. Go to the [Gemma 2](https://www.kaggle.com/models/keras/gemma2) model page, and accept\n", + " the license at the banner at the top.\n", + "2. Generate an Kaggle API key by going to [Kaggle settings](https://www.kaggle.com/settings)\n", + " and clicking \"Create New Token\" button under the \"API\" section.\n", + "3. Inside your colab notebook, click on the key icon on the left hand toolbar. Add two\n", + " secrets: `KAGGLE_USERNAME` with your username, and `KAGGLE_KEY` with the API key you just\n", + " created. Make these secrets visible to the notebook you are running." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## API Quickstart\n", + "\n", + "Before we begin, let's take a look at the key classes we will use in the KerasHub library.\n", + "\n", + "* **Task**: e.g., `keras_hub.models.CausalLM`, `keras_hub.models.ImageClassifier`, and\n", + "`keras_hub.models.TextClassifier`.\n", + " * **What it does**: A task maps from raw image, audio, and text inputs to model\n", + " predictions.\n", + " * **Why it's important**: A task is the highest-level entry point to the KerasHub API. It\n", + " encapsulates both preprocessing and modeling into a single, easy-to-use class. Tasks can\n", + " be used both for fine-tuning and inference.\n", + " * **Has a**: `backbone` and `preprocessor`.\n", + " * **Inherits from**: `keras.Model`.\n", + "* **Backbone**: `keras_hub.models.Backbone`.\n", + " * **What it does**: Maps preprocessed tensor inputs to the latent space of the model.\n", + " * **Why it's important**: The backbone encapsulates the architecture and parameters of a\n", + " pretrained models in a way that is unspecialized to any particular task. A backbone can\n", + " be combined with arbitrary preprocessing and \"head\" layers mapping dense features to\n", + " predictions to accomplish any ML task.\n", + " * **Inherits from**: `keras.Model`.\n", + "* **Preprocessor**: e.g.,`keras_hub.models.CausalLMPreprocessor`,\n", + " `keras_hub.models.ImageClassifierPreprocessor`, and\n", + " `keras_hub.models.TextClassifierPreprocessor`.\n", + " * **What it does**: A preprocessor maps from raw image, audio and text inputs to\n", + " preprocessed tensor inputs.\n", + " * **Why it's important**: A preprocessing layer encapsulates all tasks specific\n", + " preprocessing, e.g. image resizing and text tokenization, in a way that can be used\n", + " standalone to precompute preprocessed inputs. Note that if you are using a high-level\n", + " task class, this preprocessing is already baked in by default.\n", + " * **Has a**: `tokenizer`, `audio_converter`, and/or `image_converter`.\n", + " * **Inherits from**: `keras.layers.Layer`.\n", + "* **Tokenizer**: `keras_hub.tokenizers.Tokenizer`.\n", " * **What it does**: Converts strings to sequences of token ids.\n", - " * **Why it's important**: The raw bytes of a string are too high dimensional to be useful\n", - " features so we first map them to a small number of tokens, for example `\"The quick brown\n", - " fox\"` to `[\"the\", \"qu\", \"##ick\", \"br\", \"##own\", \"fox\"]`.\n", + " * **Why it's important**: The raw bytes of a string are an inefficient representation of\n", + " text input, so we first map string inputs to integer token ids. This class encapsulated\n", + " the mapping of strings to ints and the reverse (via the `detokenize()` method).\n", " * **Inherits from**: `keras.layers.Layer`.\n", - "* **Preprocessor**: `keras_hub.models.XXPreprocessor`\n", - " * **What it does**: Converts strings to a dictionary of preprocessed tensors consumed by\n", - " the backbone, starting with tokenization.\n", - " * **Why it's important**: Each model uses special tokens and extra tensors to understand\n", - " the input such as delimiting input segments and identifying padding tokens. Padding each\n", - " sequence to the same length improves computational efficiency.\n", - " * **Has a**: `XXTokenizer`.\n", + "* **ImageConverter**: `keras_hub.layers.ImageConverter`.\n", + " * **What it does**: Resizes and rescales image input.\n", + " * **Why it's important**: Image models often need to normalize image inputs to a specific\n", + " range, or resizing inputs to a specific size. This class encapsulates the image-specific\n", + " preprocessing.\n", + " * **Inherits from**: `keras.layers.Layer`.\n", + "* **AudioConveter**: `keras_hub.layers.AudioConveter`.\n", + " * **What it does**: Converts raw audio to model ready input.\n", + " * **Why it's important**: Audio models often need to preprocess raw audio input before\n", + " passing it to a model, e.g. by computing a spectrogram of the audio signal. This class\n", + " encapsulates the image specific preprocessing in an easy to use layer.\n", " * **Inherits from**: `keras.layers.Layer`.\n", - "* **Backbone**: `keras_hub.models.XXBackbone`\n", - " * **What it does**: Converts preprocessed tensors to dense features. *Does not handle\n", - " strings; call the preprocessor first.*\n", - " * **Why it's important**: The backbone distills the input tokens into dense features that\n", - " can be used in downstream tasks. It is generally pretrained on a language modeling task\n", - " using massive amounts of unlabeled data. Transferring this information to a new task is a\n", - " major breakthrough in modern NLP.\n", - " * **Inherits from**: `keras.Model`.\n", - "* **Task**: e.g., `keras_hub.models.XXClassifier`\n", - " * **What it does**: Converts strings to task-specific output (e.g., classification\n", - " probabilities).\n", - " * **Why it's important**: Task models combine string preprocessing and the backbone model\n", - " with task-specific `Layers` to solve a problem such as sentence classification, token\n", - " classification, or text generation. The additional `Layers` must be fine-tuned on labeled\n", - " data.\n", - " * **Has a**: `XXBackbone` and `XXPreprocessor`.\n", - " * **Inherits from**: `keras.Model`.\n", "\n", - "Here is the modular hierarchy for `BertClassifier` (all relationships are compositional):\n", + "All of the classes listed here have a `from_preset()` constructor, which will instantiate\n", + "the component with weights and state for the given pre-trained model identifier. E.g.\n", + "`keras_hub.tokenizers.Tokenizer.from_preset(\"gemma2_2b_en\")` will create a layer that\n", + "tokenizes text using a Gemma2 tokenizer vocabulary.\n", "\n", - "\"drawing\"\n", + "The figure below shows how all these core classes interact. Arrow indicate composition\n", + "not inheritance (e.g., a task *has a* backbone).\n", "\n", - "All modules can be used independently and have a `from_preset()` method in addition to\n", - "the standard constructor that instantiates the class with **preset** architecture and\n", - "weights (see examples below)." + "![png](/img/guides/getting_started/class-diagram.png)" ] }, { @@ -145,14 +193,19 @@ "colab_type": "text" }, "source": [ - "## Data\n", + "## Classify an image\n", "\n", - "We will use a running example of sentiment analysis of IMDB movie reviews. In this task,\n", - "we use the text to predict whether the review was positive (`label = 1`) or negative\n", - "(`label = 0`).\n", - "\n", - "We load the data using `keras.utils.text_dataset_from_directory`, which utilizes the\n", - "powerful `tf.data.Dataset` format for examples." + "![](https://storage.googleapis.com/keras-nlp/getting_started_guide/prof_keras_beginner.png)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "Enough setup! Let's have some fun with pre-trained models. Let's load a test image of a\n", + "California Quail and classify it." ] }, { @@ -163,10 +216,26 @@ }, "outputs": [], "source": [ - "!curl -O https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz\n", - "!tar -xf aclImdb_v1.tar.gz\n", - "!# Remove unsupervised examples\n", - "!rm -r aclImdb/train/unsup" + "import keras\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "\n", + "image_url = \"https://upload.wikimedia.org/wikipedia/commons/a/aa/California_quail.jpg\"\n", + "image_path = keras.utils.get_file(origin=image_url)\n", + "image = keras.utils.load_img(image_path)\n", + "plt.imshow(image)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "We can use a ResNet vision model trained on the ImageNet-1k database. This model will\n", + "give each input sample and output label from `[0, 1000)`, where each label corresponds to\n", + "some real word entity, like a \"milk can\" or a \"porcupine.\" The dataset actually has a\n", + "specific label for quail, at index 85. Let's download the model and predict a label." ] }, { @@ -177,20 +246,16 @@ }, "outputs": [], "source": [ - "BATCH_SIZE = 16\n", - "imdb_train = keras.utils.text_dataset_from_directory(\n", - " \"aclImdb/train\",\n", - " batch_size=BATCH_SIZE,\n", - ")\n", - "imdb_test = keras.utils.text_dataset_from_directory(\n", - " \"aclImdb/test\",\n", - " batch_size=BATCH_SIZE,\n", - ")\n", + "import keras_hub\n", "\n", - "# Inspect first review\n", - "# Format is (review text tensor, label tensor)\n", - "print(imdb_train.unbatch().take(1).get_single_element())\n", - "" + "image_classifier = keras_hub.models.ImageClassifier.from_preset(\n", + " \"resnet_50_imagenet\",\n", + " activation=\"softmax\",\n", + ")\n", + "batch = np.array([image])\n", + "image_classifier.preprocessor.image_size = (224, 224)\n", + "preds = image_classifier.predict(batch)\n", + "preds.shape" ] }, { @@ -199,16 +264,8 @@ "colab_type": "text" }, "source": [ - "## Inference with a pretrained classifier\n", - "\n", - "\"drawing\"\n", - "\n", - "The highest level module in KerasHub is a **task**. A **task** is a `keras.Model`\n", - "consisting of a (generally pretrained) **backbone** model and task-specific layers.\n", - "Here's an example using `keras_hub.models.BertClassifier`.\n", - "\n", - "**Note**: Outputs are the logits per class (e.g., `[0, 0]` is 50% chance of positive). The output is\n", - "[negative, positive] for binary classification." + "These ImageNet labels aren't a particularly \"human readable,\" so we can use a built-in\n", + "utility function to decode the predictions to a set of class names." ] }, { @@ -219,9 +276,7 @@ }, "outputs": [], "source": [ - "classifier = keras_hub.models.BertClassifier.from_preset(\"bert_tiny_en_uncased_sst2\")\n", - "# Note: batched inputs expected so must wrap string in iterable\n", - "classifier.predict([\"I love modular workflows in keras-hub!\"])" + "keras_hub.utils.decode_imagenet_predictions(preds)" ] }, { @@ -230,22 +285,44 @@ "colab_type": "text" }, "source": [ - "All **tasks** have a `from_preset` method that constructs a `keras.Model` instance with\n", - "preset preprocessing, architecture and weights. This means that we can pass raw strings\n", - "in any format accepted by a `keras.Model` and get output specific to our task.\n", - "\n", - "This particular **preset** is a `\"bert_tiny_uncased_en\"` **backbone** fine-tuned on\n", - "`sst2`, another movie review sentiment analysis (this time from Rotten Tomatoes). We use\n", - "the `tiny` architecture for demo purposes, but larger models are recommended for SoTA\n", - "performance. For all the task-specific presets available for `BertClassifier`, see\n", - "our keras.io [models page](https://keras.io/api/keras_hub/models/).\n", - "\n", - "Let's evaluate our classifier on the IMDB dataset. You will note we don't need to\n", - "call `keras.Model.compile` here. All **task** models like `BertClassifier` ship with\n", - "compilation defaults, meaning we can just call `keras.Model.evaluate` directly. You\n", - "can always call compile as normal to override these defaults (e.g. to add new metrics).\n", - "\n", - "The output below is [loss, accuracy]," + "Looking good! The model weights successfully downloaded, and we predicted the\n", + "correct classification label for our quail image with near certainty.\n", + "\n", + "This was our first example of the high-level **task** API mentioned in the API quickstart\n", + "above. An `keras_hub.models.ImageClassifier` is a task for classifying images, and can be\n", + "used with a number of different model architectures (ResNet, VGG, MobileNet, etc). You\n", + "can view the full list of models shipped directly by the Keras team on\n", + "[Kaggle](https://www.kaggle.com/organizations/keras/models).\n", + "\n", + "A task is just a subclass of `keras.Model` \u2014 you can use `fit()`, `compile()`, and\n", + "`save()` on our `classifier` object same as any other model. But tasks come with a few\n", + "extras provided by the KerasHub library. The first and most important is `from_preset()`,\n", + "a special constructor you will see on many classes in KerasHub.\n", + "\n", + "A **preset** is a directory of model state. It defines both the architecture we should\n", + "load and the pretrained weights that go with it. `from_preset()` allows us to load\n", + "**preset** directories from a number of different locations:\n", + "\n", + "- A local directory.\n", + "- The Kaggle Model hub.\n", + "- The HuggingFace model hub.\n", + "\n", + "You can take a look at the `keras_hub.models.ImageClassifier.from_preset` docs to better\n", + "understand all the options when constructing a Keras model from a preset.\n", + "\n", + "All tasks use two main sub-objects. A `keras_hub.models.Backbone` and a\n", + "`keras_hub.layers.Preprocessor`. You might be familiar already with the term **backbone**\n", + "from computer vision, where it is often used to describe a feature extractor network that\n", + "maps images to a latent space. A KerasHub backbone is this concept generalized, we use it\n", + "to refer to any pretrained model without a task-specific head. That is, a KerasHub\n", + "backbone maps raw images, audio and text (or a combination of these inputs) to a\n", + "pretrained model's latent space. We can then map this latent space to any number of task\n", + "specific outputs, depending on what we are trying to do with the model.\n", + "\n", + "A **preprocessor** is just a Keras layer that does all the preprocessing for a specific\n", + "task. In our case, preprocessing with will resize our input image and rescale it to the\n", + "range `[0, 1]` using some ImageNet specific mean and variance data. Let's call our\n", + "task's preprocessor and backbone in succession to see what happens to our input shape." ] }, { @@ -256,7 +333,11 @@ }, "outputs": [], "source": [ - "classifier.evaluate(imdb_test)" + "print(\"Raw input shape:\", batch.shape)\n", + "resized_batch = image_classifier.preprocessor(batch)\n", + "print(\"Preprocessed input shape:\", resized_batch.shape)\n", + "hidden_states = image_classifier.backbone(resized_batch)\n", + "print(\"Latent space shape:\", hidden_states.shape)" ] }, { @@ -265,7 +346,27 @@ "colab_type": "text" }, "source": [ - "Our result is 78% accuracy without training anything. Not bad!" + "Our raw image is rescaled to `(224, 224)` during preprocessing and finally\n", + "downscaled to a `(7, 7)` image of 2048 feature vectors \u2014 the latent space of the\n", + "ResNet model. Note that ResNet can actually handle images of arbitrary sizes,\n", + "though performance will eventually fall off if your image is very different\n", + "sized than the pretrained data. If you'd like to disable the resizing in the\n", + "preprocessing layer, you can run `image_classifier.preprocessor.image_size = None`.\n", + "\n", + "If you are ever wondering the exact structure of the task you loaded, you can\n", + "use `model.summary()` same as any Keras model. The model summary for tasks will\n", + "included extra information on model preprocessing." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "image_classifier.summary()" ] }, { @@ -274,25 +375,27 @@ "colab_type": "text" }, "source": [ - "## Fine tuning a pretrained BERT backbone\n", - "\n", - "\"drawing\"\n", - "\n", - "When labeled text specific to our task is available, fine-tuning a custom classifier can\n", - "improve performance. If we want to predict IMDB review sentiment, using IMDB data should\n", - "perform better than Rotten Tomatoes data! And for many tasks, no relevant pretrained model\n", - "will be available (e.g., categorizing customer reviews).\n", - "\n", - "The workflow for fine-tuning is almost identical to above, except that we request a\n", - "**preset** for the **backbone**-only model rather than the entire classifier. When passed\n", - "a **backbone** **preset**, a **task** `Model` will randomly initialize all task-specific\n", - "layers in preparation for training. For all the **backbone** presets available for\n", - "`BertClassifier`, see our keras.io [models page](https://keras.io/api/keras_hub/models/).\n", + "## Generate text with an LLM\n", "\n", - "To train your classifier, use `keras.Model.fit` as with any other\n", - "`keras.Model`. As with our inference example, we can rely on the compilation\n", - "defaults for the **task** and skip `keras.Model.compile`. As preprocessing is\n", - "included, we again pass the raw data." + "![](https://storage.googleapis.com/keras-nlp/getting_started_guide/prof_keras_intermediate.png)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "Next up, let's try working with and generating text. The task we can use when generating\n", + "text is `keras_hub.models.CausalLM` (where LM is short for **L**anguage **M**odel). Let's\n", + "download the 2 billion parameter Gemma 2 model and try it out.\n", + "\n", + "Since this is about 100x larger model than the ResNet model we just downloaded, we need to be\n", + "a little more careful about our GPU memory usage. We can use a half-precision type to\n", + "load each parameter of our ~2.5 billion as a two-byte float instead of four. To do this\n", + "we can pass `dtype` to the `from_preset()` constructor. `from_preset()` will forward any\n", + "kwargs to the main constructor for the class, so you can pass kwargs that work on all\n", + "Keras layers like `dtype`, `trainable`, and `name`." ] }, { @@ -303,14 +406,9 @@ }, "outputs": [], "source": [ - "classifier = keras_hub.models.BertClassifier.from_preset(\n", - " \"bert_tiny_en_uncased\",\n", - " num_classes=2,\n", - ")\n", - "classifier.fit(\n", - " imdb_train,\n", - " validation_data=imdb_test,\n", - " epochs=1,\n", + "causal_lm = keras_hub.models.CausalLM.from_preset(\n", + " \"gemma2_instruct_2b_en\",\n", + " dtype=\"bfloat16\",\n", ")" ] }, @@ -320,8 +418,29 @@ "colab_type": "text" }, "source": [ - "Here we see a significant lift in validation accuracy (0.78 -> 0.87) with a single epoch of\n", - "training even though the IMDB dataset is much smaller than `sst2`." + "The model we just loaded was an instruction-tuned version of Gemma, which means the model\n", + "was further fine-tuned for chat. We can take advantage of these capabilities as long as\n", + "we stick to the particular template for text used when training the model. These special\n", + "tokens vary per model and can be hard to track, the [Kaggle model\n", + "page](https://www.kaggle.com/models/keras/gemma2/) will contain details such as this.\n", + "\n", + "`CausalLM` come with an extra function called `generate()` which can be used generate\n", + "predict tokens in a loop and decode them as a string." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "template = \"user\\n{question}\\nmodel\"\n", + "\n", + "question = \"\"\"Write a python program to generate the first 1000 prime numbers.\n", + "Just show the actual code.\"\"\"\n", + "print(causal_lm.generate(template.format(question=question), max_length=512))" ] }, { @@ -330,16 +449,22 @@ "colab_type": "text" }, "source": [ - "## Fine tuning with user-controlled preprocessing\n", - "\"drawing\"\n", - "\n", - "For some advanced training scenarios, users might prefer direct control over\n", - "preprocessing. For large datasets, examples can be preprocessed in advance and saved to\n", - "disk or preprocessed by a separate worker pool using `tf.data.experimental.service`. In\n", - "other cases, custom preprocessing is needed to handle the inputs.\n", - "\n", - "Pass `preprocessor=None` to the constructor of a **task** `Model` to skip automatic\n", - "preprocessing or pass a custom `BertPreprocessor` instead." + "Note that on the Jax and TensorFlow backends, this `generate()` function is compiled, so\n", + "the second time you call for the same `max_length`, it will actually be much faster.\n", + "KerasHub will use Jax an TensorFlow to compute an optimized version of the generation\n", + "computational graph that can be reused." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "question = \"Share a very simple brownie recipe.\"\n", + "print(causal_lm.generate(template.format(question=question), max_length=512))" ] }, { @@ -348,19 +473,8 @@ "colab_type": "text" }, "source": [ - "### Separate preprocessing from the same preset\n", - "\n", - "Each model architecture has a parallel **preprocessor** `Layer` with its own\n", - "`from_preset` constructor. Using the same **preset** for this `Layer` will return the\n", - "matching **preprocessor** as the **task**.\n", - "\n", - "In this workflow we train the model over three epochs using `tf.data.Dataset.cache()`,\n", - "which computes the preprocessing once and caches the result before fitting begins.\n", - "\n", - "**Note:** we can use `tf.data` for preprocessing while running on the\n", - "Jax or PyTorch backend. The input dataset will automatically be converted to\n", - "backend native tensor types during fit. In fact, given the efficiency of `tf.data`\n", - "for running preprocessing, this is good practice on all backends." + "As with our image classifier, we can use model summary to see the details of our task\n", + "setup, including preprocessing." ] }, { @@ -371,33 +485,7 @@ }, "outputs": [], "source": [ - "import tensorflow as tf\n", - "\n", - "preprocessor = keras_hub.models.BertPreprocessor.from_preset(\n", - " \"bert_tiny_en_uncased\",\n", - " sequence_length=512,\n", - ")\n", - "\n", - "# Apply the preprocessor to every sample of train and test data using `map()`.\n", - "# `tf.data.AUTOTUNE` and `prefetch()` are options to tune performance, see\n", - "# https://www.tensorflow.org/guide/data_performance for details.\n", - "\n", - "# Note: only call `cache()` if you training data fits in CPU memory!\n", - "imdb_train_cached = (\n", - " imdb_train.map(preprocessor, tf.data.AUTOTUNE).cache().prefetch(tf.data.AUTOTUNE)\n", - ")\n", - "imdb_test_cached = (\n", - " imdb_test.map(preprocessor, tf.data.AUTOTUNE).cache().prefetch(tf.data.AUTOTUNE)\n", - ")\n", - "\n", - "classifier = keras_hub.models.BertClassifier.from_preset(\n", - " \"bert_tiny_en_uncased\", preprocessor=None, num_classes=2\n", - ")\n", - "classifier.fit(\n", - " imdb_train_cached,\n", - " validation_data=imdb_test_cached,\n", - " epochs=3,\n", - ")" + "causal_lm.summary()" ] }, { @@ -406,10 +494,26 @@ "colab_type": "text" }, "source": [ - "After three epochs, our validation accuracy has only increased to 0.88. This is both a\n", - "function of the small size of our dataset and our model. To exceed 90% accuracy, try\n", - "larger **presets** such as `\"bert_base_en_uncased\"`. For all the **backbone** presets\n", - "available for `BertClassifier`, see our keras.io [models page](https://keras.io/api/keras_hub/models/)." + "Our text preprocessing includes a tokenizer, which is how all KerasHub models handle\n", + "input text. Let's try using it directly to get a better sense of how it works. All\n", + "tokenizers include `tokenize()` and `detokenize()` methods, to map strings to integer\n", + "sequences and integer sequences to strings. Directly calling the layer with\n", + "`tokenizer(inputs)` is equivalent to calling `tokenizer.tokenize(inputs)`." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "tokenizer = causal_lm.preprocessor.tokenizer\n", + "tokens_ids = tokenizer.tokenize(\"The quick brown fox jumps over the lazy dog.\")\n", + "print(tokens_ids)\n", + "string = tokenizer.detokenize(tokens_ids)\n", + "print(string)" ] }, { @@ -418,16 +522,15 @@ "colab_type": "text" }, "source": [ - "### Custom preprocessing\n", - "\n", - "In cases where custom preprocessing is required, we offer direct access to the\n", - "`Tokenizer` class that maps raw strings to tokens. It also has a `from_preset()`\n", - "constructor to get the vocabulary matching pretraining.\n", - "\n", - "**Note:** `BertTokenizer` does not pad sequences by default, so the output is\n", - "ragged (each sequence has varying length). The `MultiSegmentPacker` below\n", - "handles padding these ragged sequences to dense tensor types (e.g. `tf.Tensor`\n", - "or `torch.Tensor`)." + "The `generate()` function for `CausalLM` models involved a sampling step. The Gemma model\n", + "will be called once for each token we want to generate, and return a probability\n", + "distribution over all tokens. This distribution is then sampled to choose the next token\n", + "in the sequence.\n", + "\n", + "For Gemma models, we default to greedy sampling, meaning we simply pick the most likely\n", + "output from the model at each step. But we can actually control this process with an\n", + "extra `sampler` argument to the standard `compile` function on all Keras models. Let's\n", + "try it out." ] }, { @@ -438,41 +541,12 @@ }, "outputs": [], "source": [ - "tokenizer = keras_hub.models.BertTokenizer.from_preset(\"bert_tiny_en_uncased\")\n", - "tokenizer([\"I love modular workflows!\", \"Libraries over frameworks!\"])\n", - "\n", - "# Write your own packer or use one of our `Layers`\n", - "packer = keras_hub.layers.MultiSegmentPacker(\n", - " start_value=tokenizer.cls_token_id,\n", - " end_value=tokenizer.sep_token_id,\n", - " # Note: This cannot be longer than the preset's `sequence_length`, and there\n", - " # is no check for a custom preprocessor!\n", - " sequence_length=64,\n", - ")\n", - "\n", - "\n", - "# This function that takes a text sample `x` and its\n", - "# corresponding label `y` as input and converts the\n", - "# text into a format suitable for input into a BERT model.\n", - "def preprocessor(x, y):\n", - " token_ids, segment_ids = packer(tokenizer(x))\n", - " x = {\n", - " \"token_ids\": token_ids,\n", - " \"segment_ids\": segment_ids,\n", - " \"padding_mask\": token_ids != 0,\n", - " }\n", - " return x, y\n", - "\n", - "\n", - "imdb_train_preprocessed = imdb_train.map(preprocessor, tf.data.AUTOTUNE).prefetch(\n", - " tf.data.AUTOTUNE\n", - ")\n", - "imdb_test_preprocessed = imdb_test.map(preprocessor, tf.data.AUTOTUNE).prefetch(\n", - " tf.data.AUTOTUNE\n", + "causal_lm.compile(\n", + " sampler=keras_hub.samplers.TopKSampler(k=10, temperature=2.0),\n", ")\n", "\n", - "# Preprocessed example\n", - "print(imdb_train_preprocessed.unbatch().take(1).get_single_element())" + "question = \"Share a very simple brownie recipe.\"\n", + "print(causal_lm.generate(template.format(question=question), max_length=512))" ] }, { @@ -481,22 +555,51 @@ "colab_type": "text" }, "source": [ - "## Fine tuning with a custom model\n", - "\"drawing\"\n", + "Here we used a Top-K sampler, meaning we will randomly sample the partial distribution formed\n", + "by looking at just the top 10 predicted tokens at each time step. We also pass a `temperature` of 2,\n", + "which flattens our predicted distribution before we sample.\n", "\n", - "For more advanced applications, an appropriate **task** `Model` may not be available. In\n", - "this case, we provide direct access to the **backbone** `Model`, which has its own\n", - "`from_preset` constructor and can be composed with custom `Layer`s. Detailed examples can\n", - "be found at our [transfer learning guide](https://keras.io/guides/transfer_learning/).\n", + "The net effect is that we will explore our model's distribution much more broadly each\n", + "time we generate output. Generation will now be a random process, each time we re-run\n", + "generate we will get a different result. We can note that the results feel \"looser\" than\n", + "greedy search \u2014 more minor mistakes and a less consistent tone.\n", "\n", - "A **backbone** `Model` does not include automatic preprocessing but can be paired with a\n", - "matching **preprocessor** using the same **preset** as shown in the previous workflow.\n", + "You can look at all the samplers Keras supports at [keras_hub.samplers](https://keras.io/api/keras_hub/samplers/).\n", "\n", - "In this workflow, we experiment with freezing our backbone model and adding two trainable\n", - "transformer layers to adapt to the new input.\n", + "Let's free up the memory from our large Gemma model before we jump to the next section." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "del causal_lm" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## Fine-tune and publish an image classifier\n", "\n", - "**Note**: We can ignore the warning about gradients for the `pooled_dense` layer because\n", - "we are using BERT's sequence output." + "![](https://storage.googleapis.com/keras-nlp/getting_started_guide/prof_keras_advanced.png)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "Now that we've tried running inference for both images and text, let's try running\n", + "training. We will take our ResNet image classifier from earlier and fine-tune it on\n", + "simple cats vs dogs dataset. We can start by downloading and extracting the data." ] }, { @@ -507,40 +610,79 @@ }, "outputs": [], "source": [ - "preprocessor = keras_hub.models.BertPreprocessor.from_preset(\"bert_tiny_en_uncased\")\n", - "backbone = keras_hub.models.BertBackbone.from_preset(\"bert_tiny_en_uncased\")\n", + "import pathlib\n", "\n", - "imdb_train_preprocessed = (\n", - " imdb_train.map(preprocessor, tf.data.AUTOTUNE).cache().prefetch(tf.data.AUTOTUNE)\n", - ")\n", - "imdb_test_preprocessed = (\n", - " imdb_test.map(preprocessor, tf.data.AUTOTUNE).cache().prefetch(tf.data.AUTOTUNE)\n", + "extract_dir = keras.utils.get_file(\n", + " \"cats_vs_dogs\",\n", + " \"https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_5340.zip\",\n", + " extract=True,\n", ")\n", + "data_dir = pathlib.Path(extract_dir) / \"PetImages\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "When working with lots of real-world image data, corrupted images are a common occurrence.\n", + "Let's filter out badly-encoded images that do not feature the string \"JFIF\" in their\n", + "header." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "num_skipped = 0\n", "\n", - "backbone.trainable = False\n", - "inputs = backbone.input\n", - "sequence = backbone(inputs)[\"sequence_output\"]\n", - "for _ in range(2):\n", - " sequence = keras_hub.layers.TransformerEncoder(\n", - " num_heads=2,\n", - " intermediate_dim=512,\n", - " dropout=0.1,\n", - " )(sequence)\n", - "# Use [CLS] token output to classify\n", - "outputs = keras.layers.Dense(2)(sequence[:, backbone.cls_token_index, :])\n", - "\n", - "model = keras.Model(inputs, outputs)\n", - "model.compile(\n", - " loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", - " optimizer=keras.optimizers.AdamW(5e-5),\n", - " metrics=[keras.metrics.SparseCategoricalAccuracy()],\n", - " jit_compile=True,\n", - ")\n", - "model.summary()\n", - "model.fit(\n", - " imdb_train_preprocessed,\n", - " validation_data=imdb_test_preprocessed,\n", - " epochs=3,\n", + "for path in data_dir.rglob(\"*.jpg\"):\n", + " with open(path, \"rb\") as file:\n", + " is_jfif = b\"JFIF\" in file.peek(10)\n", + " if not is_jfif:\n", + " num_skipped += 1\n", + " os.remove(path)\n", + "\n", + "print(f\"Deleted {num_skipped} images.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "We can load the dataset with `keras.utils.image_dataset_from_directory`. One important\n", + "thing to note here is that the `train_ds` and `val_ds` will both be returned as\n", + "`tf.data.Dataset` objects, including on the `torch` and `jax` backends.\n", + "\n", + "KerasHub will use [tf.data](https://www.tensorflow.org/guide/data) as the default API for\n", + "running multi-threaded preprocessing on the CPU. `tf.data` is a powerful API for training\n", + "input pipelines that can scale up to complex, multi-host training jobs easily. Using it\n", + "does not restrict your choice of backend, a `tf.data.Dataset` can be as an iterator of\n", + "regular numpy data and passed to `fit()` on any Keras backend." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "train_ds, val_ds = keras.utils.image_dataset_from_directory(\n", + " data_dir,\n", + " validation_split=0.2,\n", + " subset=\"both\",\n", + " seed=1337,\n", + " image_size=(256, 256),\n", + " batch_size=32,\n", ")" ] }, @@ -550,9 +692,33 @@ "colab_type": "text" }, "source": [ - "This model achieves reasonable accuracy despite having only 10% of the trainable parameters\n", - "of our `BertClassifier` model. Each training step takes about 1/3 of the time---even\n", - "accounting for cached preprocessing." + "At its simplest, training our classifier could consist of simply calling `fit()` on our\n", + "model with our dataset. But to make this example a little more interesting, let's show\n", + "how to customize preprocessing within a task.\n", + "\n", + "In the first example, we saw how, by default, the preprocessing for our ResNet model resized\n", + "and rescaled our input. This preprocessing can be customized when we create our model. We\n", + "can use Keras' image preprocessing layers to create a `keras.layers.Pipeline` that will\n", + "rescale, randomly flip, and randomly rotate our input images. These random image\n", + "augmentations will allow our smaller dataset to function as a larger, more varied one.\n", + "Let's try it out." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "preprocessor = keras.layers.Pipeline(\n", + " [\n", + " keras.layers.Rescaling(1.0 / 255),\n", + " keras.layers.RandomFlip(\"horizontal\"),\n", + " keras.layers.RandomRotation(0.2),\n", + " ]\n", + ")" ] }, { @@ -561,27 +727,27 @@ "colab_type": "text" }, "source": [ - "## Pretraining a backbone model\n", - "\"drawing\"\n", - "\n", - "Do you have access to large unlabeled datasets in your domain? Are they around the\n", - "same size as used to train popular backbones such as BERT, RoBERTa, or GPT2 (XX+ GiB)? If\n", - "so, you might benefit from domain-specific pretraining of your own backbone models.\n", - "\n", - "NLP models are generally pretrained on a language modeling task, predicting masked words\n", - "given the visible words in an input sentence. For example, given the input\n", - "`\"The fox [MASK] over the [MASK] dog\"`, the model might be asked to predict `[\"jumped\", \"lazy\"]`.\n", - "The lower layers of this model are then packaged as a **backbone** to be combined with\n", - "layers relating to a new task.\n", - "\n", - "The KerasHub library offers SoTA **backbones** and **tokenizers** to be trained from\n", - "scratch without presets.\n", - "\n", - "In this workflow, we pretrain a BERT **backbone** using our IMDB review text. We skip the\n", - "\"next sentence prediction\" (NSP) loss because it adds significant complexity to the data\n", - "processing and was dropped by later models like RoBERTa. See our e2e\n", - "[Transformer pretraining](https://keras.io/guides/keras_hub/transformer_pretraining/#pretraining)\n", - "for step-by-step details on how to replicate the original paper." + "Now that we have created a new layer for preprocessing, we can simply pass it to the\n", + "`ImageClassifier` during the `from_preset()` constructor. We can also pass\n", + "`num_classes=2` to match our two labels for \"cat\" and \"dog.\" When `num_classes` is\n", + "specified like this, our head weights for the model will be randomly initialized\n", + "instead of containing the weights for our 1000 class image classification." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "image_classifier = keras_hub.models.ImageClassifier.from_preset(\n", + " \"resnet_50_imagenet\",\n", + " activation=\"softmax\",\n", + " num_classes=2,\n", + " preprocessor=preprocessor,\n", + ")" ] }, { @@ -590,7 +756,14 @@ "colab_type": "text" }, "source": [ - "### Preprocessing" + "Note that if you want to preprocess your input data outside of Keras, you can simply\n", + "pass `preprocessor=None` to the task `from_preset()` call. In this case, KerasHub will\n", + "apply no preprocessing at all, and you are free to preprocess your data with any library\n", + "or workflow before passing your data to `fit()`.\n", + "\n", + "Next, we can compile our model for fine-tuning. A KerasHub task is just a regular\n", + "`keras.Model` with some extra functionality, so we can `compile()` as normal for a\n", + "classification task." ] }, { @@ -601,52 +774,93 @@ }, "outputs": [], "source": [ - "# All BERT `en` models have the same vocabulary, so reuse preprocessor from\n", - "# \"bert_tiny_en_uncased\"\n", - "preprocessor = keras_hub.models.BertPreprocessor.from_preset(\n", - " \"bert_tiny_en_uncased\",\n", - " sequence_length=256,\n", - ")\n", - "packer = preprocessor.packer\n", - "tokenizer = preprocessor.tokenizer\n", - "\n", - "# keras.Layer to replace some input tokens with the \"[MASK]\" token\n", - "masker = keras_hub.layers.MaskedLMMaskGenerator(\n", - " vocabulary_size=tokenizer.vocabulary_size(),\n", - " mask_selection_rate=0.25,\n", - " mask_selection_length=64,\n", - " mask_token_id=tokenizer.token_to_id(\"[MASK]\"),\n", - " unselectable_token_ids=[\n", - " tokenizer.token_to_id(x) for x in [\"[CLS]\", \"[PAD]\", \"[SEP]\"]\n", - " ],\n", - ")\n", - "\n", - "\n", - "def preprocess(inputs, label):\n", - " inputs = preprocessor(inputs)\n", - " masked_inputs = masker(inputs[\"token_ids\"])\n", - " # Split the masking layer outputs into a (features, labels, and weights)\n", - " # tuple that we can use with keras.Model.fit().\n", - " features = {\n", - " \"token_ids\": masked_inputs[\"token_ids\"],\n", - " \"segment_ids\": inputs[\"segment_ids\"],\n", - " \"padding_mask\": inputs[\"padding_mask\"],\n", - " \"mask_positions\": masked_inputs[\"mask_positions\"],\n", - " }\n", - " labels = masked_inputs[\"mask_ids\"]\n", - " weights = masked_inputs[\"mask_weights\"]\n", - " return features, labels, weights\n", + "image_classifier.compile(\n", + " optimizer=keras.optimizers.Adam(1e-4),\n", + " loss=\"sparse_categorical_crossentropy\",\n", + " metrics=[\"accuracy\"],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "With that, we can simply run `fit()`. The image classifier will automatically apply our\n", + "preprocessing to each batch when training the model." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "image_classifier.fit(\n", + " train_ds,\n", + " validation_data=val_ds,\n", + " epochs=3,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "After three epochs of data, we achieve 99% accuracy on our cats vs dogs\n", + "validation dataset. This is unsurprising, given that the ImageNet pretrained weights we began\n", + "with could already classify some breeds of cats and dogs individually.\n", "\n", + "Now that we have a fine-tuned model let's try saving it. You can create a new saved preset with a\n", + "fine-tuned model for any task simply by running `task.save_to_preset()`." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "image_classifier.save_to_preset(\"cats_vs_dogs\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "One of the most powerful features of KerasHub is the ability upload models to Kaggle or\n", + "Huggingface models hub and share them with others. `keras_hub.upload_preset` allows you\n", + "to upload a saved preset.\n", "\n", - "pretrain_ds = imdb_train.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE).prefetch(\n", - " tf.data.AUTOTUNE\n", - ")\n", - "pretrain_val_ds = imdb_test.map(\n", - " preprocess, num_parallel_calls=tf.data.AUTOTUNE\n", - ").prefetch(tf.data.AUTOTUNE)\n", + "In this case, we will upload to Kaggle. We have already authenticated with Kaggle to,\n", + "download the Gemma model earlier. Running the following cell well upload a new model\n", + "to Kaggle." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "from google.colab import userdata\n", "\n", - "# Tokens with ID 103 are \"masked\"\n", - "print(pretrain_ds.unbatch().take(1).get_single_element())" + "username = userdata.get(\"KAGGLE_USERNAME\")\n", + "keras_hub.upload_preset(\n", + " f\"kaggle://{username}/resnet/keras/cats_vs_dogs\",\n", + " \"cats_vs_dogs\",\n", + ")" ] }, { @@ -655,7 +869,7 @@ "colab_type": "text" }, "source": [ - "### Pretraining model" + "Let's take a look at a test image from our dataset." ] }, { @@ -666,51 +880,58 @@ }, "outputs": [], "source": [ - "# BERT backbone\n", - "backbone = keras_hub.models.BertBackbone(\n", - " vocabulary_size=tokenizer.vocabulary_size(),\n", - " num_layers=2,\n", - " num_heads=2,\n", - " hidden_dim=128,\n", - " intermediate_dim=512,\n", - ")\n", - "\n", - "# Language modeling head\n", - "mlm_head = keras_hub.layers.MaskedLMHead(\n", - " token_embedding=backbone.token_embedding,\n", - ")\n", - "\n", - "inputs = {\n", - " \"token_ids\": keras.Input(shape=(None,), dtype=tf.int32, name=\"token_ids\"),\n", - " \"segment_ids\": keras.Input(shape=(None,), dtype=tf.int32, name=\"segment_ids\"),\n", - " \"padding_mask\": keras.Input(shape=(None,), dtype=tf.int32, name=\"padding_mask\"),\n", - " \"mask_positions\": keras.Input(shape=(None,), dtype=tf.int32, name=\"mask_positions\"),\n", - "}\n", - "\n", - "# Encoded token sequence\n", - "sequence = backbone(inputs)[\"sequence_output\"]\n", - "\n", - "# Predict an output word for each masked input token.\n", - "# We use the input token embedding to project from our encoded vectors to\n", - "# vocabulary logits, which has been shown to improve training efficiency.\n", - "outputs = mlm_head(sequence, mask_positions=inputs[\"mask_positions\"])\n", - "\n", - "# Define and compile our pretraining model.\n", - "pretraining_model = keras.Model(inputs, outputs)\n", - "pretraining_model.summary()\n", - "pretraining_model.compile(\n", - " loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", - " optimizer=keras.optimizers.AdamW(learning_rate=5e-4),\n", - " weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],\n", - " jit_compile=True,\n", + "image = keras.utils.load_img(data_dir / \"Cat\" / \"6779.jpg\")\n", + "plt.imshow(image)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "If we wait for a few minutes for our model upload to finish processing on the Kaggle\n", + "side, we can go ahead and download the model we just created and use it to classify this\n", + "test image." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "image_classifier = keras_hub.models.ImageClassifier.from_preset(\n", + " f\"kaggle://{username}/resnet/keras/cats_vs_dogs\",\n", ")\n", + "print(image_classifier.predict(np.array([image])))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "Congratulations on uploading your first model with KerasHub! If you want to share your\n", + "work with others, you can go to the model link printed out when we uploaded the model, and\n", + "turn the model public in settings.\n", "\n", - "# Pretrain on IMDB dataset\n", - "pretraining_model.fit(\n", - " pretrain_ds,\n", - " validation_data=pretrain_val_ds,\n", - " epochs=3, # Increase to 6 for higher accuracy\n", - ")" + "Let's delete this model to free up memory before we move on to our final example for this\n", + "guide." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "del image_classifier" ] }, { @@ -719,7 +940,9 @@ "colab_type": "text" }, "source": [ - "After pretraining save your `backbone` submodel to use in a new task!" + "## Building a custom text classifier\n", + "\n", + "![](https://storage.googleapis.com/keras-nlp/getting_started_guide/prof_keras_expert.png)" ] }, { @@ -728,19 +951,78 @@ "colab_type": "text" }, "source": [ - "## Build and train your own transformer from scratch\n", - "\"drawing\"\n", + "As a final example for this getting started guide, let's take a look at how we can build\n", + "custom models from lower-level Keras and KerasHub components. We will build a text\n", + "classifier to classify movie reviews in the IMDb dataset as either positive or negative.\n", "\n", - "Want to implement a novel transformer architecture? The KerasHub library offers all the\n", - "low-level modules used to build SoTA architectures in our `models` API. This includes the\n", - "`keras_hub.tokenizers` API which allows you to train your own subword tokenizer using\n", - "`WordPieceTokenizer`, `BytePairTokenizer`, or `SentencePieceTokenizer`.\n", + "Let's download the dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "extract_dir = keras.utils.get_file(\n", + " \"imdb_reviews\",\n", + " origin=\"https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz\",\n", + " extract=True,\n", + ")\n", + "data_dir = pathlib.Path(extract_dir) / \"aclImdb\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "The IMDb dataset contrains a large amount of unlabeled movie reviews. We don't need those\n", + "here, we can simply delete them." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "import shutil\n", "\n", - "In this workflow, we train a custom tokenizer on the IMDB data and design a backbone with\n", - "custom transformer architecture. For simplicity, we then train directly on the\n", - "classification task. Interested in more details? We wrote an entire guide to pretraining\n", - "and finetuning a custom transformer on\n", - "[keras.io](https://keras.io/guides/keras_hub/transformer_pretraining/)," + "shutil.rmtree(data_dir / \"train\" / \"unsup\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "Next up, we can load our data with `keras.utils.text_dataset_from_directory`. As with our\n", + "image dataset creation above, the returned datasets will be `tf.data.Dataset` objects." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "raw_train_ds = keras.utils.text_dataset_from_directory(\n", + " data_dir / \"train\",\n", + " batch_size=2,\n", + ")\n", + "raw_val_ds = keras.utils.text_dataset_from_directory(\n", + " data_dir / \"test\",\n", + " batch_size=2,\n", + ")" ] }, { @@ -749,7 +1031,27 @@ "colab_type": "text" }, "source": [ - "### Train custom vocabulary from IMDB data" + "KerasHub is designed to be a layered API. At the top-most level, tasks aim to make it\n", + "easy to quickly tackle a problem. We could keep using the task API here, and create a\n", + "`keras_hub.models.TextClassifer` for a text classification model like BERT, and fine-tune\n", + "it in 10 or so lines of code.\n", + "\n", + "Instead, to make our final example a little more interesting, let's show how we can use\n", + "lower-level API components to do something that isn't directly baked in to the library.\n", + "We will take the Gemma 2 model we used earlier, which is usually used for generating text,\n", + "and modify it to output classification predictions.\n", + "\n", + "A common approach for classifying with a generative model would keep using it in a generative\n", + "context, by prompting it with the review and a question (`\"Is this review positive or negative?\"`).\n", + "But making an actual classifier is more useful if you want an actual probability score associated\n", + "with your labels.\n", + "\n", + "Instead of loading the Gemma 2 model through the `CausalLM` task, we can load two\n", + "lower-level components: a **backbone** and a **tokenizer**. Much like the task classes we have\n", + "used so far, `keras_hub.models.Backbone` and `keras_hub.tokenizers.Tokenizer` both have a\n", + "`from_preset()` constructor for loading pretrained models. If you are running this code,\n", + "you will note you don't have to wait for a download as we use the model a second time,\n", + "the weights files are cached locally the first time we use the model." ] }, { @@ -760,18 +1062,11 @@ }, "outputs": [], "source": [ - "vocab = keras_hub.tokenizers.compute_word_piece_vocabulary(\n", - " imdb_train.map(lambda x, y: x),\n", - " vocabulary_size=20_000,\n", - " lowercase=True,\n", - " strip_accents=True,\n", - " reserved_tokens=[\"[PAD]\", \"[START]\", \"[END]\", \"[MASK]\", \"[UNK]\"],\n", + "tokenizer = keras_hub.tokenizers.Tokenizer.from_preset(\n", + " \"gemma2_instruct_2b_en\",\n", ")\n", - "tokenizer = keras_hub.tokenizers.WordPieceTokenizer(\n", - " vocabulary=vocab,\n", - " lowercase=True,\n", - " strip_accents=True,\n", - " oov_token=\"[UNK]\",\n", + "backbone = keras_hub.models.Backbone.from_preset(\n", + " \"gemma2_instruct_2b_en\",\n", ")" ] }, @@ -781,7 +1076,24 @@ "colab_type": "text" }, "source": [ - "### Preprocess data with a custom tokenizer" + "We saw what the tokenizer does in the second example of this guide. We can use it to map\n", + "from string inputs to token ids in a way that matches the pretrained weights of the Gemma\n", + "model.\n", + "\n", + "The backbone will map from a sequence of token ids to a sequence of embedded tokens in\n", + "the latent space of the model. We can use this rich representation to build a classifier.\n", + "\n", + "Let's start by defining a custom preprocessing routine. `keras_hub.layers` contains a\n", + "collection of modeling and preprocessing layers, included some layers for token\n", + "preprocessing. We can use `keras_hub.layers.StartEndPacker`, which will append a special\n", + "start token to the beginning of each review, a special end token to the end, and finally\n", + "truncate or pad each review to a fixed length.\n", + "\n", + "If we combine this with our `tokenizer`, we can build a preprocessing function that will\n", + "output batches of token ids with shape `(batch_size, sequence_length)`. We should also\n", + "output a padding mask that marks which tokens are padding tokens, so we can later exclude\n", + "these positions from our Transformer's attention computation. Most Transformer backbones\n", + "in KerasNLP take in a `\"padding_mask\"` input." ] }, { @@ -793,26 +1105,109 @@ "outputs": [], "source": [ "packer = keras_hub.layers.StartEndPacker(\n", - " start_value=tokenizer.token_to_id(\"[START]\"),\n", - " end_value=tokenizer.token_to_id(\"[END]\"),\n", - " pad_value=tokenizer.token_to_id(\"[PAD]\"),\n", - " sequence_length=512,\n", + " start_value=tokenizer.start_token_id,\n", + " end_value=tokenizer.end_token_id,\n", + " pad_value=tokenizer.pad_token_id,\n", + " sequence_length=None,\n", ")\n", "\n", "\n", - "def preprocess(x, y):\n", - " token_ids = packer(tokenizer(x))\n", - " return token_ids, y\n", - "\n", + "def preprocess(x, y=None, sequence_length=256):\n", + " x = tokenizer(x)\n", + " x = packer(x, sequence_length=sequence_length)\n", + " x = {\n", + " \"token_ids\": x,\n", + " \"padding_mask\": x != tokenizer.pad_token_id,\n", + " }\n", + " return keras.utils.pack_x_y_sample_weight(x, y)\n", + "" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "With our preprocessing defined, we can simply use `tf.data.Dataset.map` to apply our\n", + "preprocessing to our input data." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "train_ds = raw_train_ds.map(preprocess, num_parallel_calls=16)\n", + "val_ds = raw_val_ds.map(preprocess, num_parallel_calls=16)\n", + "next(iter(train_ds))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "Running fine-tuning on a 2.5 billion parameter model is quite expensive compared to the\n", + "image classifier we trained earlier, for the simple reason that this model is 100x the\n", + "size of ResNet! To speed things up a bit, let's reduce the size of our training data to a\n", + "tenth of the original size. Of course, this is leaving some performance on the table\n", + "compared to full training, but it will keep things running quickly for our guide." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "train_ds = train_ds.take(1000)\n", + "val_ds = val_ds.take(1000)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "Next, we need to attach a classification head to our backbone model. In general, text\n", + "transformer backbones will output a tensor with shape\n", + "`(batch_size, sequence_length, hidden_dim)`. The main thing we will need to\n", + "classify with this input is to pool on the sequence dimension so we have a single\n", + "feature vector per input example.\n", + "\n", + "Since the Gemma model is a generative model, information only passed from left to right\n", + "in the sequence. The only token representation that can \"see\" the entire movie review\n", + "input is the final token in each review. We can write a simple pooling layer to do this \u2014\n", + "we will simply grab the last non-padding position of each input sequence. There's no special\n", + "process to writing a layer like this, we can use Keras and `keras.ops` normally." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "from keras import ops\n", "\n", - "imdb_preproc_train_ds = imdb_train.map(\n", - " preprocess, num_parallel_calls=tf.data.AUTOTUNE\n", - ").prefetch(tf.data.AUTOTUNE)\n", - "imdb_preproc_val_ds = imdb_test.map(\n", - " preprocess, num_parallel_calls=tf.data.AUTOTUNE\n", - ").prefetch(tf.data.AUTOTUNE)\n", "\n", - "print(imdb_preproc_train_ds.unbatch().take(1).get_single_element())" + "class LastTokenPooler(keras.layers.Layer):\n", + " def call(self, inputs, padding_mask):\n", + " end_positions = ops.sum(padding_mask, axis=1, keepdims=True) - 1\n", + " end_positions = ops.cast(end_positions, \"int\")[:, :, None]\n", + " outputs = ops.take_along_axis(inputs, end_positions, axis=1)\n", + " return ops.squeeze(outputs, axis=1)\n", + "" ] }, { @@ -821,7 +1216,11 @@ "colab_type": "text" }, "source": [ - "### Design a tiny transformer" + "With this pooling layer, we are ready to write our Gemma classifier. All task and backbone\n", + "models in KerasHub are [functional](https://keras.io/guides/functional_api/) models, so\n", + "we can easily manipulate the model structure. We will call our backbone on our inputs, add\n", + "our new pooling layer, and finally add a small feedforward network with a `\"relu\"` activation\n", + "in the middle. Let's try it out." ] }, { @@ -832,29 +1231,58 @@ }, "outputs": [], "source": [ - "token_id_input = keras.Input(\n", - " shape=(None,),\n", - " dtype=\"int32\",\n", - " name=\"token_ids\",\n", - ")\n", - "outputs = keras_hub.layers.TokenAndPositionEmbedding(\n", - " vocabulary_size=len(vocab),\n", - " sequence_length=packer.sequence_length,\n", - " embedding_dim=64,\n", - ")(token_id_input)\n", - "outputs = keras_hub.layers.TransformerEncoder(\n", - " num_heads=2,\n", - " intermediate_dim=128,\n", - " dropout=0.1,\n", - ")(outputs)\n", - "# Use \"[START]\" token to classify\n", - "outputs = keras.layers.Dense(2)(outputs[:, 0, :])\n", - "model = keras.Model(\n", - " inputs=token_id_input,\n", - " outputs=outputs,\n", - ")\n", - "\n", - "model.summary()" + "inputs = backbone.input\n", + "x = backbone(inputs)\n", + "x = LastTokenPooler(\n", + " name=\"pooler\",\n", + ")(x, inputs[\"padding_mask\"])\n", + "x = keras.layers.Dense(\n", + " 2048,\n", + " activation=\"relu\",\n", + " name=\"pooled_dense\",\n", + ")(x)\n", + "x = keras.layers.Dropout(\n", + " 0.1,\n", + " name=\"output_dropout\",\n", + ")(x)\n", + "outputs = keras.layers.Dense(\n", + " 2,\n", + " activation=\"softmax\",\n", + " name=\"output_dense\",\n", + ")(x)\n", + "text_classifier = keras.Model(inputs, outputs)\n", + "text_classifier.summary()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "Before we train, there is one last trick we should employ to make this code run on free\n", + "tier colab GPUs. We can see from our model summary our model takes up almost 10 gigabytes\n", + "of space. An optimizer will need to make multiple copies of each parameter during\n", + "training, taking the total space of our model during training close to 30 or 40\n", + "gigabytes.\n", + "\n", + "This would OOM many GPUs. A useful trick we can employ is to enable LoRA on our\n", + "backbone. LoRA is an approach which freezes the entire model, and only trains a low-parameter\n", + "decomposition of large weight matrices. You can read more about LoRA in this [Keras\n", + "example](https://keras.io/examples/nlp/parameter_efficient_finetuning_of_gpt2_with_lora/).\n", + "Let's try enabling it and re-printing our summary." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "backbone.enable_lora(4)\n", + "text_classifier.summary()" ] }, { @@ -863,7 +1291,10 @@ "colab_type": "text" }, "source": [ - "### Train the transformer directly on the classification objective" + "After enabling LoRA, our model goes from 10GB of traininable parameters to just 20MB.\n", + "That means the space used by optimizer variables will no longer be a concern.\n", + "\n", + "With all that set up, we can compile and train our model as normal." ] }, { @@ -874,16 +1305,14 @@ }, "outputs": [], "source": [ - "model.compile(\n", - " loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", - " optimizer=keras.optimizers.AdamW(5e-5),\n", - " metrics=[keras.metrics.SparseCategoricalAccuracy()],\n", - " jit_compile=True,\n", + "text_classifier.compile(\n", + " optimizer=keras.optimizers.Adam(5e-5),\n", + " loss=\"sparse_categorical_crossentropy\",\n", + " metrics=[\"accuracy\"],\n", ")\n", - "model.fit(\n", - " imdb_preproc_train_ds,\n", - " validation_data=imdb_preproc_val_ds,\n", - " epochs=3,\n", + "text_classifier.fit(\n", + " train_ds,\n", + " validation_data=val_ds,\n", ")" ] }, @@ -893,9 +1322,42 @@ "colab_type": "text" }, "source": [ - "Excitingly, our custom classifier is similar to the performance of fine-tuning\n", - "`\"bert_tiny_en_uncased\"`! To see the advantages of pretraining and exceed 90% accuracy we\n", - "would need to use larger **presets** such as `\"bert_base_en_uncased\"`." + "We are able to achieve over ~93% accuracy on the movie review sentiment\n", + "classification problem. This is not bad, given that we only used a 10th of our\n", + "original dataset to train.\n", + "\n", + "Taken together, the `backbone` and `tokenizer` we created in this example\n", + "allowed us access the full power of pretrained Gemma checkpoints, without\n", + "restricting what we could do with them. This is a central aim of the KerasHub\n", + "API. Simple workflows should be easy, and as you go deeper, you gain access to a\n", + "deeply customizable set of building blocks." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## Going further\n", + "\n", + "This is just scratching the surface of what you can do with the KerasHub.\n", + "\n", + "This guide shows a few of the high-level tasks that we ship with the KerasHub library,\n", + "but there are many tasks we did not cover here. Try [generating images with Stable\n", + "Diffusion](https://keras.io/guides/keras_hub/stable_diffusion_3_in_keras_hub/), for\n", + "example.\n", + "\n", + "The most significant advantage of KerasHub is it gives you the flexibility to combine pre-trained\n", + "building blocks with the full power of Keras 3. You can train large LLMs on TPUs with model\n", + "parallelism with the [keras.distribution](https://keras.io/guides/distribution/) API. You can\n", + "quantize models with Keras' [quatize\n", + "method](https://keras.io/examples/keras_recipes/float8_training_and_inference_with_transfo\n", + "rmer/). You can write custom training loops and even mix in direct Jax, Torch, or\n", + "Tensorflow calls.\n", + "\n", + "See [keras.io/keras_hub](https://keras.io/keras_hub/) for a full list of guides and\n", + "examples to continue digging into the library." ] } ], diff --git a/guides/keras_hub/getting_started.py b/guides/keras_hub/getting_started.py index 6aaaeb0510..017ffd6324 100644 --- a/guides/keras_hub/getting_started.py +++ b/guides/keras_hub/getting_started.py @@ -1,633 +1,788 @@ """ Title: Getting Started with KerasHub -Author: [Jonathan Bischof](https://github.com/jbischof) +Author: [Matthew Watson](https://github.com/mattdangerw/), [Jonathan Bischof](https://github.com/jbischof) Date created: 2022/12/15 -Last modified: 2023/07/01 +Last modified: 2024/10/17 Description: An introduction to the KerasHub API. Accelerator: GPU """ """ -## Introduction +**KerasHub** is a pretrained modeling library that aims to be simple, flexible, and fast. +The library provides [Keras 3](https://keras.io/keras_3/) implementations of popular +model architectures, paired with a collection of pretrained checkpoints available on +[Kaggle](https://www.kaggle.com/organizations/keras/models). Models can be used for both +training and inference on any of the TensorFlow, Jax, and Torch backends. -KerasHub is a natural language processing library that supports users through -their entire development cycle. Our workflows are built from modular components -that have state-of-the-art preset weights and architectures when used -out-of-the-box and are easily customizable when more control is needed. +KerasHub is an extension of the core Keras API; KerasHub components are provided as +`keras.Layer`s and `keras.Model`s. If you are familiar with Keras, congratulations! You +already understand most of KerasHub. -This library is an extension of the core Keras API; all high-level modules are -[`Layers`](/api/layers/) or [`Models`](/api/models/). If you are familiar with Keras, -congratulations! You already understand most of KerasHub. +This guide is meant to be an accessible introduction to the entire library. We will start +by using high-level APIs to classify images and generate text, then progressively show +deeper customization of models and training. Throughout the guide, we use Professor Keras, +the official Keras mascot, as a visual reference for the complexity of the material: -KerasHub uses Keras 3 to work with any of TensorFlow, Pytorch and Jax. In the -guide below, we will use the `jax` backend for training our models, and -[tf.data](https://www.tensorflow.org/guide/data) for efficiently running our -input preprocessing. But feel free to mix things up! This guide runs in -TensorFlow or PyTorch backends with zero changes, simply update the -`KERAS_BACKEND` below. +![](https://storage.googleapis.com/keras-nlp/getting_started_guide/prof_keras_evolution.png) -This guide demonstrates our modular approach using a sentiment analysis example at six -levels of complexity: - -* Inference with a pretrained classifier -* Fine tuning a pretrained backbone -* Fine tuning with user-controlled preprocessing -* Fine tuning a custom model -* Pretraining a backbone model -* Build and train your own transformer from scratch +As always, we'll keep our Keras guides focused on real-world code examples. You can play +with the code here at any time by clicking the Colab link at the top of the guide. +""" -Throughout our guide, we use Professor Keras, the official Keras mascot, as a visual -reference for the complexity of the material: +""" +## Installation and Setup +""" -drawing +""" +To begin, let's install keras-hub. The library is available on PyPI, so we can simply +install it with pip. """ """shell -pip install -q --upgrade keras-hub -pip install -q --upgrade keras # Upgrade to Keras 3. +pip install --upgrade --quiet keras-hub-nightly keras-nightly +""" + +""" +Keras 3 was built to work on top of TensorFlow, Jax, and Torch backends. You should +specify the backend first thing when writing Keras code, before any library imports. We +will use the Jax backend for this guide, but you can use `torch` or `tensorflow` without +changing a single line in the rest of this guide. That's the power of Keras 3! + +We will also set `XLA_PYTHON_CLIENT_MEM_FRACTION`, which frees up the whole GPU for +Jax to use from the start. """ import os os.environ["KERAS_BACKEND"] = "jax" # or "tensorflow" or "torch" +os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.0" -import keras_hub -import keras +""" +Lastly, we need to do some extra setup to access the models used in this guide. Many +popular open LLMs, such as Gemma from Google and Llama from Meta, require accepting +a community license before accessing the model weights. We will be using Gemma in this +guide, so we can follow the following steps: -# Use mixed precision to speed up all training in this guide. -keras.mixed_precision.set_global_policy("mixed_float16") +1. Go to the [Gemma 2](https://www.kaggle.com/models/keras/gemma2) model page, and accept + the license at the banner at the top. +2. Generate an Kaggle API key by going to [Kaggle settings](https://www.kaggle.com/settings) + and clicking "Create New Token" button under the "API" section. +3. Inside your colab notebook, click on the key icon on the left hand toolbar. Add two + secrets: `KAGGLE_USERNAME` with your username, and `KAGGLE_KEY` with the API key you just + created. Make these secrets visible to the notebook you are running. +""" """ -## API quickstart +## API Quickstart -Our highest level API is `keras_hub.models`. These symbols cover the complete user -journey of converting strings to tokens, tokens to dense features, and dense features to -task-specific output. For each `XX` architecture (e.g., `Bert`), we offer the following -modules: +Before we begin, let's take a look at the key classes we will use in the KerasHub library. -* **Tokenizer**: `keras_hub.models.XXTokenizer` +* **Task**: e.g., `keras_hub.models.CausalLM`, `keras_hub.models.ImageClassifier`, and +`keras_hub.models.TextClassifier`. + * **What it does**: A task maps from raw image, audio, and text inputs to model + predictions. + * **Why it's important**: A task is the highest-level entry point to the KerasHub API. It + encapsulates both preprocessing and modeling into a single, easy-to-use class. Tasks can + be used both for fine-tuning and inference. + * **Has a**: `backbone` and `preprocessor`. + * **Inherits from**: `keras.Model`. +* **Backbone**: `keras_hub.models.Backbone`. + * **What it does**: Maps preprocessed tensor inputs to the latent space of the model. + * **Why it's important**: The backbone encapsulates the architecture and parameters of a + pretrained models in a way that is unspecialized to any particular task. A backbone can + be combined with arbitrary preprocessing and "head" layers mapping dense features to + predictions to accomplish any ML task. + * **Inherits from**: `keras.Model`. +* **Preprocessor**: e.g.,`keras_hub.models.CausalLMPreprocessor`, + `keras_hub.models.ImageClassifierPreprocessor`, and + `keras_hub.models.TextClassifierPreprocessor`. + * **What it does**: A preprocessor maps from raw image, audio and text inputs to + preprocessed tensor inputs. + * **Why it's important**: A preprocessing layer encapsulates all tasks specific + preprocessing, e.g. image resizing and text tokenization, in a way that can be used + standalone to precompute preprocessed inputs. Note that if you are using a high-level + task class, this preprocessing is already baked in by default. + * **Has a**: `tokenizer`, `audio_converter`, and/or `image_converter`. + * **Inherits from**: `keras.layers.Layer`. +* **Tokenizer**: `keras_hub.tokenizers.Tokenizer`. * **What it does**: Converts strings to sequences of token ids. - * **Why it's important**: The raw bytes of a string are too high dimensional to be useful - features so we first map them to a small number of tokens, for example `"The quick brown - fox"` to `["the", "qu", "##ick", "br", "##own", "fox"]`. + * **Why it's important**: The raw bytes of a string are an inefficient representation of + text input, so we first map string inputs to integer token ids. This class encapsulated + the mapping of strings to ints and the reverse (via the `detokenize()` method). * **Inherits from**: `keras.layers.Layer`. -* **Preprocessor**: `keras_hub.models.XXPreprocessor` - * **What it does**: Converts strings to a dictionary of preprocessed tensors consumed by - the backbone, starting with tokenization. - * **Why it's important**: Each model uses special tokens and extra tensors to understand - the input such as delimiting input segments and identifying padding tokens. Padding each - sequence to the same length improves computational efficiency. - * **Has a**: `XXTokenizer`. +* **ImageConverter**: `keras_hub.layers.ImageConverter`. + * **What it does**: Resizes and rescales image input. + * **Why it's important**: Image models often need to normalize image inputs to a specific + range, or resizing inputs to a specific size. This class encapsulates the image-specific + preprocessing. + * **Inherits from**: `keras.layers.Layer`. +* **AudioConveter**: `keras_hub.layers.AudioConveter`. + * **What it does**: Converts raw audio to model ready input. + * **Why it's important**: Audio models often need to preprocess raw audio input before + passing it to a model, e.g. by computing a spectrogram of the audio signal. This class + encapsulates the image specific preprocessing in an easy to use layer. * **Inherits from**: `keras.layers.Layer`. -* **Backbone**: `keras_hub.models.XXBackbone` - * **What it does**: Converts preprocessed tensors to dense features. *Does not handle - strings; call the preprocessor first.* - * **Why it's important**: The backbone distills the input tokens into dense features that - can be used in downstream tasks. It is generally pretrained on a language modeling task - using massive amounts of unlabeled data. Transferring this information to a new task is a - major breakthrough in modern NLP. - * **Inherits from**: `keras.Model`. -* **Task**: e.g., `keras_hub.models.XXClassifier` - * **What it does**: Converts strings to task-specific output (e.g., classification - probabilities). - * **Why it's important**: Task models combine string preprocessing and the backbone model - with task-specific `Layers` to solve a problem such as sentence classification, token - classification, or text generation. The additional `Layers` must be fine-tuned on labeled - data. - * **Has a**: `XXBackbone` and `XXPreprocessor`. - * **Inherits from**: `keras.Model`. -Here is the modular hierarchy for `BertClassifier` (all relationships are compositional): +All of the classes listed here have a `from_preset()` constructor, which will instantiate +the component with weights and state for the given pre-trained model identifier. E.g. +`keras_hub.tokenizers.Tokenizer.from_preset("gemma2_2b_en")` will create a layer that +tokenizes text using a Gemma2 tokenizer vocabulary. -drawing +The figure below shows how all these core classes interact. Arrow indicate composition +not inheritance (e.g., a task *has a* backbone). -All modules can be used independently and have a `from_preset()` method in addition to -the standard constructor that instantiates the class with **preset** architecture and -weights (see examples below). +![png](/img/guides/getting_started/class-diagram.png) """ """ -## Data +## Classify an image -We will use a running example of sentiment analysis of IMDB movie reviews. In this task, -we use the text to predict whether the review was positive (`label = 1`) or negative -(`label = 0`). +![](https://storage.googleapis.com/keras-nlp/getting_started_guide/prof_keras_beginner.png) +""" -We load the data using `keras.utils.text_dataset_from_directory`, which utilizes the -powerful `tf.data.Dataset` format for examples. +""" +Enough setup! Let's have some fun with pre-trained models. Let's load a test image of a +California Quail and classify it. """ -"""shell -curl -O https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz -tar -xf aclImdb_v1.tar.gz -# Remove unsupervised examples -rm -r aclImdb/train/unsup +import keras +import numpy as np +import matplotlib.pyplot as plt + +image_url = "https://upload.wikimedia.org/wikipedia/commons/a/aa/California_quail.jpg" +image_path = keras.utils.get_file(origin=image_url) +image = keras.utils.load_img(image_path) +plt.imshow(image) + +""" +We can use a ResNet vision model trained on the ImageNet-1k database. This model will +give each input sample and output label from `[0, 1000)`, where each label corresponds to +some real word entity, like a "milk can" or a "porcupine." The dataset actually has a +specific label for quail, at index 85. Let's download the model and predict a label. """ -BATCH_SIZE = 16 -imdb_train = keras.utils.text_dataset_from_directory( - "aclImdb/train", - batch_size=BATCH_SIZE, -) -imdb_test = keras.utils.text_dataset_from_directory( - "aclImdb/test", - batch_size=BATCH_SIZE, +import keras_hub + +image_classifier = keras_hub.models.ImageClassifier.from_preset( + "resnet_50_imagenet", + activation="softmax", ) +batch = np.array([image]) +image_classifier.preprocessor.image_size = (224, 224) +preds = image_classifier.predict(batch) +preds.shape -# Inspect first review -# Format is (review text tensor, label tensor) -print(imdb_train.unbatch().take(1).get_single_element()) +""" +These ImageNet labels aren't a particularly "human readable," so we can use a built-in +utility function to decode the predictions to a set of class names. +""" +keras_hub.utils.decode_imagenet_predictions(preds) """ -## Inference with a pretrained classifier +Looking good! The model weights successfully downloaded, and we predicted the +correct classification label for our quail image with near certainty. + +This was our first example of the high-level **task** API mentioned in the API quickstart +above. An `keras_hub.models.ImageClassifier` is a task for classifying images, and can be +used with a number of different model architectures (ResNet, VGG, MobileNet, etc). You +can view the full list of models shipped directly by the Keras team on +[Kaggle](https://www.kaggle.com/organizations/keras/models). + +A task is just a subclass of `keras.Model` — you can use `fit()`, `compile()`, and +`save()` on our `classifier` object same as any other model. But tasks come with a few +extras provided by the KerasHub library. The first and most important is `from_preset()`, +a special constructor you will see on many classes in KerasHub. -drawing +A **preset** is a directory of model state. It defines both the architecture we should +load and the pretrained weights that go with it. `from_preset()` allows us to load +**preset** directories from a number of different locations: -The highest level module in KerasHub is a **task**. A **task** is a `keras.Model` -consisting of a (generally pretrained) **backbone** model and task-specific layers. -Here's an example using `keras_hub.models.BertClassifier`. +- A local directory. +- The Kaggle Model hub. +- The HuggingFace model hub. -**Note**: Outputs are the logits per class (e.g., `[0, 0]` is 50% chance of positive). The output is -[negative, positive] for binary classification. +You can take a look at the `keras_hub.models.ImageClassifier.from_preset` docs to better +understand all the options when constructing a Keras model from a preset. + +All tasks use two main sub-objects. A `keras_hub.models.Backbone` and a +`keras_hub.layers.Preprocessor`. You might be familiar already with the term **backbone** +from computer vision, where it is often used to describe a feature extractor network that +maps images to a latent space. A KerasHub backbone is this concept generalized, we use it +to refer to any pretrained model without a task-specific head. That is, a KerasHub +backbone maps raw images, audio and text (or a combination of these inputs) to a +pretrained model's latent space. We can then map this latent space to any number of task +specific outputs, depending on what we are trying to do with the model. + +A **preprocessor** is just a Keras layer that does all the preprocessing for a specific +task. In our case, preprocessing with will resize our input image and rescale it to the +range `[0, 1]` using some ImageNet specific mean and variance data. Let's call our +task's preprocessor and backbone in succession to see what happens to our input shape. """ -classifier = keras_hub.models.BertClassifier.from_preset("bert_tiny_en_uncased_sst2") -# Note: batched inputs expected so must wrap string in iterable -classifier.predict(["I love modular workflows in keras-hub!"]) +print("Raw input shape:", batch.shape) +resized_batch = image_classifier.preprocessor(batch) +print("Preprocessed input shape:", resized_batch.shape) +hidden_states = image_classifier.backbone(resized_batch) +print("Latent space shape:", hidden_states.shape) """ -All **tasks** have a `from_preset` method that constructs a `keras.Model` instance with -preset preprocessing, architecture and weights. This means that we can pass raw strings -in any format accepted by a `keras.Model` and get output specific to our task. +Our raw image is rescaled to `(224, 224)` during preprocessing and finally +downscaled to a `(7, 7)` image of 2048 feature vectors — the latent space of the +ResNet model. Note that ResNet can actually handle images of arbitrary sizes, +though performance will eventually fall off if your image is very different +sized than the pretrained data. If you'd like to disable the resizing in the +preprocessing layer, you can run `image_classifier.preprocessor.image_size = None`. -This particular **preset** is a `"bert_tiny_uncased_en"` **backbone** fine-tuned on -`sst2`, another movie review sentiment analysis (this time from Rotten Tomatoes). We use -the `tiny` architecture for demo purposes, but larger models are recommended for SoTA -performance. For all the task-specific presets available for `BertClassifier`, see -our keras.io [models page](https://keras.io/api/keras_hub/models/). +If you are ever wondering the exact structure of the task you loaded, you can +use `model.summary()` same as any Keras model. The model summary for tasks will +included extra information on model preprocessing. +""" -Let's evaluate our classifier on the IMDB dataset. You will note we don't need to -call `keras.Model.compile` here. All **task** models like `BertClassifier` ship with -compilation defaults, meaning we can just call `keras.Model.evaluate` directly. You -can always call compile as normal to override these defaults (e.g. to add new metrics). +image_classifier.summary() -The output below is [loss, accuracy], """ +## Generate text with an LLM -classifier.evaluate(imdb_test) +![](https://storage.googleapis.com/keras-nlp/getting_started_guide/prof_keras_intermediate.png) +""" """ -Our result is 78% accuracy without training anything. Not bad! +Next up, let's try working with and generating text. The task we can use when generating +text is `keras_hub.models.CausalLM` (where LM is short for **L**anguage **M**odel). Let's +download the 2 billion parameter Gemma 2 model and try it out. + +Since this is about 100x larger model than the ResNet model we just downloaded, we need to be +a little more careful about our GPU memory usage. We can use a half-precision type to +load each parameter of our ~2.5 billion as a two-byte float instead of four. To do this +we can pass `dtype` to the `from_preset()` constructor. `from_preset()` will forward any +kwargs to the main constructor for the class, so you can pass kwargs that work on all +Keras layers like `dtype`, `trainable`, and `name`. """ +causal_lm = keras_hub.models.CausalLM.from_preset( + "gemma2_instruct_2b_en", + dtype="bfloat16", +) + """ -## Fine tuning a pretrained BERT backbone +The model we just loaded was an instruction-tuned version of Gemma, which means the model +was further fine-tuned for chat. We can take advantage of these capabilities as long as +we stick to the particular template for text used when training the model. These special +tokens vary per model and can be hard to track, the [Kaggle model +page](https://www.kaggle.com/models/keras/gemma2/) will contain details such as this. -drawing +`CausalLM` come with an extra function called `generate()` which can be used generate +predict tokens in a loop and decode them as a string. +""" -When labeled text specific to our task is available, fine-tuning a custom classifier can -improve performance. If we want to predict IMDB review sentiment, using IMDB data should -perform better than Rotten Tomatoes data! And for many tasks, no relevant pretrained model -will be available (e.g., categorizing customer reviews). +template = "user\n{question}\nmodel" -The workflow for fine-tuning is almost identical to above, except that we request a -**preset** for the **backbone**-only model rather than the entire classifier. When passed -a **backbone** **preset**, a **task** `Model` will randomly initialize all task-specific -layers in preparation for training. For all the **backbone** presets available for -`BertClassifier`, see our keras.io [models page](https://keras.io/api/keras_hub/models/). +question = """Write a python program to generate the first 1000 prime numbers. +Just show the actual code.""" +print(causal_lm.generate(template.format(question=question), max_length=512)) -To train your classifier, use `keras.Model.fit` as with any other -`keras.Model`. As with our inference example, we can rely on the compilation -defaults for the **task** and skip `keras.Model.compile`. As preprocessing is -included, we again pass the raw data. +""" +Note that on the Jax and TensorFlow backends, this `generate()` function is compiled, so +the second time you call for the same `max_length`, it will actually be much faster. +KerasHub will use Jax an TensorFlow to compute an optimized version of the generation +computational graph that can be reused. """ -classifier = keras_hub.models.BertClassifier.from_preset( - "bert_tiny_en_uncased", - num_classes=2, -) -classifier.fit( - imdb_train, - validation_data=imdb_test, - epochs=1, -) +question = "Share a very simple brownie recipe." +print(causal_lm.generate(template.format(question=question), max_length=512)) """ -Here we see a significant lift in validation accuracy (0.78 -> 0.87) with a single epoch of -training even though the IMDB dataset is much smaller than `sst2`. +As with our image classifier, we can use model summary to see the details of our task +setup, including preprocessing. """ +causal_lm.summary() + +""" +Our text preprocessing includes a tokenizer, which is how all KerasHub models handle +input text. Let's try using it directly to get a better sense of how it works. All +tokenizers include `tokenize()` and `detokenize()` methods, to map strings to integer +sequences and integer sequences to strings. Directly calling the layer with +`tokenizer(inputs)` is equivalent to calling `tokenizer.tokenize(inputs)`. """ -## Fine tuning with user-controlled preprocessing -drawing -For some advanced training scenarios, users might prefer direct control over -preprocessing. For large datasets, examples can be preprocessed in advance and saved to -disk or preprocessed by a separate worker pool using `tf.data.experimental.service`. In -other cases, custom preprocessing is needed to handle the inputs. +tokenizer = causal_lm.preprocessor.tokenizer +tokens_ids = tokenizer.tokenize("The quick brown fox jumps over the lazy dog.") +print(tokens_ids) +string = tokenizer.detokenize(tokens_ids) +print(string) -Pass `preprocessor=None` to the constructor of a **task** `Model` to skip automatic -preprocessing or pass a custom `BertPreprocessor` instead. """ +The `generate()` function for `CausalLM` models involved a sampling step. The Gemma model +will be called once for each token we want to generate, and return a probability +distribution over all tokens. This distribution is then sampled to choose the next token +in the sequence. +For Gemma models, we default to greedy sampling, meaning we simply pick the most likely +output from the model at each step. But we can actually control this process with an +extra `sampler` argument to the standard `compile` function on all Keras models. Let's +try it out. """ -### Separate preprocessing from the same preset -Each model architecture has a parallel **preprocessor** `Layer` with its own -`from_preset` constructor. Using the same **preset** for this `Layer` will return the -matching **preprocessor** as the **task**. +causal_lm.compile( + sampler=keras_hub.samplers.TopKSampler(k=10, temperature=2.0), +) -In this workflow we train the model over three epochs using `tf.data.Dataset.cache()`, -which computes the preprocessing once and caches the result before fitting begins. +question = "Share a very simple brownie recipe." +print(causal_lm.generate(template.format(question=question), max_length=512)) -**Note:** we can use `tf.data` for preprocessing while running on the -Jax or PyTorch backend. The input dataset will automatically be converted to -backend native tensor types during fit. In fact, given the efficiency of `tf.data` -for running preprocessing, this is good practice on all backends. """ +Here we used a Top-K sampler, meaning we will randomly sample the partial distribution formed +by looking at just the top 10 predicted tokens at each time step. We also pass a `temperature` of 2, +which flattens our predicted distribution before we sample. -import tensorflow as tf - -preprocessor = keras_hub.models.BertPreprocessor.from_preset( - "bert_tiny_en_uncased", - sequence_length=512, -) +The net effect is that we will explore our model's distribution much more broadly each +time we generate output. Generation will now be a random process, each time we re-run +generate we will get a different result. We can note that the results feel "looser" than +greedy search — more minor mistakes and a less consistent tone. -# Apply the preprocessor to every sample of train and test data using `map()`. -# `tf.data.AUTOTUNE` and `prefetch()` are options to tune performance, see -# https://www.tensorflow.org/guide/data_performance for details. +You can look at all the samplers Keras supports at [keras_hub.samplers](https://keras.io/api/keras_hub/samplers/). -# Note: only call `cache()` if you training data fits in CPU memory! -imdb_train_cached = ( - imdb_train.map(preprocessor, tf.data.AUTOTUNE).cache().prefetch(tf.data.AUTOTUNE) -) -imdb_test_cached = ( - imdb_test.map(preprocessor, tf.data.AUTOTUNE).cache().prefetch(tf.data.AUTOTUNE) -) +Let's free up the memory from our large Gemma model before we jump to the next section. +""" -classifier = keras_hub.models.BertClassifier.from_preset( - "bert_tiny_en_uncased", preprocessor=None, num_classes=2 -) -classifier.fit( - imdb_train_cached, - validation_data=imdb_test_cached, - epochs=3, -) +del causal_lm """ -After three epochs, our validation accuracy has only increased to 0.88. This is both a -function of the small size of our dataset and our model. To exceed 90% accuracy, try -larger **presets** such as `"bert_base_en_uncased"`. For all the **backbone** presets -available for `BertClassifier`, see our keras.io [models page](https://keras.io/api/keras_hub/models/). +## Fine-tune and publish an image classifier + +![](https://storage.googleapis.com/keras-nlp/getting_started_guide/prof_keras_advanced.png) """ """ -### Custom preprocessing +Now that we've tried running inference for both images and text, let's try running +training. We will take our ResNet image classifier from earlier and fine-tune it on +simple cats vs dogs dataset. We can start by downloading and extracting the data. +""" + +import pathlib -In cases where custom preprocessing is required, we offer direct access to the -`Tokenizer` class that maps raw strings to tokens. It also has a `from_preset()` -constructor to get the vocabulary matching pretraining. +extract_dir = keras.utils.get_file( + "cats_vs_dogs", + "https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_5340.zip", + extract=True, +) +data_dir = pathlib.Path(extract_dir) / "PetImages" -**Note:** `BertTokenizer` does not pad sequences by default, so the output is -ragged (each sequence has varying length). The `MultiSegmentPacker` below -handles padding these ragged sequences to dense tensor types (e.g. `tf.Tensor` -or `torch.Tensor`). +""" +When working with lots of real-world image data, corrupted images are a common occurrence. +Let's filter out badly-encoded images that do not feature the string "JFIF" in their +header. """ -tokenizer = keras_hub.models.BertTokenizer.from_preset("bert_tiny_en_uncased") -tokenizer(["I love modular workflows!", "Libraries over frameworks!"]) +num_skipped = 0 -# Write your own packer or use one of our `Layers` -packer = keras_hub.layers.MultiSegmentPacker( - start_value=tokenizer.cls_token_id, - end_value=tokenizer.sep_token_id, - # Note: This cannot be longer than the preset's `sequence_length`, and there - # is no check for a custom preprocessor! - sequence_length=64, -) +for path in data_dir.rglob("*.jpg"): + with open(path, "rb") as file: + is_jfif = b"JFIF" in file.peek(10) + if not is_jfif: + num_skipped += 1 + os.remove(path) +print(f"Deleted {num_skipped} images.") -# This function that takes a text sample `x` and its -# corresponding label `y` as input and converts the -# text into a format suitable for input into a BERT model. -def preprocessor(x, y): - token_ids, segment_ids = packer(tokenizer(x)) - x = { - "token_ids": token_ids, - "segment_ids": segment_ids, - "padding_mask": token_ids != 0, - } - return x, y +""" +We can load the dataset with `keras.utils.image_dataset_from_directory`. One important +thing to note here is that the `train_ds` and `val_ds` will both be returned as +`tf.data.Dataset` objects, including on the `torch` and `jax` backends. +KerasHub will use [tf.data](https://www.tensorflow.org/guide/data) as the default API for +running multi-threaded preprocessing on the CPU. `tf.data` is a powerful API for training +input pipelines that can scale up to complex, multi-host training jobs easily. Using it +does not restrict your choice of backend, a `tf.data.Dataset` can be as an iterator of +regular numpy data and passed to `fit()` on any Keras backend. +""" -imdb_train_preprocessed = imdb_train.map(preprocessor, tf.data.AUTOTUNE).prefetch( - tf.data.AUTOTUNE -) -imdb_test_preprocessed = imdb_test.map(preprocessor, tf.data.AUTOTUNE).prefetch( - tf.data.AUTOTUNE +train_ds, val_ds = keras.utils.image_dataset_from_directory( + data_dir, + validation_split=0.2, + subset="both", + seed=1337, + image_size=(256, 256), + batch_size=32, ) -# Preprocessed example -print(imdb_train_preprocessed.unbatch().take(1).get_single_element()) +""" +At its simplest, training our classifier could consist of simply calling `fit()` on our +model with our dataset. But to make this example a little more interesting, let's show +how to customize preprocessing within a task. +In the first example, we saw how, by default, the preprocessing for our ResNet model resized +and rescaled our input. This preprocessing can be customized when we create our model. We +can use Keras' image preprocessing layers to create a `keras.layers.Pipeline` that will +rescale, randomly flip, and randomly rotate our input images. These random image +augmentations will allow our smaller dataset to function as a larger, more varied one. +Let's try it out. """ -## Fine tuning with a custom model -drawing -For more advanced applications, an appropriate **task** `Model` may not be available. In -this case, we provide direct access to the **backbone** `Model`, which has its own -`from_preset` constructor and can be composed with custom `Layer`s. Detailed examples can -be found at our [transfer learning guide](https://keras.io/guides/transfer_learning/). +preprocessor = keras.layers.Pipeline( + [ + keras.layers.Rescaling(1.0 / 255), + keras.layers.RandomFlip("horizontal"), + keras.layers.RandomRotation(0.2), + ] +) -A **backbone** `Model` does not include automatic preprocessing but can be paired with a -matching **preprocessor** using the same **preset** as shown in the previous workflow. +""" +Now that we have created a new layer for preprocessing, we can simply pass it to the +`ImageClassifier` during the `from_preset()` constructor. We can also pass +`num_classes=2` to match our two labels for "cat" and "dog." When `num_classes` is +specified like this, our head weights for the model will be randomly initialized +instead of containing the weights for our 1000 class image classification. +""" -In this workflow, we experiment with freezing our backbone model and adding two trainable -transformer layers to adapt to the new input. +image_classifier = keras_hub.models.ImageClassifier.from_preset( + "resnet_50_imagenet", + activation="softmax", + num_classes=2, + preprocessor=preprocessor, +) -**Note**: We can ignore the warning about gradients for the `pooled_dense` layer because -we are using BERT's sequence output. """ +Note that if you want to preprocess your input data outside of Keras, you can simply +pass `preprocessor=None` to the task `from_preset()` call. In this case, KerasHub will +apply no preprocessing at all, and you are free to preprocess your data with any library +or workflow before passing your data to `fit()`. -preprocessor = keras_hub.models.BertPreprocessor.from_preset("bert_tiny_en_uncased") -backbone = keras_hub.models.BertBackbone.from_preset("bert_tiny_en_uncased") +Next, we can compile our model for fine-tuning. A KerasHub task is just a regular +`keras.Model` with some extra functionality, so we can `compile()` as normal for a +classification task. +""" -imdb_train_preprocessed = ( - imdb_train.map(preprocessor, tf.data.AUTOTUNE).cache().prefetch(tf.data.AUTOTUNE) -) -imdb_test_preprocessed = ( - imdb_test.map(preprocessor, tf.data.AUTOTUNE).cache().prefetch(tf.data.AUTOTUNE) +image_classifier.compile( + optimizer=keras.optimizers.Adam(1e-4), + loss="sparse_categorical_crossentropy", + metrics=["accuracy"], ) -backbone.trainable = False -inputs = backbone.input -sequence = backbone(inputs)["sequence_output"] -for _ in range(2): - sequence = keras_hub.layers.TransformerEncoder( - num_heads=2, - intermediate_dim=512, - dropout=0.1, - )(sequence) -# Use [CLS] token output to classify -outputs = keras.layers.Dense(2)(sequence[:, backbone.cls_token_index, :]) - -model = keras.Model(inputs, outputs) -model.compile( - loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), - optimizer=keras.optimizers.AdamW(5e-5), - metrics=[keras.metrics.SparseCategoricalAccuracy()], - jit_compile=True, -) -model.summary() -model.fit( - imdb_train_preprocessed, - validation_data=imdb_test_preprocessed, +""" +With that, we can simply run `fit()`. The image classifier will automatically apply our +preprocessing to each batch when training the model. +""" + +image_classifier.fit( + train_ds, + validation_data=val_ds, epochs=3, ) """ -This model achieves reasonable accuracy despite having only 10% of the trainable parameters -of our `BertClassifier` model. Each training step takes about 1/3 of the time---even -accounting for cached preprocessing. +After three epochs of data, we achieve 99% accuracy on our cats vs dogs +validation dataset. This is unsurprising, given that the ImageNet pretrained weights we began +with could already classify some breeds of cats and dogs individually. + +Now that we have a fine-tuned model let's try saving it. You can create a new saved preset with a +fine-tuned model for any task simply by running `task.save_to_preset()`. """ +image_classifier.save_to_preset("cats_vs_dogs") + """ -## Pretraining a backbone model -drawing +One of the most powerful features of KerasHub is the ability upload models to Kaggle or +Huggingface models hub and share them with others. `keras_hub.upload_preset` allows you +to upload a saved preset. -Do you have access to large unlabeled datasets in your domain? Are they around the -same size as used to train popular backbones such as BERT, RoBERTa, or GPT2 (XX+ GiB)? If -so, you might benefit from domain-specific pretraining of your own backbone models. +In this case, we will upload to Kaggle. We have already authenticated with Kaggle to, +download the Gemma model earlier. Running the following cell well upload a new model +to Kaggle. +""" -NLP models are generally pretrained on a language modeling task, predicting masked words -given the visible words in an input sentence. For example, given the input -`"The fox [MASK] over the [MASK] dog"`, the model might be asked to predict `["jumped", "lazy"]`. -The lower layers of this model are then packaged as a **backbone** to be combined with -layers relating to a new task. +from google.colab import userdata -The KerasHub library offers SoTA **backbones** and **tokenizers** to be trained from -scratch without presets. +username = userdata.get("KAGGLE_USERNAME") +keras_hub.upload_preset( + f"kaggle://{username}/resnet/keras/cats_vs_dogs", + "cats_vs_dogs", +) -In this workflow, we pretrain a BERT **backbone** using our IMDB review text. We skip the -"next sentence prediction" (NSP) loss because it adds significant complexity to the data -processing and was dropped by later models like RoBERTa. See our e2e -[Transformer pretraining](https://keras.io/guides/keras_hub/transformer_pretraining/#pretraining) -for step-by-step details on how to replicate the original paper. +""" +Let's take a look at a test image from our dataset. """ +image = keras.utils.load_img(data_dir / "Cat" / "6779.jpg") +plt.imshow(image) + """ -### Preprocessing +If we wait for a few minutes for our model upload to finish processing on the Kaggle +side, we can go ahead and download the model we just created and use it to classify this +test image. """ -# All BERT `en` models have the same vocabulary, so reuse preprocessor from -# "bert_tiny_en_uncased" -preprocessor = keras_hub.models.BertPreprocessor.from_preset( - "bert_tiny_en_uncased", - sequence_length=256, -) -packer = preprocessor.packer -tokenizer = preprocessor.tokenizer - -# keras.Layer to replace some input tokens with the "[MASK]" token -masker = keras_hub.layers.MaskedLMMaskGenerator( - vocabulary_size=tokenizer.vocabulary_size(), - mask_selection_rate=0.25, - mask_selection_length=64, - mask_token_id=tokenizer.token_to_id("[MASK]"), - unselectable_token_ids=[ - tokenizer.token_to_id(x) for x in ["[CLS]", "[PAD]", "[SEP]"] - ], +image_classifier = keras_hub.models.ImageClassifier.from_preset( + f"kaggle://{username}/resnet/keras/cats_vs_dogs", ) +print(image_classifier.predict(np.array([image]))) +""" +Congratulations on uploading your first model with KerasHub! If you want to share your +work with others, you can go to the model link printed out when we uploaded the model, and +turn the model public in settings. -def preprocess(inputs, label): - inputs = preprocessor(inputs) - masked_inputs = masker(inputs["token_ids"]) - # Split the masking layer outputs into a (features, labels, and weights) - # tuple that we can use with keras.Model.fit(). - features = { - "token_ids": masked_inputs["token_ids"], - "segment_ids": inputs["segment_ids"], - "padding_mask": inputs["padding_mask"], - "mask_positions": masked_inputs["mask_positions"], - } - labels = masked_inputs["mask_ids"] - weights = masked_inputs["mask_weights"] - return features, labels, weights +Let's delete this model to free up memory before we move on to our final example for this +guide. +""" +del image_classifier -pretrain_ds = imdb_train.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE).prefetch( - tf.data.AUTOTUNE -) -pretrain_val_ds = imdb_test.map( - preprocess, num_parallel_calls=tf.data.AUTOTUNE -).prefetch(tf.data.AUTOTUNE) +""" +## Building a custom text classifier -# Tokens with ID 103 are "masked" -print(pretrain_ds.unbatch().take(1).get_single_element()) +![](https://storage.googleapis.com/keras-nlp/getting_started_guide/prof_keras_expert.png) +""" """ -### Pretraining model +As a final example for this getting started guide, let's take a look at how we can build +custom models from lower-level Keras and KerasHub components. We will build a text +classifier to classify movie reviews in the IMDb dataset as either positive or negative. + +Let's download the dataset. """ -# BERT backbone -backbone = keras_hub.models.BertBackbone( - vocabulary_size=tokenizer.vocabulary_size(), - num_layers=2, - num_heads=2, - hidden_dim=128, - intermediate_dim=512, +extract_dir = keras.utils.get_file( + "imdb_reviews", + origin="https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz", + extract=True, ) +data_dir = pathlib.Path(extract_dir) / "aclImdb" -# Language modeling head -mlm_head = keras_hub.layers.MaskedLMHead( - token_embedding=backbone.token_embedding, -) +""" +The IMDb dataset contrains a large amount of unlabeled movie reviews. We don't need those +here, we can simply delete them. +""" -inputs = { - "token_ids": keras.Input(shape=(None,), dtype=tf.int32, name="token_ids"), - "segment_ids": keras.Input(shape=(None,), dtype=tf.int32, name="segment_ids"), - "padding_mask": keras.Input(shape=(None,), dtype=tf.int32, name="padding_mask"), - "mask_positions": keras.Input(shape=(None,), dtype=tf.int32, name="mask_positions"), -} - -# Encoded token sequence -sequence = backbone(inputs)["sequence_output"] - -# Predict an output word for each masked input token. -# We use the input token embedding to project from our encoded vectors to -# vocabulary logits, which has been shown to improve training efficiency. -outputs = mlm_head(sequence, mask_positions=inputs["mask_positions"]) - -# Define and compile our pretraining model. -pretraining_model = keras.Model(inputs, outputs) -pretraining_model.summary() -pretraining_model.compile( - loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), - optimizer=keras.optimizers.AdamW(learning_rate=5e-4), - weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()], - jit_compile=True, -) +import shutil -# Pretrain on IMDB dataset -pretraining_model.fit( - pretrain_ds, - validation_data=pretrain_val_ds, - epochs=3, # Increase to 6 for higher accuracy -) +shutil.rmtree(data_dir / "train" / "unsup") """ -After pretraining save your `backbone` submodel to use in a new task! +Next up, we can load our data with `keras.utils.text_dataset_from_directory`. As with our +image dataset creation above, the returned datasets will be `tf.data.Dataset` objects. """ +raw_train_ds = keras.utils.text_dataset_from_directory( + data_dir / "train", + batch_size=2, +) +raw_val_ds = keras.utils.text_dataset_from_directory( + data_dir / "test", + batch_size=2, +) + """ -## Build and train your own transformer from scratch -drawing +KerasHub is designed to be a layered API. At the top-most level, tasks aim to make it +easy to quickly tackle a problem. We could keep using the task API here, and create a +`keras_hub.models.TextClassifer` for a text classification model like BERT, and fine-tune +it in 10 or so lines of code. -Want to implement a novel transformer architecture? The KerasHub library offers all the -low-level modules used to build SoTA architectures in our `models` API. This includes the -`keras_hub.tokenizers` API which allows you to train your own subword tokenizer using -`WordPieceTokenizer`, `BytePairTokenizer`, or `SentencePieceTokenizer`. +Instead, to make our final example a little more interesting, let's show how we can use +lower-level API components to do something that isn't directly baked in to the library. +We will take the Gemma 2 model we used earlier, which is usually used for generating text, +and modify it to output classification predictions. -In this workflow, we train a custom tokenizer on the IMDB data and design a backbone with -custom transformer architecture. For simplicity, we then train directly on the -classification task. Interested in more details? We wrote an entire guide to pretraining -and finetuning a custom transformer on -[keras.io](https://keras.io/guides/keras_hub/transformer_pretraining/), -""" +A common approach for classifying with a generative model would keep using it in a generative +context, by prompting it with the review and a question (`"Is this review positive or negative?"`). +But making an actual classifier is more useful if you want an actual probability score associated +with your labels. -""" -### Train custom vocabulary from IMDB data +Instead of loading the Gemma 2 model through the `CausalLM` task, we can load two +lower-level components: a **backbone** and a **tokenizer**. Much like the task classes we have +used so far, `keras_hub.models.Backbone` and `keras_hub.tokenizers.Tokenizer` both have a +`from_preset()` constructor for loading pretrained models. If you are running this code, +you will note you don't have to wait for a download as we use the model a second time, +the weights files are cached locally the first time we use the model. """ -vocab = keras_hub.tokenizers.compute_word_piece_vocabulary( - imdb_train.map(lambda x, y: x), - vocabulary_size=20_000, - lowercase=True, - strip_accents=True, - reserved_tokens=["[PAD]", "[START]", "[END]", "[MASK]", "[UNK]"], +tokenizer = keras_hub.tokenizers.Tokenizer.from_preset( + "gemma2_instruct_2b_en", ) -tokenizer = keras_hub.tokenizers.WordPieceTokenizer( - vocabulary=vocab, - lowercase=True, - strip_accents=True, - oov_token="[UNK]", +backbone = keras_hub.models.Backbone.from_preset( + "gemma2_instruct_2b_en", ) """ -### Preprocess data with a custom tokenizer +We saw what the tokenizer does in the second example of this guide. We can use it to map +from string inputs to token ids in a way that matches the pretrained weights of the Gemma +model. + +The backbone will map from a sequence of token ids to a sequence of embedded tokens in +the latent space of the model. We can use this rich representation to build a classifier. + +Let's start by defining a custom preprocessing routine. `keras_hub.layers` contains a +collection of modeling and preprocessing layers, included some layers for token +preprocessing. We can use `keras_hub.layers.StartEndPacker`, which will append a special +start token to the beginning of each review, a special end token to the end, and finally +truncate or pad each review to a fixed length. + +If we combine this with our `tokenizer`, we can build a preprocessing function that will +output batches of token ids with shape `(batch_size, sequence_length)`. We should also +output a padding mask that marks which tokens are padding tokens, so we can later exclude +these positions from our Transformer's attention computation. Most Transformer backbones +in KerasNLP take in a `"padding_mask"` input. """ packer = keras_hub.layers.StartEndPacker( - start_value=tokenizer.token_to_id("[START]"), - end_value=tokenizer.token_to_id("[END]"), - pad_value=tokenizer.token_to_id("[PAD]"), - sequence_length=512, + start_value=tokenizer.start_token_id, + end_value=tokenizer.end_token_id, + pad_value=tokenizer.pad_token_id, + sequence_length=None, ) -def preprocess(x, y): - token_ids = packer(tokenizer(x)) - return token_ids, y +def preprocess(x, y=None, sequence_length=256): + x = tokenizer(x) + x = packer(x, sequence_length=sequence_length) + x = { + "token_ids": x, + "padding_mask": x != tokenizer.pad_token_id, + } + return keras.utils.pack_x_y_sample_weight(x, y) + + +""" +With our preprocessing defined, we can simply use `tf.data.Dataset.map` to apply our +preprocessing to our input data. +""" +train_ds = raw_train_ds.map(preprocess, num_parallel_calls=16) +val_ds = raw_val_ds.map(preprocess, num_parallel_calls=16) +next(iter(train_ds)) -imdb_preproc_train_ds = imdb_train.map( - preprocess, num_parallel_calls=tf.data.AUTOTUNE -).prefetch(tf.data.AUTOTUNE) -imdb_preproc_val_ds = imdb_test.map( - preprocess, num_parallel_calls=tf.data.AUTOTUNE -).prefetch(tf.data.AUTOTUNE) +""" +Running fine-tuning on a 2.5 billion parameter model is quite expensive compared to the +image classifier we trained earlier, for the simple reason that this model is 100x the +size of ResNet! To speed things up a bit, let's reduce the size of our training data to a +tenth of the original size. Of course, this is leaving some performance on the table +compared to full training, but it will keep things running quickly for our guide. +""" -print(imdb_preproc_train_ds.unbatch().take(1).get_single_element()) +train_ds = train_ds.take(1000) +val_ds = val_ds.take(1000) """ +Next, we need to attach a classification head to our backbone model. In general, text +transformer backbones will output a tensor with shape +`(batch_size, sequence_length, hidden_dim)`. The main thing we will need to +classify with this input is to pool on the sequence dimension so we have a single +feature vector per input example. -### Design a tiny transformer +Since the Gemma model is a generative model, information only passed from left to right +in the sequence. The only token representation that can "see" the entire movie review +input is the final token in each review. We can write a simple pooling layer to do this — +we will simply grab the last non-padding position of each input sequence. There's no special +process to writing a layer like this, we can use Keras and `keras.ops` normally. """ -token_id_input = keras.Input( - shape=(None,), - dtype="int32", - name="token_ids", -) -outputs = keras_hub.layers.TokenAndPositionEmbedding( - vocabulary_size=len(vocab), - sequence_length=packer.sequence_length, - embedding_dim=64, -)(token_id_input) -outputs = keras_hub.layers.TransformerEncoder( - num_heads=2, - intermediate_dim=128, - dropout=0.1, -)(outputs) -# Use "[START]" token to classify -outputs = keras.layers.Dense(2)(outputs[:, 0, :]) -model = keras.Model( - inputs=token_id_input, - outputs=outputs, -) +from keras import ops + + +class LastTokenPooler(keras.layers.Layer): + def call(self, inputs, padding_mask): + end_positions = ops.sum(padding_mask, axis=1, keepdims=True) - 1 + end_positions = ops.cast(end_positions, "int")[:, :, None] + outputs = ops.take_along_axis(inputs, end_positions, axis=1) + return ops.squeeze(outputs, axis=1) -model.summary() """ -### Train the transformer directly on the classification objective +With this pooling layer, we are ready to write our Gemma classifier. All task and backbone +models in KerasHub are [functional](https://keras.io/guides/functional_api/) models, so +we can easily manipulate the model structure. We will call our backbone on our inputs, add +our new pooling layer, and finally add a small feedforward network with a `"relu"` activation +in the middle. Let's try it out. """ -model.compile( - loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), - optimizer=keras.optimizers.AdamW(5e-5), - metrics=[keras.metrics.SparseCategoricalAccuracy()], - jit_compile=True, +inputs = backbone.input +x = backbone(inputs) +x = LastTokenPooler( + name="pooler", +)(x, inputs["padding_mask"]) +x = keras.layers.Dense( + 2048, + activation="relu", + name="pooled_dense", +)(x) +x = keras.layers.Dropout( + 0.1, + name="output_dropout", +)(x) +outputs = keras.layers.Dense( + 2, + activation="softmax", + name="output_dense", +)(x) +text_classifier = keras.Model(inputs, outputs) +text_classifier.summary() + +""" +Before we train, there is one last trick we should employ to make this code run on free +tier colab GPUs. We can see from our model summary our model takes up almost 10 gigabytes +of space. An optimizer will need to make multiple copies of each parameter during +training, taking the total space of our model during training close to 30 or 40 +gigabytes. + +This would OOM many GPUs. A useful trick we can employ is to enable LoRA on our +backbone. LoRA is an approach which freezes the entire model, and only trains a low-parameter +decomposition of large weight matrices. You can read more about LoRA in this [Keras +example](https://keras.io/examples/nlp/parameter_efficient_finetuning_of_gpt2_with_lora/). +Let's try enabling it and re-printing our summary. +""" + +backbone.enable_lora(4) +text_classifier.summary() + +""" +After enabling LoRA, our model goes from 10GB of traininable parameters to just 20MB. +That means the space used by optimizer variables will no longer be a concern. + +With all that set up, we can compile and train our model as normal. +""" + +text_classifier.compile( + optimizer=keras.optimizers.Adam(5e-5), + loss="sparse_categorical_crossentropy", + metrics=["accuracy"], ) -model.fit( - imdb_preproc_train_ds, - validation_data=imdb_preproc_val_ds, - epochs=3, +text_classifier.fit( + train_ds, + validation_data=val_ds, ) """ -Excitingly, our custom classifier is similar to the performance of fine-tuning -`"bert_tiny_en_uncased"`! To see the advantages of pretraining and exceed 90% accuracy we -would need to use larger **presets** such as `"bert_base_en_uncased"`. +We are able to achieve over ~93% accuracy on the movie review sentiment +classification problem. This is not bad, given that we only used a 10th of our +original dataset to train. + +Taken together, the `backbone` and `tokenizer` we created in this example +allowed us access the full power of pretrained Gemma checkpoints, without +restricting what we could do with them. This is a central aim of the KerasHub +API. Simple workflows should be easy, and as you go deeper, you gain access to a +deeply customizable set of building blocks. +""" + +""" +## Going further + +This is just scratching the surface of what you can do with the KerasHub. + +This guide shows a few of the high-level tasks that we ship with the KerasHub library, +but there are many tasks we did not cover here. Try [generating images with Stable +Diffusion](https://keras.io/guides/keras_hub/stable_diffusion_3_in_keras_hub/), for +example. + +The most significant advantage of KerasHub is it gives you the flexibility to combine pre-trained +building blocks with the full power of Keras 3. You can train large LLMs on TPUs with model +parallelism with the [keras.distribution](https://keras.io/guides/distribution/) API. You can +quantize models with Keras' [quatize +method](https://keras.io/examples/keras_recipes/float8_training_and_inference_with_transfo +rmer/). You can write custom training loops and even mix in direct Jax, Torch, or +Tensorflow calls. + +See [keras.io/keras_hub](https://keras.io/keras_hub/) for a full list of guides and +examples to continue digging into the library. """ diff --git a/guides/md/keras_hub/getting_started.md b/guides/md/keras_hub/getting_started.md index 51afbb54ac..cd41cc5ce8 100644 --- a/guides/md/keras_hub/getting_started.md +++ b/guides/md/keras_hub/getting_started.md @@ -1,8 +1,8 @@ # Getting Started with KerasHub -**Author:** [Jonathan Bischof](https://github.com/jbischof)
+**Author:** [Matthew Watson](https://github.com/mattdangerw/), [Jonathan Bischof](https://github.com/jbischof)
**Date created:** 2022/12/15
-**Last modified:** 2023/07/01
+**Last modified:** 2024/10/17
**Description:** An introduction to the KerasHub API. @@ -10,1016 +10,1138 @@ ---- -## Introduction - -KerasHub is a natural language processing library that supports users through -their entire development cycle. Our workflows are built from modular components -that have state-of-the-art preset weights and architectures when used -out-of-the-box and are easily customizable when more control is needed. +**KerasHub** is a pretrained modeling library that aims to be simple, flexible, and fast. +The library provides [Keras 3](https://keras.io/keras_3/) implementations of popular +model architectures, paired with a collection of pretrained checkpoints available on +[Kaggle](https://www.kaggle.com/organizations/keras/models). Models can be used for both +training and inference on any of the TensorFlow, Jax, and Torch backends. -This library is an extension of the core Keras API; all high-level modules are -[`Layers`](/api/layers/) or [`Models`](/api/models/). If you are familiar with Keras, -congratulations! You already understand most of KerasHub. +KerasHub is an extension of the core Keras API; KerasHub components are provided as +`keras.Layer`s and `keras.Model`s. If you are familiar with Keras, congratulations! You +already understand most of KerasHub. -KerasHub uses Keras 3 to work with any of TensorFlow, Pytorch and Jax. In the -guide below, we will use the `jax` backend for training our models, and -[tf.data](https://www.tensorflow.org/guide/data) for efficiently running our -input preprocessing. But feel free to mix things up! This guide runs in -TensorFlow or PyTorch backends with zero changes, simply update the -`KERAS_BACKEND` below. +This guide is meant to be an accessible introduction to the entire library. We will start +by using high-level APIs to classify images and generate text, then progressively show +deeper customization of models and training. Throughout the guide, we use Professor Keras, +the official Keras mascot, as a visual reference for the complexity of the material: -This guide demonstrates our modular approach using a sentiment analysis example at six -levels of complexity: +![](https://storage.googleapis.com/keras-nlp/getting_started_guide/prof_keras_evolution.png) -* Inference with a pretrained classifier -* Fine tuning a pretrained backbone -* Fine tuning with user-controlled preprocessing -* Fine tuning a custom model -* Pretraining a backbone model -* Build and train your own transformer from scratch +As always, we'll keep our Keras guides focused on real-world code examples. You can play +with the code here at any time by clicking the Colab link at the top of the guide. -Throughout our guide, we use Professor Keras, the official Keras mascot, as a visual -reference for the complexity of the material: +--- +## Installation and Setup -drawing +To begin, let's install keras-hub. The library is available on PyPI, so we can simply +install it with pip. ```python -!pip install -q --upgrade keras-hub -!pip install -q --upgrade keras # Upgrade to Keras 3. +!pip install --upgrade --quiet keras-hub-nightly keras-nightly ``` +Keras 3 was built to work on top of TensorFlow, Jax, and Torch backends. You should +specify the backend first thing when writing Keras code, before any library imports. We +will use the Jax backend for this guide, but you can use `torch` or `tensorflow` without +changing a single line in the rest of this guide. That's the power of Keras 3! + +We will also set `XLA_PYTHON_CLIENT_MEM_FRACTION`, which frees up the whole GPU for +Jax to use from the start. + + ```python import os os.environ["KERAS_BACKEND"] = "jax" # or "tensorflow" or "torch" - -import keras_hub -import keras - -# Use mixed precision to speed up all training in this guide. -keras.mixed_precision.set_global_policy("mixed_float16") +os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.0" ``` -
-``` - -``` -
---- -## API quickstart +Lastly, we need to do some extra setup to access the models used in this guide. Many +popular open LLMs, such as Gemma from Google and Llama from Meta, require accepting +a community license before accessing the model weights. We will be using Gemma in this +guide, so we can follow the following steps: -Our highest level API is `keras_hub.models`. These symbols cover the complete user -journey of converting strings to tokens, tokens to dense features, and dense features to -task-specific output. For each `XX` architecture (e.g., `Bert`), we offer the following -modules: +1. Go to the [Gemma 2](https://www.kaggle.com/models/keras/gemma2) model page, and accept + the license at the banner at the top. +2. Generate an Kaggle API key by going to [Kaggle settings](https://www.kaggle.com/settings) + and clicking "Create New Token" button under the "API" section. +3. Inside your colab notebook, click on the key icon on the left hand toolbar. Add two + secrets: `KAGGLE_USERNAME` with your username, and `KAGGLE_KEY` with the API key you just + created. Make these secrets visible to the notebook you are running. -* **Tokenizer**: `keras_hub.models.XXTokenizer` +--- +## API Quickstart + +Before we begin, let's take a look at the key classes we will use in the KerasHub library. + +* **Task**: e.g., `keras_hub.models.CausalLM`, `keras_hub.models.ImageClassifier`, and +`keras_hub.models.TextClassifier`. + * **What it does**: A task maps from raw image, audio, and text inputs to model + predictions. + * **Why it's important**: A task is the highest-level entry point to the KerasHub API. It + encapsulates both preprocessing and modeling into a single, easy-to-use class. Tasks can + be used both for fine-tuning and inference. + * **Has a**: `backbone` and `preprocessor`. + * **Inherits from**: `keras.Model`. +* **Backbone**: `keras_hub.models.Backbone`. + * **What it does**: Maps preprocessed tensor inputs to the latent space of the model. + * **Why it's important**: The backbone encapsulates the architecture and parameters of a + pretrained models in a way that is unspecialized to any particular task. A backbone can + be combined with arbitrary preprocessing and "head" layers mapping dense features to + predictions to accomplish any ML task. + * **Inherits from**: `keras.Model`. +* **Preprocessor**: e.g.,`keras_hub.models.CausalLMPreprocessor`, + `keras_hub.models.ImageClassifierPreprocessor`, and + `keras_hub.models.TextClassifierPreprocessor`. + * **What it does**: A preprocessor maps from raw image, audio and text inputs to + preprocessed tensor inputs. + * **Why it's important**: A preprocessing layer encapsulates all tasks specific + preprocessing, e.g. image resizing and text tokenization, in a way that can be used + standalone to precompute preprocessed inputs. Note that if you are using a high-level + task class, this preprocessing is already baked in by default. + * **Has a**: `tokenizer`, `audio_converter`, and/or `image_converter`. + * **Inherits from**: `keras.layers.Layer`. +* **Tokenizer**: `keras_hub.tokenizers.Tokenizer`. * **What it does**: Converts strings to sequences of token ids. - * **Why it's important**: The raw bytes of a string are too high dimensional to be useful - features so we first map them to a small number of tokens, for example `"The quick brown - fox"` to `["the", "qu", "##ick", "br", "##own", "fox"]`. + * **Why it's important**: The raw bytes of a string are an inefficient representation of + text input, so we first map string inputs to integer token ids. This class encapsulated + the mapping of strings to ints and the reverse (via the `detokenize()` method). * **Inherits from**: `keras.layers.Layer`. -* **Preprocessor**: `keras_hub.models.XXPreprocessor` - * **What it does**: Converts strings to a dictionary of preprocessed tensors consumed by - the backbone, starting with tokenization. - * **Why it's important**: Each model uses special tokens and extra tensors to understand - the input such as delimiting input segments and identifying padding tokens. Padding each - sequence to the same length improves computational efficiency. - * **Has a**: `XXTokenizer`. +* **ImageConverter**: `keras_hub.layers.ImageConverter`. + * **What it does**: Resizes and rescales image input. + * **Why it's important**: Image models often need to normalize image inputs to a specific + range, or resizing inputs to a specific size. This class encapsulates the image-specific + preprocessing. + * **Inherits from**: `keras.layers.Layer`. +* **AudioConveter**: `keras_hub.layers.AudioConveter`. + * **What it does**: Converts raw audio to model ready input. + * **Why it's important**: Audio models often need to preprocess raw audio input before + passing it to a model, e.g. by computing a spectrogram of the audio signal. This class + encapsulates the image specific preprocessing in an easy to use layer. * **Inherits from**: `keras.layers.Layer`. -* **Backbone**: `keras_hub.models.XXBackbone` - * **What it does**: Converts preprocessed tensors to dense features. *Does not handle - strings; call the preprocessor first.* - * **Why it's important**: The backbone distills the input tokens into dense features that - can be used in downstream tasks. It is generally pretrained on a language modeling task - using massive amounts of unlabeled data. Transferring this information to a new task is a - major breakthrough in modern NLP. - * **Inherits from**: `keras.Model`. -* **Task**: e.g., `keras_hub.models.XXClassifier` - * **What it does**: Converts strings to task-specific output (e.g., classification - probabilities). - * **Why it's important**: Task models combine string preprocessing and the backbone model - with task-specific `Layers` to solve a problem such as sentence classification, token - classification, or text generation. The additional `Layers` must be fine-tuned on labeled - data. - * **Has a**: `XXBackbone` and `XXPreprocessor`. - * **Inherits from**: `keras.Model`. -Here is the modular hierarchy for `BertClassifier` (all relationships are compositional): +All of the classes listed here have a `from_preset()` constructor, which will instantiate +the component with weights and state for the given pre-trained model identifier. E.g. +`keras_hub.tokenizers.Tokenizer.from_preset("gemma2_2b_en")` will create a layer that +tokenizes text using a Gemma2 tokenizer vocabulary. -drawing +The figure below shows how all these core classes interact. Arrow indicate composition +not inheritance (e.g., a task *has a* backbone). -All modules can be used independently and have a `from_preset()` method in addition to -the standard constructor that instantiates the class with **preset** architecture and -weights (see examples below). +![png](/img/guides/getting_started/class-diagram.png) --- -## Data +## Classify an image -We will use a running example of sentiment analysis of IMDB movie reviews. In this task, -we use the text to predict whether the review was positive (`label = 1`) or negative -(`label = 0`). +![](https://storage.googleapis.com/keras-nlp/getting_started_guide/prof_keras_beginner.png) -We load the data using `keras.utils.text_dataset_from_directory`, which utilizes the -powerful `tf.data.Dataset` format for examples. +Enough setup! Let's have some fun with pre-trained models. Let's load a test image of a +California Quail and classify it. ```python -!curl -O https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz -!tar -xf aclImdb_v1.tar.gz -!# Remove unsupervised examples -!rm -r aclImdb/train/unsup +import keras +import numpy as np +import matplotlib.pyplot as plt + +image_url = "https://upload.wikimedia.org/wikipedia/commons/a/aa/California_quail.jpg" +image_path = keras.utils.get_file(origin=image_url) +image = keras.utils.load_img(image_path) +plt.imshow(image) ``` + +![png](/img/guides/getting_started/getting_started_11_1.png) + + + +We can use a ResNet vision model trained on the ImageNet-1k database. This model will +give each input sample and output label from `[0, 1000)`, where each label corresponds to +some real word entity, like a "milk can" or a "porcupine." The dataset actually has a +specific label for quail, at index 85. Let's download the model and predict a label. -```python -BATCH_SIZE = 16 -imdb_train = keras.utils.text_dataset_from_directory( - "aclImdb/train", - batch_size=BATCH_SIZE, -) -imdb_test = keras.utils.text_dataset_from_directory( - "aclImdb/test", - batch_size=BATCH_SIZE, -) -# Inspect first review -# Format is (review text tensor, label tensor) -print(imdb_train.unbatch().take(1).get_single_element()) +```python +import keras_hub +image_classifier = keras_hub.models.ImageClassifier.from_preset( + "resnet_50_imagenet", + activation="softmax", +) +batch = np.array([image]) +image_classifier.preprocessor.image_size = (224, 224) +preds = image_classifier.predict(batch) +preds.shape ``` +
``` - % Total % Received % Xferd Average Speed Time Time Time Current - Dload Upload Total Spent Left Speed -100 80.2M 100 80.2M 0 0 88.0M 0 --:--:-- --:--:-- --:--:-- 87.9M + 1/1 ━━━━━━━━━━━━━━━━━━━━ 2s 2s/step +(1, 1000) +``` +
+These ImageNet labels aren't a particularly "human readable," so we can use a built-in +utility function to decode the predictions to a set of class names. -Found 25000 files belonging to 2 classes. -Found 25000 files belonging to 2 classes. -(, ) +```python +keras_hub.utils.decode_imagenet_predictions(preds) ``` - ---- -## Inference with a pretrained classifier -drawing -The highest level module in KerasHub is a **task**. A **task** is a `keras.Model` -consisting of a (generally pretrained) **backbone** model and task-specific layers. -Here's an example using `keras_hub.models.BertClassifier`. -**Note**: Outputs are the logits per class (e.g., `[0, 0]` is 50% chance of positive). The output is -[negative, positive] for binary classification. + +
+``` +[[('quail', 0.9996534585952759), + ('prairie_chicken', 8.45497488626279e-05), + ('partridge', 1.4000976079842076e-05), + ('black_grouse', 7.407367775158491e-06), + ('bullet_train', 7.323932550207246e-06)]] + +``` +
+Looking good! The model weights successfully downloaded, and we predicted the +correct classification label for our quail image with near certainty. + +This was our first example of the high-level **task** API mentioned in the API quickstart +above. An `keras_hub.models.ImageClassifier` is a task for classifying images, and can be +used with a number of different model architectures (ResNet, VGG, MobileNet, etc). You +can view the full list of models shipped directly by the Keras team on +[Kaggle](https://www.kaggle.com/organizations/keras/models). + +A task is just a subclass of `keras.Model` — you can use `fit()`, `compile()`, and +`save()` on our `classifier` object same as any other model. But tasks come with a few +extras provided by the KerasHub library. The first and most important is `from_preset()`, +a special constructor you will see on many classes in KerasHub. + +A **preset** is a directory of model state. It defines both the architecture we should +load and the pretrained weights that go with it. `from_preset()` allows us to load +**preset** directories from a number of different locations: + +- A local directory. +- The Kaggle Model hub. +- The HuggingFace model hub. + +You can take a look at the `keras_hub.models.ImageClassifier.from_preset` docs to better +understand all the options when constructing a Keras model from a preset. + +All tasks use two main sub-objects. A `keras_hub.models.Backbone` and a +`keras_hub.layers.Preprocessor`. You might be familiar already with the term **backbone** +from computer vision, where it is often used to describe a feature extractor network that +maps images to a latent space. A KerasHub backbone is this concept generalized, we use it +to refer to any pretrained model without a task-specific head. That is, a KerasHub +backbone maps raw images, audio and text (or a combination of these inputs) to a +pretrained model's latent space. We can then map this latent space to any number of task +specific outputs, depending on what we are trying to do with the model. + +A **preprocessor** is just a Keras layer that does all the preprocessing for a specific +task. In our case, preprocessing with will resize our input image and rescale it to the +range `[0, 1]` using some ImageNet specific mean and variance data. Let's call our +task's preprocessor and backbone in succession to see what happens to our input shape. ```python -classifier = keras_hub.models.BertClassifier.from_preset("bert_tiny_en_uncased_sst2") -# Note: batched inputs expected so must wrap string in iterable -classifier.predict(["I love modular workflows in keras-hub!"]) +print("Raw input shape:", batch.shape) +resized_batch = image_classifier.preprocessor(batch) +print("Preprocessed input shape:", resized_batch.shape) +hidden_states = image_classifier.backbone(resized_batch) +print("Latent space shape:", hidden_states.shape) ```
``` - 1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 689ms/step +Raw input shape: (1, 557, 707, 3) + +Preprocessed input shape: (1, 224, 224, 3) -array([[-1.539, 1.543]], dtype=float16) +Latent space shape: (1, 7, 7, 2048) ```
-All **tasks** have a `from_preset` method that constructs a `keras.Model` instance with -preset preprocessing, architecture and weights. This means that we can pass raw strings -in any format accepted by a `keras.Model` and get output specific to our task. +Our raw image is rescaled to `(224, 224)` during preprocessing and finally +downscaled to a `(7, 7)` image of 2048 feature vectors — the latent space of the +ResNet model. Note that ResNet can actually handle images of arbitrary sizes, +though performance will eventually fall off if your image is very different +sized than the pretrained data. If you'd like to disable the resizing in the +preprocessing layer, you can run `image_classifier.preprocessor.image_size = None`. -This particular **preset** is a `"bert_tiny_uncased_en"` **backbone** fine-tuned on -`sst2`, another movie review sentiment analysis (this time from Rotten Tomatoes). We use -the `tiny` architecture for demo purposes, but larger models are recommended for SoTA -performance. For all the task-specific presets available for `BertClassifier`, see -our keras.io [models page](https://keras.io/api/keras_hub/models/). - -Let's evaluate our classifier on the IMDB dataset. You will note we don't need to -call `keras.Model.compile` here. All **task** models like `BertClassifier` ship with -compilation defaults, meaning we can just call `keras.Model.evaluate` directly. You -can always call compile as normal to override these defaults (e.g. to add new metrics). - -The output below is [loss, accuracy], +If you are ever wondering the exact structure of the task you loaded, you can +use `model.summary()` same as any Keras model. The model summary for tasks will +included extra information on model preprocessing. ```python -classifier.evaluate(imdb_test) +image_classifier.summary() ``` -
-``` - 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 4s 2ms/step - loss: 0.4610 - sparse_categorical_accuracy: 0.7882 -[0.4630218744277954, 0.783519983291626] +
Preprocessor: "res_net_image_classifier_preprocessor"
+
-``` -
-Our result is 78% accuracy without training anything. Not bad! ---- -## Fine tuning a pretrained BERT backbone -drawing -When labeled text specific to our task is available, fine-tuning a custom classifier can -improve performance. If we want to predict IMDB review sentiment, using IMDB data should -perform better than Rotten Tomatoes data! And for many tasks, no relevant pretrained model -will be available (e.g., categorizing customer reviews). +
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
+┃ Layer (type)                                                         Config ┃
+┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
+│ res_net_image_converter                      │        Image size: (224, 224) │
+│ (ResNetImageConverter)                       │                               │
+└──────────────────────────────────────────────┴───────────────────────────────┘
+
-The workflow for fine-tuning is almost identical to above, except that we request a -**preset** for the **backbone**-only model rather than the entire classifier. When passed -a **backbone** **preset**, a **task** `Model` will randomly initialize all task-specific -layers in preparation for training. For all the **backbone** presets available for -`BertClassifier`, see our keras.io [models page](https://keras.io/api/keras_hub/models/). -To train your classifier, use `keras.Model.fit` as with any other -`keras.Model`. As with our inference example, we can rely on the compilation -defaults for the **task** and skip `keras.Model.compile`. As preprocessing is -included, we again pass the raw data. -```python -classifier = keras_hub.models.BertClassifier.from_preset( - "bert_tiny_en_uncased", - num_classes=2, -) -classifier.fit( - imdb_train, - validation_data=imdb_test, - epochs=1, -) -``` +
Model: "res_net_image_classifier"
+
-
-``` - 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 16s 9ms/step - loss: 0.5202 - sparse_categorical_accuracy: 0.7281 - val_loss: 0.3254 - val_sparse_categorical_accuracy: 0.8621 - -``` -
-Here we see a significant lift in validation accuracy (0.78 -> 0.87) with a single epoch of -training even though the IMDB dataset is much smaller than `sst2`. ---- -## Fine tuning with user-controlled preprocessing -drawing +
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
+┃ Layer (type)                       Output Shape                    Param # ┃
+┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
+│ input_layer (InputLayer)          │ (None, None, None, 3)    │             0 │
+├───────────────────────────────────┼──────────────────────────┼───────────────┤
+│ res_net_backbone (ResNetBackbone) │ (None, None, None, 2048) │    23,561,152 │
+├───────────────────────────────────┼──────────────────────────┼───────────────┤
+│ pooler (GlobalAveragePooling2D)   │ (None, 2048)             │             0 │
+├───────────────────────────────────┼──────────────────────────┼───────────────┤
+│ output_dropout (Dropout)          │ (None, 2048)             │             0 │
+├───────────────────────────────────┼──────────────────────────┼───────────────┤
+│ predictions (Dense)               │ (None, 1000)             │     2,049,000 │
+└───────────────────────────────────┴──────────────────────────┴───────────────┘
+
-For some advanced training scenarios, users might prefer direct control over -preprocessing. For large datasets, examples can be preprocessed in advance and saved to -disk or preprocessed by a separate worker pool using `tf.data.experimental.service`. In -other cases, custom preprocessing is needed to handle the inputs. -Pass `preprocessor=None` to the constructor of a **task** `Model` to skip automatic -preprocessing or pass a custom `BertPreprocessor` instead. -### Separate preprocessing from the same preset -Each model architecture has a parallel **preprocessor** `Layer` with its own -`from_preset` constructor. Using the same **preset** for this `Layer` will return the -matching **preprocessor** as the **task**. +
 Total params: 25,610,152 (97.69 MB)
+
-In this workflow we train the model over three epochs using `tf.data.Dataset.cache()`, -which computes the preprocessing once and caches the result before fitting begins. -**Note:** we can use `tf.data` for preprocessing while running on the -Jax or PyTorch backend. The input dataset will automatically be converted to -backend native tensor types during fit. In fact, given the efficiency of `tf.data` -for running preprocessing, this is good practice on all backends. -```python -import tensorflow as tf +
 Trainable params: 25,557,032 (97.49 MB)
+
-preprocessor = keras_hub.models.BertPreprocessor.from_preset( - "bert_tiny_en_uncased", - sequence_length=512, -) -# Apply the preprocessor to every sample of train and test data using `map()`. -# `tf.data.AUTOTUNE` and `prefetch()` are options to tune performance, see -# https://www.tensorflow.org/guide/data_performance for details. -# Note: only call `cache()` if you training data fits in CPU memory! -imdb_train_cached = ( - imdb_train.map(preprocessor, tf.data.AUTOTUNE).cache().prefetch(tf.data.AUTOTUNE) -) -imdb_test_cached = ( - imdb_test.map(preprocessor, tf.data.AUTOTUNE).cache().prefetch(tf.data.AUTOTUNE) -) -classifier = keras_hub.models.BertClassifier.from_preset( - "bert_tiny_en_uncased", preprocessor=None, num_classes=2 -) -classifier.fit( - imdb_train_cached, - validation_data=imdb_test_cached, - epochs=3, -) -``` +
 Non-trainable params: 53,120 (207.50 KB)
+
-
-``` -Epoch 1/3 - 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 15s 8ms/step - loss: 0.5194 - sparse_categorical_accuracy: 0.7272 - val_loss: 0.3032 - val_sparse_categorical_accuracy: 0.8728 -Epoch 2/3 - 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 10s 7ms/step - loss: 0.2871 - sparse_categorical_accuracy: 0.8805 - val_loss: 0.2809 - val_sparse_categorical_accuracy: 0.8818 -Epoch 3/3 - 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 10s 7ms/step - loss: 0.2134 - sparse_categorical_accuracy: 0.9178 - val_loss: 0.3043 - val_sparse_categorical_accuracy: 0.8790 - -``` -
-After three epochs, our validation accuracy has only increased to 0.88. This is both a -function of the small size of our dataset and our model. To exceed 90% accuracy, try -larger **presets** such as `"bert_base_en_uncased"`. For all the **backbone** presets -available for `BertClassifier`, see our keras.io [models page](https://keras.io/api/keras_hub/models/). +--- +## Generate text with an LLM -### Custom preprocessing +![](https://storage.googleapis.com/keras-nlp/getting_started_guide/prof_keras_intermediate.png) -In cases where custom preprocessing is required, we offer direct access to the -`Tokenizer` class that maps raw strings to tokens. It also has a `from_preset()` -constructor to get the vocabulary matching pretraining. +Next up, let's try working with and generating text. The task we can use when generating +text is `keras_hub.models.CausalLM` (where LM is short for **L**anguage **M**odel). Let's +download the 2 billion parameter Gemma 2 model and try it out. -**Note:** `BertTokenizer` does not pad sequences by default, so the output is -ragged (each sequence has varying length). The `MultiSegmentPacker` below -handles padding these ragged sequences to dense tensor types (e.g. `tf.Tensor` -or `torch.Tensor`). +Since this is about 100x larger model than the ResNet model we just downloaded, we need to be +a little more careful about our GPU memory usage. We can use a half-precision type to +load each parameter of our ~2.5 billion as a two-byte float instead of four. To do this +we can pass `dtype` to the `from_preset()` constructor. `from_preset()` will forward any +kwargs to the main constructor for the class, so you can pass kwargs that work on all +Keras layers like `dtype`, `trainable`, and `name`. ```python -tokenizer = keras_hub.models.BertTokenizer.from_preset("bert_tiny_en_uncased") -tokenizer(["I love modular workflows!", "Libraries over frameworks!"]) - -# Write your own packer or use one of our `Layers` -packer = keras_hub.layers.MultiSegmentPacker( - start_value=tokenizer.cls_token_id, - end_value=tokenizer.sep_token_id, - # Note: This cannot be longer than the preset's `sequence_length`, and there - # is no check for a custom preprocessor! - sequence_length=64, +causal_lm = keras_hub.models.CausalLM.from_preset( + "gemma2_instruct_2b_en", + dtype="bfloat16", ) +``` +The model we just loaded was an instruction-tuned version of Gemma, which means the model +was further fine-tuned for chat. We can take advantage of these capabilities as long as +we stick to the particular template for text used when training the model. These special +tokens vary per model and can be hard to track, the [Kaggle model +page](https://www.kaggle.com/models/keras/gemma2/) will contain details such as this. -# This function that takes a text sample `x` and its -# corresponding label `y` as input and converts the -# text into a format suitable for input into a BERT model. -def preprocessor(x, y): - token_ids, segment_ids = packer(tokenizer(x)) - x = { - "token_ids": token_ids, - "segment_ids": segment_ids, - "padding_mask": token_ids != 0, - } - return x, y +`CausalLM` come with an extra function called `generate()` which can be used generate +predict tokens in a loop and decode them as a string. -imdb_train_preprocessed = imdb_train.map(preprocessor, tf.data.AUTOTUNE).prefetch( - tf.data.AUTOTUNE -) -imdb_test_preprocessed = imdb_test.map(preprocessor, tf.data.AUTOTUNE).prefetch( - tf.data.AUTOTUNE -) +```python +template = "user\n{question}\nmodel" -# Preprocessed example -print(imdb_train_preprocessed.unbatch().take(1).get_single_element()) +question = """Write a python program to generate the first 1000 prime numbers. +Just show the actual code.""" +print(causal_lm.generate(template.format(question=question), max_length=512)) ```
``` -({'token_ids': , 'segment_ids': , 'padding_mask': }, ) +user +Write a python program to generate the first 1000 prime numbers. +Just show the actual code. +model +def is_prime(n): + if n <= 1: + return False + for i in range(2, int(n**0.5) + 1): + if n % i == 0: + return False + return True + +count = 0 +number = 2 +primes = [] +while count < 1000: + if is_prime(number): + primes.append(number) + count += 1 + number += 1 +print(primes) + ```
+Note that on the Jax and TensorFlow backends, this `generate()` function is compiled, so +the second time you call for the same `max_length`, it will actually be much faster. +KerasHub will use Jax an TensorFlow to compute an optimized version of the generation +computational graph that can be reused. + + +```python +question = "Share a very simple brownie recipe." +print(causal_lm.generate(template.format(question=question), max_length=512)) +``` + +
+``` +user +Share a very simple brownie recipe. +model + --- -## Fine tuning with a custom model -drawing +## Super Simple Brownies -For more advanced applications, an appropriate **task** `Model` may not be available. In -this case, we provide direct access to the **backbone** `Model`, which has its own -`from_preset` constructor and can be composed with custom `Layer`s. Detailed examples can -be found at our [transfer learning guide](https://keras.io/guides/transfer_learning/). +**Ingredients:** -A **backbone** `Model` does not include automatic preprocessing but can be paired with a -matching **preprocessor** using the same **preset** as shown in the previous workflow. +* 1 cup (2 sticks) unsalted butter, melted +* 2 cups granulated sugar +* 4 large eggs +* 1 teaspoon vanilla extract +* 1 cup all-purpose flour +* 1/2 cup unsweetened cocoa powder +* 1/4 teaspoon salt -In this workflow, we experiment with freezing our backbone model and adding two trainable -transformer layers to adapt to the new input. +**Instructions:** -**Note**: We can ignore the warning about gradients for the `pooled_dense` layer because -we are using BERT's sequence output. +1. Preheat oven to 350°F (175°C). Grease and flour a 9x13 inch baking pan. +2. In a large bowl, whisk together the melted butter and sugar until smooth. +3. Beat in the eggs one at a time, then stir in the vanilla extract. +4. In a separate bowl, whisk together the flour, cocoa powder, and salt. +5. Gradually add the dry ingredients to the wet ingredients, mixing until just combined. Do not overmix. +6. Pour the batter into the prepared pan and spread evenly. +7. Bake for 25-30 minutes, or until a toothpick inserted into the center comes out with a few moist crumbs attached. +8. Let cool completely before cutting and serving. +**Tips:** -```python -preprocessor = keras_hub.models.BertPreprocessor.from_preset("bert_tiny_en_uncased") -backbone = keras_hub.models.BertBackbone.from_preset("bert_tiny_en_uncased") +* For extra fudgy brownies, underbake them slightly. +* Add chocolate chips, nuts, or other mix-ins to the batter for a personalized touch. +* Serve with a scoop of ice cream or whipped cream for a decadent treat. -imdb_train_preprocessed = ( - imdb_train.map(preprocessor, tf.data.AUTOTUNE).cache().prefetch(tf.data.AUTOTUNE) -) -imdb_test_preprocessed = ( - imdb_test.map(preprocessor, tf.data.AUTOTUNE).cache().prefetch(tf.data.AUTOTUNE) -) +Enjoy! + -backbone.trainable = False -inputs = backbone.input -sequence = backbone(inputs)["sequence_output"] -for _ in range(2): - sequence = keras_hub.layers.TransformerEncoder( - num_heads=2, - intermediate_dim=512, - dropout=0.1, - )(sequence) -# Use [CLS] token output to classify -outputs = keras.layers.Dense(2)(sequence[:, backbone.cls_token_index, :]) - -model = keras.Model(inputs, outputs) -model.compile( - loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), - optimizer=keras.optimizers.AdamW(5e-5), - metrics=[keras.metrics.SparseCategoricalAccuracy()], - jit_compile=True, -) -model.summary() -model.fit( - imdb_train_preprocessed, - validation_data=imdb_test_preprocessed, - epochs=3, -) ``` +
+As with our image classifier, we can use model summary to see the details of our task +setup, including preprocessing. -
Model: "functional_1"
+```python
+causal_lm.summary()
+```
+
+
+
Preprocessor: "gemma_causal_lm_preprocessor"
 
-
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓
-┃ Layer (type)         Output Shape       Param #  Connected to         ┃
-┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩
-│ padding_mask        │ (None, None)      │       0 │ -                    │
-│ (InputLayer)        │                   │         │                      │
-├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
-│ segment_ids         │ (None, None)      │       0 │ -                    │
-│ (InputLayer)        │                   │         │                      │
-├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
-│ token_ids           │ (None, None)      │       0 │ -                    │
-│ (InputLayer)        │                   │         │                      │
-├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
-│ bert_backbone_3     │ [(None, 128),     │ 4,385,… │ padding_mask[0][0],  │
-│ (BertBackbone)      │ (None, None,      │         │ segment_ids[0][0],   │
-│                     │ 128)]             │         │ token_ids[0][0]      │
-├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
-│ transformer_encoder │ (None, None, 128) │ 198,272 │ bert_backbone_3[0][ │
-│ (TransformerEncode… │                   │         │                      │
-├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
-│ transformer_encode… │ (None, None, 128) │ 198,272 │ transformer_encoder… │
-│ (TransformerEncode… │                   │         │                      │
-├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
-│ get_item_4          │ (None, 128)       │       0 │ transformer_encoder… │
-│ (GetItem)           │                   │         │                      │
-├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
-│ dense (Dense)       │ (None, 2)         │     258 │ get_item_4[0][0]     │
-└─────────────────────┴───────────────────┴─────────┴──────────────────────┘
+
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
+┃ Layer (type)                                                         Config ┃
+┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
+│ gemma_tokenizer (GemmaTokenizer)             │           Vocab size: 256,000 │
+└──────────────────────────────────────────────┴───────────────────────────────┘
 
-
 Total params: 4,782,722 (18.24 MB)
+
Model: "gemma_causal_lm"
 
-
 Trainable params: 396,802 (1.51 MB)
+
┏━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┓
+┃ Layer (type)           Output Shape           Param #  Connected to       ┃
+┡━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━┩
+│ padding_mask          │ (None, None)      │           0 │ -                  │
+│ (InputLayer)          │                   │             │                    │
+├───────────────────────┼───────────────────┼─────────────┼────────────────────┤
+│ token_ids             │ (None, None)      │           0 │ -                  │
+│ (InputLayer)          │                   │             │                    │
+├───────────────────────┼───────────────────┼─────────────┼────────────────────┤
+│ gemma_backbone        │ (None, None,      │ 2,614,341,… │ padding_mask[0][0… │
+│ (GemmaBackbone)       │ 2304)             │             │ token_ids[0][0]    │
+├───────────────────────┼───────────────────┼─────────────┼────────────────────┤
+│ token_embedding       │ (None, None,      │ 589,824,000 │ gemma_backbone[0]… │
+│ (ReversibleEmbedding) │ 256000)           │             │                    │
+└───────────────────────┴───────────────────┴─────────────┴────────────────────┘
 
-
 Non-trainable params: 4,385,920 (16.73 MB)
+
 Total params: 2,614,341,888 (4.87 GB)
 
+ +
 Trainable params: 2,614,341,888 (4.87 GB)
+
+ + + + +
 Non-trainable params: 0 (0.00 B)
+
+ + + +Our text preprocessing includes a tokenizer, which is how all KerasHub models handle +input text. Let's try using it directly to get a better sense of how it works. All +tokenizers include `tokenize()` and `detokenize()` methods, to map strings to integer +sequences and integer sequences to strings. Directly calling the layer with +`tokenizer(inputs)` is equivalent to calling `tokenizer.tokenize(inputs)`. + + +```python +tokenizer = causal_lm.preprocessor.tokenizer +tokens_ids = tokenizer.tokenize("The quick brown fox jumps over the lazy dog.") +print(tokens_ids) +string = tokenizer.detokenize(tokens_ids) +print(string) +``` +
``` -Epoch 1/3 - 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 17s 10ms/step - loss: 0.6208 - sparse_categorical_accuracy: 0.6612 - val_loss: 0.6119 - val_sparse_categorical_accuracy: 0.6758 -Epoch 2/3 - 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 12s 8ms/step - loss: 0.5324 - sparse_categorical_accuracy: 0.7347 - val_loss: 0.5484 - val_sparse_categorical_accuracy: 0.7320 -Epoch 3/3 - 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 12s 8ms/step - loss: 0.4735 - sparse_categorical_accuracy: 0.7723 - val_loss: 0.4874 - val_sparse_categorical_accuracy: 0.7742 +[ 651 4320 8426 25341 36271 1163 573 27894 5929 235265] +The quick brown fox jumps over the lazy dog. + +``` +
+The `generate()` function for `CausalLM` models involved a sampling step. The Gemma model +will be called once for each token we want to generate, and return a probability +distribution over all tokens. This distribution is then sampled to choose the next token +in the sequence. + +For Gemma models, we default to greedy sampling, meaning we simply pick the most likely +output from the model at each step. But we can actually control this process with an +extra `sampler` argument to the standard `compile` function on all Keras models. Let's +try it out. - + +```python +causal_lm.compile( + sampler=keras_hub.samplers.TopKSampler(k=10, temperature=2.0), +) + +question = "Share a very simple brownie recipe." +print(causal_lm.generate(template.format(question=question), max_length=512)) +``` + +
+``` +user +Share a very simple brownie recipe. +model ## Ultimate Simple Brownies + +This recipe requires NO oven or special equipment! Just microwave, mixing, and a few moments! + +**Yields:** 6 large brownies +**Prep time:** 7 minutes +**Cook time:** 6-9 minutes, depending on your microwave + +**What you need:** +* 3 ounces (about 2-3 tablespoons) chocolate chips +* 1/4 cup butter +* 1 large egg +* 1/2 cup granulated sugar +* 9 tablespoons all-purpose flour + +**Optional Add-Ins (for extra fun):** +* 1/2 teaspoon vanilla +* 1/4 cup chopped walnuts or pecans + +**Instructions:** + +1. Place all microwave-safe mixing bowl ingredients: + - Chocolate Chips 🍫 + - Butter 🧈 + - Flour 🗲 + - Egg (beaten!) + (You can add the optional add-INS like chopped nuts/extra vanilla, now is the good place to!) + + + 2. Put all that in your microwave (microwave-safe dish or a heat-safe mug is fine!) + +3. **Cook on:** Medium-high, stirring halfway. + * Time depends on your microwave, so keep checking, but aim for 6-9 minutes (if no stirring at least 8 mins). You want a thick, almost chewy-texture. + + + + **To serve:** Cut up your brownies immediately and savor this classic treat. You'd also need a tall glass of cold milk or coffee (or both, if you've really enjoyed it). + + Let me know if you want to experiment with a different chocolate or add-ins to make it even sweeter. Enjoy! 😉 + + ```
-This model achieves reasonable accuracy despite having only 10% of the trainable parameters -of our `BertClassifier` model. Each training step takes about 1/3 of the time---even -accounting for cached preprocessing. +Here we used a Top-K sampler, meaning we will randomly sample the partial distribution formed +by looking at just the top 10 predicted tokens at each time step. We also pass a `temperature` of 2, +which flattens our predicted distribution before we sample. ---- -## Pretraining a backbone model -drawing +The net effect is that we will explore our model's distribution much more broadly each +time we generate output. Generation will now be a random process, each time we re-run +generate we will get a different result. We can note that the results feel "looser" than +greedy search — more minor mistakes, a less consistent one, and the dubious recommendation to +microwave brownies. + +You can look at all the samplers Keras supports at [keras_hub.samplers](https://keras.io/api/keras_hub/samplers/). + +Let's free up the memory from our large Gemma model before we jump to the next section. -Do you have access to large unlabeled datasets in your domain? Are they around the -same size as used to train popular backbones such as BERT, RoBERTa, or GPT2 (XX+ GiB)? If -so, you might benefit from domain-specific pretraining of your own backbone models. -NLP models are generally pretrained on a language modeling task, predicting masked words -given the visible words in an input sentence. For example, given the input -`"The fox [MASK] over the [MASK] dog"`, the model might be asked to predict `["jumped", "lazy"]`. -The lower layers of this model are then packaged as a **backbone** to be combined with -layers relating to a new task. +```python +del causal_lm +``` -The KerasHub library offers SoTA **backbones** and **tokenizers** to be trained from -scratch without presets. +--- +## Fine-tune and publish an image classifier -In this workflow, we pretrain a BERT **backbone** using our IMDB review text. We skip the -"next sentence prediction" (NSP) loss because it adds significant complexity to the data -processing and was dropped by later models like RoBERTa. See our e2e -[Transformer pretraining](https://keras.io/guides/keras_hub/transformer_pretraining/#pretraining) -for step-by-step details on how to replicate the original paper. +![](https://storage.googleapis.com/keras-nlp/getting_started_guide/prof_keras_advanced.png) -### Preprocessing +Now that we've tried running inference for both images and text, let's try running +training. We will take our ResNet image classifier from earlier and fine-tune it on +simple cats vs dogs dataset. We can start by downloading and extracting the data. ```python -# All BERT `en` models have the same vocabulary, so reuse preprocessor from -# "bert_tiny_en_uncased" -preprocessor = keras_hub.models.BertPreprocessor.from_preset( - "bert_tiny_en_uncased", - sequence_length=256, -) -packer = preprocessor.packer -tokenizer = preprocessor.tokenizer - -# keras.Layer to replace some input tokens with the "[MASK]" token -masker = keras_hub.layers.MaskedLMMaskGenerator( - vocabulary_size=tokenizer.vocabulary_size(), - mask_selection_rate=0.25, - mask_selection_length=64, - mask_token_id=tokenizer.token_to_id("[MASK]"), - unselectable_token_ids=[ - tokenizer.token_to_id(x) for x in ["[CLS]", "[PAD]", "[SEP]"] - ], +import pathlib + +extract_dir = keras.utils.get_file( + "cats_vs_dogs", + "https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_5340.zip", + extract=True, ) +data_dir = pathlib.Path(extract_dir) / "PetImages" +``` +When working with lots of real-world image data, corrupted images are a common occurrence. +Let's filter out badly-encoded images that do not feature the string "JFIF" in their +header. -def preprocess(inputs, label): - inputs = preprocessor(inputs) - masked_inputs = masker(inputs["token_ids"]) - # Split the masking layer outputs into a (features, labels, and weights) - # tuple that we can use with keras.Model.fit(). - features = { - "token_ids": masked_inputs["token_ids"], - "segment_ids": inputs["segment_ids"], - "padding_mask": inputs["padding_mask"], - "mask_positions": masked_inputs["mask_positions"], - } - labels = masked_inputs["mask_ids"] - weights = masked_inputs["mask_weights"] - return features, labels, weights +```python +num_skipped = 0 -pretrain_ds = imdb_train.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE).prefetch( - tf.data.AUTOTUNE -) -pretrain_val_ds = imdb_test.map( - preprocess, num_parallel_calls=tf.data.AUTOTUNE -).prefetch(tf.data.AUTOTUNE) +for path in data_dir.rglob("*.jpg"): + with open(path, "rb") as file: + is_jfif = b"JFIF" in file.peek(10) + if not is_jfif: + num_skipped += 1 + os.remove(path) -# Tokens with ID 103 are "masked" -print(pretrain_ds.unbatch().take(1).get_single_element()) +print(f"Deleted {num_skipped} images.") ```
``` -({'token_ids': , 'segment_ids': , 'padding_mask': , 'mask_positions': }, , ) +Deleted 1590 images. ```
-### Pretraining model +We can load the dataset with `keras.utils.image_dataset_from_directory`. One important +thing to note here is that the `train_ds` and `val_ds` will both be returned as +`tf.data.Dataset` objects, including on the `torch` and `jax` backends. + +KerasHub will use [tf.data](https://www.tensorflow.org/guide/data) as the default API for +running multi-threaded preprocessing on the CPU. `tf.data` is a powerful API for training +input pipelines that can scale up to complex, multi-host training jobs easily. Using it +does not restrict your choice of backend, a `tf.data.Dataset` can be as an iterator of +regular numpy data and passed to `fit()` on any Keras backend. ```python -# BERT backbone -backbone = keras_hub.models.BertBackbone( - vocabulary_size=tokenizer.vocabulary_size(), - num_layers=2, - num_heads=2, - hidden_dim=128, - intermediate_dim=512, +train_ds, val_ds = keras.utils.image_dataset_from_directory( + data_dir, + validation_split=0.2, + subset="both", + seed=1337, + image_size=(256, 256), + batch_size=32, ) +``` -# Language modeling head -mlm_head = keras_hub.layers.MaskedLMHead( - token_embedding=backbone.token_embedding, +
+``` +Found 23410 files belonging to 2 classes. + +Using 18728 files for training. + +Using 4682 files for validation. + +``` +
+At its simplest, training our classifier could consist of simply calling `fit()` on our +model with our dataset. But to make this example a little more interesting, let's show +how to customize preprocessing within a task. + +In the first example, we saw how, by default, the preprocessing for our ResNet model resized +and rescaled our input. This preprocessing can be customized when we create our model. We +can use Keras' image preprocessing layers to create a `keras.layers.Pipeline` that will +rescale, randomly flip, and randomly rotate our input images. These random image +augmentations will allow our smaller dataset to function as a larger, more varied one. +Let's try it out. + + +```python +preprocessor = keras.layers.Pipeline( + [ + keras.layers.Rescaling(1.0 / 255), + keras.layers.RandomFlip("horizontal"), + keras.layers.RandomRotation(0.2), + ] ) +``` -inputs = { - "token_ids": keras.Input(shape=(None,), dtype=tf.int32, name="token_ids"), - "segment_ids": keras.Input(shape=(None,), dtype=tf.int32, name="segment_ids"), - "padding_mask": keras.Input(shape=(None,), dtype=tf.int32, name="padding_mask"), - "mask_positions": keras.Input(shape=(None,), dtype=tf.int32, name="mask_positions"), -} - -# Encoded token sequence -sequence = backbone(inputs)["sequence_output"] - -# Predict an output word for each masked input token. -# We use the input token embedding to project from our encoded vectors to -# vocabulary logits, which has been shown to improve training efficiency. -outputs = mlm_head(sequence, mask_positions=inputs["mask_positions"]) - -# Define and compile our pretraining model. -pretraining_model = keras.Model(inputs, outputs) -pretraining_model.summary() -pretraining_model.compile( - loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), - optimizer=keras.optimizers.AdamW(learning_rate=5e-4), - weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()], - jit_compile=True, +Now that we have created a new layer for preprocessing, we can simply pass it to the +`ImageClassifier` during the `from_preset()` constructor. We can also pass +`num_classes=2` to match our two labels for "cat" and "dog." When `num_classes` is +specified like this, our head weights for the model will be randomly initialized +instead of containing the weights for our 1000 class image classification. + + +```python +image_classifier = keras_hub.models.ImageClassifier.from_preset( + "resnet_50_imagenet", + activation="softmax", + num_classes=2, + preprocessor=preprocessor, ) +``` -# Pretrain on IMDB dataset -pretraining_model.fit( - pretrain_ds, - validation_data=pretrain_val_ds, - epochs=3, # Increase to 6 for higher accuracy +Note that if you want to preprocess your input data outside of Keras, you can simply +pass `preprocessor=None` to the task `from_preset()` call. In this case, KerasHub will +apply no preprocessing at all, and you are free to preprocess your data with any library +or workflow before passing your data to `fit()`. + +Next, we can compile our model for fine-tuning. A KerasHub task is just a regular +`keras.Model` with some extra functionality, so we can `compile()` as normal for a +classification task. + + +```python +image_classifier.compile( + optimizer=keras.optimizers.Adam(1e-4), + loss="sparse_categorical_crossentropy", + metrics=["accuracy"], ) ``` +With that, we can simply run `fit()`. The image classifier will automatically apply our +preprocessing to each batch when training the model. -
Model: "functional_3"
-
+```python +image_classifier.fit( + train_ds, + validation_data=val_ds, + epochs=3, +) +``` +
+``` +Epoch 1/3 + 586/586 ━━━━━━━━━━━━━━━━━━━━ 0s 122ms/step - accuracy: 0.8869 - loss: 0.2921 +Epoch 2/3 + 586/586 ━━━━━━━━━━━━━━━━━━━━ 65s 105ms/step - accuracy: 0.9858 - loss: 0.0393 - val_accuracy: 0.9912 - val_loss: 0.0234 +Epoch 3/3 + 586/586 ━━━━━━━━━━━━━━━━━━━━ 57s 96ms/step - accuracy: 0.9897 - loss: 0.0289 - val_accuracy: 0.9930 - val_loss: 0.0206 + + +``` +
+After three epochs of data, we achieve 99% accuracy on our cats vs dogs +validation dataset. This is unsurprising, given that the ImageNet pretrained weights we began +with could already classify some breeds of cats and dogs individually. +Now that we have a fine-tuned model let's try saving it. You can create a new saved preset with a +fine-tuned model for any task simply by running `task.save_to_preset()`. -
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓
-┃ Layer (type)         Output Shape       Param #  Connected to         ┃
-┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩
-│ mask_positions      │ (None, None)      │       0 │ -                    │
-│ (InputLayer)        │                   │         │                      │
-├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
-│ padding_mask        │ (None, None)      │       0 │ -                    │
-│ (InputLayer)        │                   │         │                      │
-├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
-│ segment_ids         │ (None, None)      │       0 │ -                    │
-│ (InputLayer)        │                   │         │                      │
-├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
-│ token_ids           │ (None, None)      │       0 │ -                    │
-│ (InputLayer)        │                   │         │                      │
-├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
-│ bert_backbone_4     │ [(None, 128),     │ 4,385,… │ mask_positions[0][0… │
-│ (BertBackbone)      │ (None, None,      │         │ padding_mask[0][0],  │
-│                     │ 128)]             │         │ segment_ids[0][0],   │
-│                     │                   │         │ token_ids[0][0]      │
-├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
-│ masked_lm_head      │ (None, None,      │ 3,954,… │ bert_backbone_4[0][ │
-│ (MaskedLMHead)      │ 30522)            │         │ mask_positions[0][0] │
-└─────────────────────┴───────────────────┴─────────┴──────────────────────┘
-
+```python +image_classifier.save_to_preset("cats_vs_dogs") +``` +One of the most powerful features of KerasHub is the ability upload models to Kaggle or +Huggingface models hub and share them with others. `keras_hub.upload_preset` allows you +to upload a saved preset. +In this case, we will upload to Kaggle. We have already authenticated with Kaggle to, +download the Gemma model earlier. Running the following cell well upload a new model +to Kaggle. -
 Total params: 4,433,210 (16.91 MB)
-
+```python +from google.colab import userdata +username = userdata.get("KAGGLE_USERNAME") +keras_hub.upload_preset( + f"kaggle://{username}/resnet/keras/cats_vs_dogs", + "cats_vs_dogs", +) +``` +
+``` +Uploading Model https://www.kaggle.com/models/matthewdwatson/resnet/keras/cats_vs_dogs ... +Upload successful: cats_vs_dogs/task.json (5KB) +Upload successful: cats_vs_dogs/task.weights.h5 (270MB) +Upload successful: cats_vs_dogs/metadata.json (157B) +Upload successful: cats_vs_dogs/model.weights.h5 (90MB) +Upload successful: cats_vs_dogs/config.json (841B) +Upload successful: cats_vs_dogs/preprocessor.json (3KB) + +Your model instance version has been created. +Files are being processed... +See at: https://www.kaggle.com/models/matthewdwatson/resnet/keras/cats_vs_dogs -
 Trainable params: 4,433,210 (16.91 MB)
-
+``` +
+Let's take a look at a test image from our dataset. +```python +image = keras.utils.load_img(data_dir / "Cat" / "6779.jpg") +plt.imshow(image) +``` +![png](/img/guides/getting_started/getting_started_55_1.png) + -
 Non-trainable params: 0 (0.00 B)
-
+ +If we wait for a few minutes for our model upload to finish processing on the Kaggle +side, we can go ahead and download the model we just created and use it to classify this +test image. +```python +image_classifier = keras_hub.models.ImageClassifier.from_preset( + f"kaggle://{username}/resnet/keras/cats_vs_dogs", +) +print(image_classifier.predict(np.array([image]))) +```
``` -Epoch 1/3 - 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 22s 12ms/step - loss: 5.7032 - sparse_categorical_accuracy: 0.0566 - val_loss: 5.0685 - val_sparse_categorical_accuracy: 0.1044 -Epoch 2/3 - 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 13s 8ms/step - loss: 5.0701 - sparse_categorical_accuracy: 0.1096 - val_loss: 4.9363 - val_sparse_categorical_accuracy: 0.1239 -Epoch 3/3 - 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 13s 8ms/step - loss: 4.9607 - sparse_categorical_accuracy: 0.1240 - val_loss: 4.7913 - val_sparse_categorical_accuracy: 0.1417 + 1/1 ━━━━━━━━━━━━━━━━━━━━ 2s 2s/step - +[[9.999286e-01 7.135461e-05]] ```
-After pretraining save your `backbone` submodel to use in a new task! +Congratulations on uploading your first model with KerasHub! If you want to share your +work with others, you can go to the model link printed out when we uploaded the model, and +turn the model public in settings. + +Let's delete this model to free up memory before we move on to our final example for this +guide. + + +```python +del image_classifier +``` --- -## Build and train your own transformer from scratch -drawing +## Building a custom text classifier -Want to implement a novel transformer architecture? The KerasHub library offers all the -low-level modules used to build SoTA architectures in our `models` API. This includes the -`keras_hub.tokenizers` API which allows you to train your own subword tokenizer using -`WordPieceTokenizer`, `BytePairTokenizer`, or `SentencePieceTokenizer`. +![](https://storage.googleapis.com/keras-nlp/getting_started_guide/prof_keras_expert.png) -In this workflow, we train a custom tokenizer on the IMDB data and design a backbone with -custom transformer architecture. For simplicity, we then train directly on the -classification task. Interested in more details? We wrote an entire guide to pretraining -and finetuning a custom transformer on -[keras.io](https://keras.io/guides/keras_hub/transformer_pretraining/), +As a final example for this getting started guide, let's take a look at how we can build +custom models from lower-level Keras and KerasHub components. We will build a text +classifier to classify movie reviews in the IMDb dataset as either positive or negative. -### Train custom vocabulary from IMDB data +Let's download the dataset. ```python -vocab = keras_hub.tokenizers.compute_word_piece_vocabulary( - imdb_train.map(lambda x, y: x), - vocabulary_size=20_000, - lowercase=True, - strip_accents=True, - reserved_tokens=["[PAD]", "[START]", "[END]", "[MASK]", "[UNK]"], +extract_dir = keras.utils.get_file( + "imdb_reviews", + origin="https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz", + extract=True, ) -tokenizer = keras_hub.tokenizers.WordPieceTokenizer( - vocabulary=vocab, - lowercase=True, - strip_accents=True, - oov_token="[UNK]", +data_dir = pathlib.Path(extract_dir) / "aclImdb" +``` + +The IMDb dataset contrains a large amount of unlabeled movie reviews. We don't need those +here, we can simply delete them. + + +```python +import shutil + +shutil.rmtree(data_dir / "train" / "unsup") +``` + +Next up, we can load our data with `keras.utils.text_dataset_from_directory`. As with our +image dataset creation above, the returned datasets will be `tf.data.Dataset` objects. + + +```python +raw_train_ds = keras.utils.text_dataset_from_directory( + data_dir / "train", + batch_size=2, +) +raw_val_ds = keras.utils.text_dataset_from_directory( + data_dir / "test", + batch_size=2, ) ``` -### Preprocess data with a custom tokenizer +
+``` +Found 25000 files belonging to 2 classes. + +Found 25000 files belonging to 2 classes. + +``` +
+KerasHub is designed to be a layered API. At the top-most level, tasks aim to make it +easy to quickly tackle a problem. We could keep using the task API here, and create a +`keras_hub.models.TextClassifer` for a text classification model like BERT, and fine-tune +it in 10 or so lines of code. + +Instead, to make our final example a little more interesting, let's show how we can use +lower-level API components to do something that isn't directly baked in to the library. +We will take the Gemma 2 model we used earlier, which is usually used for generating text, +and modify it to output classification predictions. + +A common approach for classifying with a generative model would keep using it in a generative +context, by prompting it with the review and a question (`"Is this review positive or negative?"`). +But making an actual classifier is more useful if you want an actual probability score associated +with your labels. + +Instead of loading the Gemma 2 model through the `CausalLM` task, we can load two +lower-level components: a **backbone** and a **tokenizer**. Much like the task classes we have +used so far, `keras_hub.models.Backbone` and `keras_hub.tokenizers.Tokenizer` both have a +`from_preset()` constructor for loading pretrained models. If you are running this code, +you will note you don't have to wait for a download as we use the model a second time, +the weights files are cached locally the first time we use the model. + + +```python +tokenizer = keras_hub.tokenizers.Tokenizer.from_preset( + "gemma2_instruct_2b_en", +) +backbone = keras_hub.models.Backbone.from_preset( + "gemma2_instruct_2b_en", +) +``` + +We saw what the tokenizer does in the second example of this guide. We can use it to map +from string inputs to token ids in a way that matches the pretrained weights of the Gemma +model. + +The backbone will map from a sequence of token ids to a sequence of embedded tokens in +the latent space of the model. We can use this rich representation to build a classifier. + +Let's start by defining a custom preprocessing routine. `keras_hub.layers` contains a +collection of modeling and preprocessing layers, included some layers for token +preprocessing. We can use `keras_hub.layers.StartEndPacker`, which will append a special +start token to the beginning of each review, a special end token to the end, and finally +truncate or pad each review to a fixed length. + +If we combine this with our `tokenizer`, we can build a preprocessing function that will +output batches of token ids with shape `(batch_size, sequence_length)`. We should also +output a padding mask that marks which tokens are padding tokens, so we can later exclude +these positions from our Transformer's attention computation. Most Transformer backbones +in KerasNLP take in a `"padding_mask"` input. ```python packer = keras_hub.layers.StartEndPacker( - start_value=tokenizer.token_to_id("[START]"), - end_value=tokenizer.token_to_id("[END]"), - pad_value=tokenizer.token_to_id("[PAD]"), - sequence_length=512, + start_value=tokenizer.start_token_id, + end_value=tokenizer.end_token_id, + pad_value=tokenizer.pad_token_id, + sequence_length=None, ) -def preprocess(x, y): - token_ids = packer(tokenizer(x)) - return token_ids, y +def preprocess(x, y=None, sequence_length=256): + x = tokenizer(x) + x = packer(x, sequence_length=sequence_length) + x = { + "token_ids": x, + "padding_mask": x != tokenizer.pad_token_id, + } + return keras.utils.pack_x_y_sample_weight(x, y) + +``` +With our preprocessing defined, we can simply use `tf.data.Dataset.map` to apply our +preprocessing to our input data. -imdb_preproc_train_ds = imdb_train.map( - preprocess, num_parallel_calls=tf.data.AUTOTUNE -).prefetch(tf.data.AUTOTUNE) -imdb_preproc_val_ds = imdb_test.map( - preprocess, num_parallel_calls=tf.data.AUTOTUNE -).prefetch(tf.data.AUTOTUNE) -print(imdb_preproc_train_ds.unbatch().take(1).get_single_element()) +```python +train_ds = raw_train_ds.map(preprocess, num_parallel_calls=16) +val_ds = raw_val_ds.map(preprocess, num_parallel_calls=16) +next(iter(train_ds)) ``` + + +
``` -(, ) +({'token_ids': , + 'padding_mask': }, + ) ```
-### Design a tiny transformer +Running fine-tuning on a 2.5 billion parameter model is quite expensive compared to the +image classifier we trained earlier, for the simple reason that this model is 100x the +size of ResNet! To speed things up a bit, let's reduce the size of our training data to a +tenth of the original size. Of course, this is leaving some performance on the table +compared to full training, but it will keep things running quickly for our guide. ```python -token_id_input = keras.Input( - shape=(None,), - dtype="int32", - name="token_ids", -) -outputs = keras_hub.layers.TokenAndPositionEmbedding( - vocabulary_size=len(vocab), - sequence_length=packer.sequence_length, - embedding_dim=64, -)(token_id_input) -outputs = keras_hub.layers.TransformerEncoder( - num_heads=2, - intermediate_dim=128, - dropout=0.1, -)(outputs) -# Use "[START]" token to classify -outputs = keras.layers.Dense(2)(outputs[:, 0, :]) -model = keras.Model( - inputs=token_id_input, - outputs=outputs, -) +train_ds = train_ds.take(1000) +val_ds = val_ds.take(1000) +``` + +Next, we need to attach a classification head to our backbone model. In general, text +transformer backbones will output a tensor with shape +`(batch_size, sequence_length, hidden_dim)`. The main thing we will need to +classify with this input is to pool on the sequence dimension so we have a single +feature vector per input example. + +Since the Gemma model is a generative model, information only passed from left to right +in the sequence. The only token representation that can "see" the entire movie review +input is the final token in each review. We can write a simple pooling layer to do this — +we will simply grab the last non-padding position of each input sequence. There's no special +process to writing a layer like this, we can use Keras and `keras.ops` normally. + + +```python +from keras import ops + -model.summary() +class LastTokenPooler(keras.layers.Layer): + def call(self, inputs, padding_mask): + end_positions = ops.sum(padding_mask, axis=1, keepdims=True) - 1 + end_positions = ops.cast(end_positions, "int")[:, :, None] + outputs = ops.take_along_axis(inputs, end_positions, axis=1) + return ops.squeeze(outputs, axis=1) + +``` + +With this pooling layer, we are ready to write our Gemma classifier. All task and backbone +models in KerasHub are [functional](https://keras.io/guides/functional_api/) models, so +we can easily manipulate the model structure. We will call our backbone on our inputs, add +our new pooling layer, and finally add a small feedforward network with a `"relu"` activation +in the middle. Let's try it out. + + +```python +inputs = backbone.input +x = backbone(inputs) +x = LastTokenPooler( + name="pooler", +)(x, inputs["padding_mask"]) +x = keras.layers.Dense( + 2048, + activation="relu", + name="pooled_dense", +)(x) +x = keras.layers.Dropout( + 0.1, + name="output_dropout", +)(x) +outputs = keras.layers.Dense( + 2, + activation="softmax", + name="output_dense", +)(x) +text_classifier = keras.Model(inputs, outputs) +text_classifier.summary() ``` -
Model: "functional_5"
+
Model: "functional"
 
-
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓
-┃ Layer (type)                     Output Shape                  Param # ┃
-┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩
-│ token_ids (InputLayer)          │ (None, None)              │          0 │
-├─────────────────────────────────┼───────────────────────────┼────────────┤
-│ token_and_position_embedding    │ (None, None, 64)          │  1,259,648 │
-│ (TokenAndPositionEmbedding)     │                           │            │
-├─────────────────────────────────┼───────────────────────────┼────────────┤
-│ transformer_encoder_2           │ (None, None, 64)          │     33,472 │
-│ (TransformerEncoder)            │                           │            │
-├─────────────────────────────────┼───────────────────────────┼────────────┤
-│ get_item_6 (GetItem)            │ (None, 64)                │          0 │
-├─────────────────────────────────┼───────────────────────────┼────────────┤
-│ dense_1 (Dense)                 │ (None, 2)                 │        130 │
-└─────────────────────────────────┴───────────────────────────┴────────────┘
+
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
+┃ Layer (type)         Output Shape          Param #  Connected to      ┃
+┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
+│ padding_mask        │ (None, None)      │          0 │ -                 │
+│ (InputLayer)        │                   │            │                   │
+├─────────────────────┼───────────────────┼────────────┼───────────────────┤
+│ token_ids           │ (None, None)      │          0 │ -                 │
+│ (InputLayer)        │                   │            │                   │
+├─────────────────────┼───────────────────┼────────────┼───────────────────┤
+│ gemma_backbone      │ (None, None,      │ 2,614,341… │ padding_mask[0][ │
+│ (GemmaBackbone)     │ 2304)             │            │ token_ids[0][0]   │
+├─────────────────────┼───────────────────┼────────────┼───────────────────┤
+│ pooler              │ (None, 2304)      │          0 │ gemma_backbone[0… │
+│ (LastTokenPooler)   │                   │            │ padding_mask[0][ │
+├─────────────────────┼───────────────────┼────────────┼───────────────────┤
+│ pooled_dense        │ (None, 2048)      │  4,720,640 │ pooler[0][0]      │
+│ (Dense)             │                   │            │                   │
+├─────────────────────┼───────────────────┼────────────┼───────────────────┤
+│ output_dropout      │ (None, 2048)      │          0 │ pooled_dense[0][ │
+│ (Dropout)           │                   │            │                   │
+├─────────────────────┼───────────────────┼────────────┼───────────────────┤
+│ output_dense        │ (None, 2)         │      4,098 │ output_dropout[0… │
+│ (Dense)             │                   │            │                   │
+└─────────────────────┴───────────────────┴────────────┴───────────────────┘
 
-
 Total params: 1,293,250 (4.93 MB)
+
 Total params: 2,619,066,626 (9.76 GB)
 
-
 Trainable params: 1,293,250 (4.93 MB)
+
 Trainable params: 2,619,066,626 (9.76 GB)
 
@@ -1030,36 +1152,129 @@ model.summary() -### Train the transformer directly on the classification objective +Before we train, there is one last trick we should employ to make this code run on free +tier colab GPUs. We can see from our model summary our model takes up almost 10 gigabytes +of space. An optimizer will need to make multiple copies of each parameter during +training, taking the total space of our model during training close to 30 or 40 +gigabytes. + +This would OOM many GPUs. A useful trick we can employ is to enable LoRA on our +backbone. LoRA is an approach which freezes the entire model, and only trains a low-parameter +decomposition of large weight matrices. You can read more about LoRA in this [Keras +example](https://keras.io/examples/nlp/parameter_efficient_finetuning_of_gpt2_with_lora/). +Let's try enabling it and re-printing our summary. ```python -model.compile( - loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), - optimizer=keras.optimizers.AdamW(5e-5), - metrics=[keras.metrics.SparseCategoricalAccuracy()], - jit_compile=True, +backbone.enable_lora(4) +text_classifier.summary() +``` + + +
Model: "functional"
+
+ + + + +
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
+┃ Layer (type)         Output Shape          Param #  Connected to      ┃
+┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
+│ padding_mask        │ (None, None)      │          0 │ -                 │
+│ (InputLayer)        │                   │            │                   │
+├─────────────────────┼───────────────────┼────────────┼───────────────────┤
+│ token_ids           │ (None, None)      │          0 │ -                 │
+│ (InputLayer)        │                   │            │                   │
+├─────────────────────┼───────────────────┼────────────┼───────────────────┤
+│ gemma_backbone      │ (None, None,      │ 2,617,270… │ padding_mask[0][ │
+│ (GemmaBackbone)     │ 2304)             │            │ token_ids[0][0]   │
+├─────────────────────┼───────────────────┼────────────┼───────────────────┤
+│ pooler              │ (None, 2304)      │          0 │ gemma_backbone[0… │
+│ (LastTokenPooler)   │                   │            │ padding_mask[0][ │
+├─────────────────────┼───────────────────┼────────────┼───────────────────┤
+│ pooled_dense        │ (None, 2048)      │  4,720,640 │ pooler[0][0]      │
+│ (Dense)             │                   │            │                   │
+├─────────────────────┼───────────────────┼────────────┼───────────────────┤
+│ output_dropout      │ (None, 2048)      │          0 │ pooled_dense[0][ │
+│ (Dropout)           │                   │            │                   │
+├─────────────────────┼───────────────────┼────────────┼───────────────────┤
+│ output_dense        │ (None, 2)         │      4,098 │ output_dropout[0… │
+│ (Dense)             │                   │            │                   │
+└─────────────────────┴───────────────────┴────────────┴───────────────────┘
+
+ + + + +
 Total params: 2,621,995,266 (9.77 GB)
+
+ + + + +
 Trainable params: 7,653,378 (29.20 MB)
+
+ + + + +
 Non-trainable params: 2,614,341,888 (9.74 GB)
+
+ + + +After enabling LoRA, our model goes from 10GB of traininable parameters to just 20MB. +That means the space used by optimizer variables will no longer be a concern. + +With all that set up, we can compile and train our model as normal. + + +```python +text_classifier.compile( + optimizer=keras.optimizers.Adam(5e-5), + loss="sparse_categorical_crossentropy", + metrics=["accuracy"], ) -model.fit( - imdb_preproc_train_ds, - validation_data=imdb_preproc_val_ds, - epochs=3, +text_classifier.fit( + train_ds, + validation_data=val_ds, ) ```
``` -Epoch 1/3 - 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 8s 4ms/step - loss: 0.7790 - sparse_categorical_accuracy: 0.5367 - val_loss: 0.4420 - val_sparse_categorical_accuracy: 0.8120 -Epoch 2/3 - 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 5s 3ms/step - loss: 0.3654 - sparse_categorical_accuracy: 0.8443 - val_loss: 0.3046 - val_sparse_categorical_accuracy: 0.8752 -Epoch 3/3 - 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 5s 3ms/step - loss: 0.2471 - sparse_categorical_accuracy: 0.9019 - val_loss: 0.3060 - val_sparse_categorical_accuracy: 0.8748 - - + 1000/1000 ━━━━━━━━━━━━━━━━━━━━ 295s 285ms/step - accuracy: 0.7733 - loss: 0.6511 - val_accuracy: 0.9370 - val_loss: 0.2814 + ```
-Excitingly, our custom classifier is similar to the performance of fine-tuning -`"bert_tiny_en_uncased"`! To see the advantages of pretraining and exceed 90% accuracy we -would need to use larger **presets** such as `"bert_base_en_uncased"`. +We are able to achieve over ~93% accuracy on the movie review sentiment +classification problem. This is not bad, given that we only used a 10th of our +original dataset to train. + +Taken together, the `backbone` and `tokenizer` we created in this example +allowed us access the full power of pretrained Gemma checkpoints, without +restricting what we could do with them. This is a central aim of the KerasHub +API. Simple workflows should be easy, and as you go deeper, you gain access to a +deeply customizable set of building blocks. + +--- +## Going further + +This is just scratching the surface of what you can do with the KerasHub. + +This guide shows a few of the high-level tasks that we ship with the KerasHub library, +but there are many tasks we did not cover here. Try [generating images with Stable +Diffusion](https://keras.io/guides/keras_hub/stable_diffusion_3_in_keras_hub/), for +example. + +The most significant advantage of KerasHub is it gives you the flexibility to combine pre-trained +building blocks with the full power of Keras 3. You can train large LLMs on TPUs with model +parallelism with the [keras.distribution](https://keras.io/guides/distribution/) API. You can +quantize models with Keras' [quatize +method](https://keras.io/examples/keras_recipes/float8_training_and_inference_with_transfo +rmer/). You can write custom training loops and even mix in direct Jax, Torch, or +Tensorflow calls. + +See [keras.io/keras_hub](https://keras.io/keras_hub/) for a full list of guides and +examples to continue digging into the library.