diff --git a/source/conf.py b/source/conf.py index 0fff7535..06cb1400 100644 --- a/source/conf.py +++ b/source/conf.py @@ -72,7 +72,7 @@ "rapids_admonitions", ] -myst_enable_extensions = ["colon_fence"] +myst_enable_extensions = ["colon_fence", "dollarmath"] # Add any paths that contain templates here, relative to this directory. templates_path = ["_templates"] diff --git a/source/examples/index.md b/source/examples/index.md index b89d250f..450b3955 100644 --- a/source/examples/index.md +++ b/source/examples/index.md @@ -14,5 +14,6 @@ rapids-ec2-mnmg/notebook rapids-autoscaling-multi-tenant-kubernetes/notebook xgboost-randomforest-gpu-hpo-dask/notebook rapids-azureml-hpo/notebook +time-series-forecasting-with-hpo/notebook xgboost-rf-gpu-cpu-benchmark/notebook ``` diff --git a/source/examples/time-series-forecasting-with-hpo/notebook.ipynb b/source/examples/time-series-forecasting-with-hpo/notebook.ipynb new file mode 100644 index 00000000..129de3f9 --- /dev/null +++ b/source/examples/time-series-forecasting-with-hpo/notebook.ipynb @@ -0,0 +1,8032 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "671dd603-6b51-46b2-98b3-2b05c7c92c38", + "metadata": { + "tags": [ + "platform/kubernetes", + "cloud/gcp/gke", + "tools/dask-operator", + "workflow/hpo", + "workflow/xgboost", + "library/dask", + "library/dask-cuda", + "library/xgboost", + "library/optuna", + "data-storage/gcs" + ] + }, + "source": [ + "# Perform time series forecasting on Google Kubernetes Engine with NVIDIA GPUs" + ] + }, + { + "cell_type": "markdown", + "id": "e18814f8-5374-468f-9877-28275dbf20d6", + "metadata": {}, + "source": [ + "In this example, we will be looking at a real-world example of **time series forecasting** with data from [the M5 Forecasting Competition](https://www.kaggle.com/competitions/m5-forecasting-accuracy). Walmart provides historical sales data from multiple stores in three states, and our job is to predict the sales in a future 28-day period." + ] + }, + { + "cell_type": "markdown", + "id": "0af8dc16-70a9-41cc-b846-88e83a2aa660", + "metadata": {}, + "source": [ + "## Prerequisites" + ] + }, + { + "cell_type": "markdown", + "id": "9168f224-8549-41a7-a6a7-2023c83bb466", + "metadata": {}, + "source": [ + "### Prepare GKE cluster" + ] + }, + { + "cell_type": "markdown", + "id": "fde94ab8-123c-4a72-95db-86a2098247bb", + "metadata": {}, + "source": [ + "To run the example, you will need a working Google Kubernetes Engine (GKE) cluster with access to NVIDIA GPUs.\n", + "\n", + "````{docref} /cloud/gcp/gke\n", + "Set up a Google Kubernetes Engine (GKE) cluster with access to NVIDIA GPUs. Follow instructions in [Google Kubernetes Engine](../../cloud/gcp/gke).\n", + "````\n", + "\n", + "1. To ensure that the example runs smoothly, ensure that you have ample memory in your GPUs. This notebook has been tested with NVIDIA A100.\n", + "\n", + "2. Set up Dask-Kubernetes integration by following instructions in the following guides:\n", + "\n", + " * [Install the Dask-Kubernetes operator](https://kubernetes.dask.org/en/latest/operator_installation.html)\n", + " * [Install Kubeflow](https://www.kubeflow.org/docs/started/installing-kubeflow/)\n", + "\n", + "Kubeflow is not strictly necessary, but we highly recommend it, as Kubeflow gives you a nice notebook environment to run this notebook within the k8s cluster. (You may choose any method; we tested this example after installing Kubeflow from manifests.) When creating the notebook environment, use the following configuration:\n", + "\n", + "* 2 CPUs, 16 GiB of memory\n", + "* 1 NVIDIA GPU\n", + "* 40 GiB disk volume\n", + "\n", + "After uploading all the notebooks in the example, run this notebook (`notebook.ipynb`) in the notebook environment.\n", + "\n", + "Note: We will use the worker pods to speed up the training stage. The preprocessing steps will run solely on the scheduler node." + ] + }, + { + "cell_type": "markdown", + "id": "79b63586-119e-47ee-ba74-063cfca71fe0", + "metadata": {}, + "source": [ + "### Prepare a bucket in Google Cloud Storage" + ] + }, + { + "cell_type": "markdown", + "id": "f1c4ada1-60c5-4456-b82d-953ce499a2fa", + "metadata": {}, + "source": [ + "Create a new bucket in Google Cloud Storage. Make sure that the worker pods in the k8s cluster has read/write access to this bucket. This can be done in one of the following methods:\n", + "\n", + "1. Option 1: Specify an additional scope when provisioning the GKE cluster.\n", + "\n", + " When you are provisioning a new GKE cluster, add the `storage-rw` scope.\n", + " This option is only available if you are creating a new cluster from scratch. If you are using an exising GKE cluster, see Option 2.\n", + "\n", + " Example:\n", + "```\n", + "gcloud container clusters create my_new_cluster --accelerator type=nvidia-tesla-t4 \\\n", + " --machine-type n1-standard-32 --zone us-central1-c --release-channel stable \\\n", + " --num-nodes 5 --scopes=gke-default,storage-rw\n", + "```\n", + "\n", + "2. Option 2: Grant bucket access to the associated service account.\n", + "\n", + " Find out which service account is associated with your GKE cluster. You can grant the bucket access to the service account as follows: Nagivate to the Cloud Storage console, open the Bucket Details page for the bucket, open the Permissions tab, and click on Grant Access.\n", + " \n", + "Enter the name of the bucket that your cluster has read-write access to:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "45bc19b6-f5f9-4f55-827b-db9484c125ce", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "bucket_name = \"\"" + ] + }, + { + "cell_type": "markdown", + "id": "704898b3-e2b0-4b40-bcd3-178a0bdeee6f", + "metadata": {}, + "source": [ + "### Install Python packages in the notebook environment" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cdf7b111-3aba-4fae-b805-fb3063d5a621", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "!pip install kaggle gcsfs dask-kubernetes optuna" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "13737fe2-3df5-4614-8675-fb20bccf7a19", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "[]" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Test if the bucket is accessible\n", + "import gcsfs\n", + "\n", + "fs = gcsfs.GCSFileSystem()\n", + "fs.ls(f\"{bucket_name}/\")" + ] + }, + { + "cell_type": "markdown", + "id": "42f2274c-90a0-484f-ad71-33ddba170f09", + "metadata": {}, + "source": [ + "## Obtain the time series data set from Kaggle" + ] + }, + { + "cell_type": "markdown", + "id": "21cde5ee-dad0-4ee2-9c26-7a36b6bcee49", + "metadata": {}, + "source": [ + "If you do not yet have an account with Kaggle, create one now. Then follow instructions in [Public API Documentation of Kaggle](https://www.kaggle.com/docs/api) to obtain the API key. This step is needed to obtain the training data from the M5 Forecasting Competition. Once you obtained the API key, fill in the following:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "8c2eb62a-a0f5-4801-9254-367fcef05e05", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "kaggle_username = \"\"\n", + "kaggle_api_key = \"\"" + ] + }, + { + "cell_type": "markdown", + "id": "64fd487f-0de1-4c41-ba02-7189493739f3", + "metadata": {}, + "source": [ + "Now we are ready to download the data set:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3c6fd3fe-b2d3-4bff-88f5-b86de68878a6", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "%env KAGGLE_USERNAME=$kaggle_username\n", + "%env KAGGLE_KEY=$kaggle_api_key\n", + "\n", + "!kaggle competitions download -c m5-forecasting-accuracy" + ] + }, + { + "cell_type": "markdown", + "id": "b7141384-e825-4387-a9f5-42b932e27e65", + "metadata": {}, + "source": [ + "Let's unzip the ZIP archive and see what's inside." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "b7bf2248-cd3b-47c5-8923-ec9a6e868a49", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import zipfile\n", + "\n", + "with zipfile.ZipFile(\"m5-forecasting-accuracy.zip\", \"r\") as zf:\n", + " zf.extractall(path=\"./data\")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "673468ae-9f6d-499f-86c4-230021bbf1b0", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "-rw-r--r-- 1 rapids conda 102K Sep 28 18:59 data/calendar.csv\n", + "-rw-r--r-- 1 rapids conda 117M Sep 28 18:59 data/sales_train_evaluation.csv\n", + "-rw-r--r-- 1 rapids conda 115M Sep 28 18:59 data/sales_train_validation.csv\n", + "-rw-r--r-- 1 rapids conda 5.0M Sep 28 18:59 data/sample_submission.csv\n", + "-rw-r--r-- 1 rapids conda 194M Sep 28 18:59 data/sell_prices.csv\n" + ] + } + ], + "source": [ + "!ls -lh data/*.csv" + ] + }, + { + "cell_type": "markdown", + "id": "f304ea68-381f-45b4-9e27-201a35e31239", + "metadata": {}, + "source": [ + "## Data preprocessing" + ] + }, + { + "cell_type": "markdown", + "id": "d9903e47-2a83-40d1-b65b-d5818e9f0647", + "metadata": {}, + "source": [ + "We are now ready to run the preprocessing steps." + ] + }, + { + "cell_type": "markdown", + "id": "be6740ef-538d-4624-aa00-95d0029ee90b", + "metadata": {}, + "source": [ + "### Import modules and define utility functions" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "2de6774b-9641-437e-8a03-f578c977a6ac", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import cudf\n", + "import numpy as np\n", + "import gc\n", + "import pathlib\n", + "import gcsfs\n", + "\n", + "\n", + "def sizeof_fmt(num, suffix=\"B\"):\n", + " for unit in [\"\", \"Ki\", \"Mi\", \"Gi\", \"Ti\", \"Pi\", \"Ei\", \"Zi\"]:\n", + " if abs(num) < 1024.0:\n", + " return f\"{num:3.1f}{unit}{suffix}\"\n", + " num /= 1024.0\n", + " return \"%.1f%s%s\" % (num, \"Yi\", suffix)\n", + "\n", + "\n", + "def report_dataframe_size(df, name):\n", + " print(\n", + " \"{} takes up {} memory on GPU\".format(\n", + " name, sizeof_fmt(grid_df.memory_usage(index=True).sum())\n", + " )\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "f87688b3-49e9-4d2b-a18c-3d4185c96b49", + "metadata": {}, + "source": [ + "### Load Data" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "a0724160-5394-4976-b5e4-2cb5bd672d59", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "TARGET = \"sales\" # Our main target\n", + "END_TRAIN = 1941 # Last day in train set" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "aebe04ff-8a3d-451d-a9ca-d7e7d7ecd739", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "raw_data_dir = pathlib.Path(\"./data/\")" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "914cf6bd-62b9-4bbc-aad2-1029e701d0cd", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "train_df = cudf.read_csv(raw_data_dir / \"sales_train_evaluation.csv\")\n", + "prices_df = cudf.read_csv(raw_data_dir / \"sell_prices.csv\")\n", + "calendar_df = cudf.read_csv(raw_data_dir / \"calendar.csv\").rename(\n", + " columns={\"d\": \"day_id\"}\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "4d9f912b-31e1-44a7-8613-92881e9e88f9", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
iditem_iddept_idcat_idstore_idstate_idd_1d_2d_3d_4...d_1932d_1933d_1934d_1935d_1936d_1937d_1938d_1939d_1940d_1941
0HOBBIES_1_001_CA_1_evaluationHOBBIES_1_001HOBBIES_1HOBBIESCA_1CA0000...2400003301
1HOBBIES_1_002_CA_1_evaluationHOBBIES_1_002HOBBIES_1HOBBIESCA_1CA0000...0121100000
2HOBBIES_1_003_CA_1_evaluationHOBBIES_1_003HOBBIES_1HOBBIESCA_1CA0000...1020002301
3HOBBIES_1_004_CA_1_evaluationHOBBIES_1_004HOBBIES_1HOBBIESCA_1CA0000...1104013026
4HOBBIES_1_005_CA_1_evaluationHOBBIES_1_005HOBBIES_1HOBBIESCA_1CA0000...0002100210
..................................................................
30485FOODS_3_823_WI_3_evaluationFOODS_3_823FOODS_3FOODSWI_3WI0022...1030110011
30486FOODS_3_824_WI_3_evaluationFOODS_3_824FOODS_3FOODSWI_3WI0000...0000001010
30487FOODS_3_825_WI_3_evaluationFOODS_3_825FOODS_3FOODSWI_3WI0602...0012010102
30488FOODS_3_826_WI_3_evaluationFOODS_3_826FOODS_3FOODSWI_3WI0000...1114601110
30489FOODS_3_827_WI_3_evaluationFOODS_3_827FOODS_3FOODSWI_3WI0000...1205402251
\n", + "

30490 rows × 1947 columns

\n", + "
" + ], + "text/plain": [ + " id item_id dept_id cat_id \\\n", + "0 HOBBIES_1_001_CA_1_evaluation HOBBIES_1_001 HOBBIES_1 HOBBIES \n", + "1 HOBBIES_1_002_CA_1_evaluation HOBBIES_1_002 HOBBIES_1 HOBBIES \n", + "2 HOBBIES_1_003_CA_1_evaluation HOBBIES_1_003 HOBBIES_1 HOBBIES \n", + "3 HOBBIES_1_004_CA_1_evaluation HOBBIES_1_004 HOBBIES_1 HOBBIES \n", + "4 HOBBIES_1_005_CA_1_evaluation HOBBIES_1_005 HOBBIES_1 HOBBIES \n", + "... ... ... ... ... \n", + "30485 FOODS_3_823_WI_3_evaluation FOODS_3_823 FOODS_3 FOODS \n", + "30486 FOODS_3_824_WI_3_evaluation FOODS_3_824 FOODS_3 FOODS \n", + "30487 FOODS_3_825_WI_3_evaluation FOODS_3_825 FOODS_3 FOODS \n", + "30488 FOODS_3_826_WI_3_evaluation FOODS_3_826 FOODS_3 FOODS \n", + "30489 FOODS_3_827_WI_3_evaluation FOODS_3_827 FOODS_3 FOODS \n", + "\n", + " store_id state_id d_1 d_2 d_3 d_4 ... d_1932 d_1933 d_1934 \\\n", + "0 CA_1 CA 0 0 0 0 ... 2 4 0 \n", + "1 CA_1 CA 0 0 0 0 ... 0 1 2 \n", + "2 CA_1 CA 0 0 0 0 ... 1 0 2 \n", + "3 CA_1 CA 0 0 0 0 ... 1 1 0 \n", + "4 CA_1 CA 0 0 0 0 ... 0 0 0 \n", + "... ... ... ... ... ... ... ... ... ... ... \n", + "30485 WI_3 WI 0 0 2 2 ... 1 0 3 \n", + "30486 WI_3 WI 0 0 0 0 ... 0 0 0 \n", + "30487 WI_3 WI 0 6 0 2 ... 0 0 1 \n", + "30488 WI_3 WI 0 0 0 0 ... 1 1 1 \n", + "30489 WI_3 WI 0 0 0 0 ... 1 2 0 \n", + "\n", + " d_1935 d_1936 d_1937 d_1938 d_1939 d_1940 d_1941 \n", + "0 0 0 0 3 3 0 1 \n", + "1 1 1 0 0 0 0 0 \n", + "2 0 0 0 2 3 0 1 \n", + "3 4 0 1 3 0 2 6 \n", + "4 2 1 0 0 2 1 0 \n", + "... ... ... ... ... ... ... ... \n", + "30485 0 1 1 0 0 1 1 \n", + "30486 0 0 0 1 0 1 0 \n", + "30487 2 0 1 0 1 0 2 \n", + "30488 4 6 0 1 1 1 0 \n", + "30489 5 4 0 2 2 5 1 \n", + "\n", + "[30490 rows x 1947 columns]" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_df" + ] + }, + { + "cell_type": "markdown", + "id": "7f834bdb-0c99-4d9f-b207-ac1f423756c9", + "metadata": {}, + "source": [ + "The columns `d_1`, `d_2`, ..., `d_1941` indicate the sales data at days 1, 2, ..., 1941 from 2011-01-29." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "2af6b2d9-14f9-42e4-81f9-1ad2fe3b92fe", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
store_iditem_idwm_yr_wksell_price
0CA_1HOBBIES_1_001113259.58
1CA_1HOBBIES_1_001113269.58
2CA_1HOBBIES_1_001113278.26
3CA_1HOBBIES_1_001113288.26
4CA_1HOBBIES_1_001113298.26
...............
6841116WI_3FOODS_3_827116171.00
6841117WI_3FOODS_3_827116181.00
6841118WI_3FOODS_3_827116191.00
6841119WI_3FOODS_3_827116201.00
6841120WI_3FOODS_3_827116211.00
\n", + "

6841121 rows × 4 columns

\n", + "
" + ], + "text/plain": [ + " store_id item_id wm_yr_wk sell_price\n", + "0 CA_1 HOBBIES_1_001 11325 9.58\n", + "1 CA_1 HOBBIES_1_001 11326 9.58\n", + "2 CA_1 HOBBIES_1_001 11327 8.26\n", + "3 CA_1 HOBBIES_1_001 11328 8.26\n", + "4 CA_1 HOBBIES_1_001 11329 8.26\n", + "... ... ... ... ...\n", + "6841116 WI_3 FOODS_3_827 11617 1.00\n", + "6841117 WI_3 FOODS_3_827 11618 1.00\n", + "6841118 WI_3 FOODS_3_827 11619 1.00\n", + "6841119 WI_3 FOODS_3_827 11620 1.00\n", + "6841120 WI_3 FOODS_3_827 11621 1.00\n", + "\n", + "[6841121 rows x 4 columns]" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "prices_df" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "34eb6624-3960-4b80-a92a-82c26ea4d974", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
datewm_yr_wkweekdaywdaymonthyearday_idevent_name_1event_type_1event_name_2event_type_2snap_CAsnap_TXsnap_WI
02011-01-2911101Saturday112011d_1<NA><NA><NA><NA>000
12011-01-3011101Sunday212011d_2<NA><NA><NA><NA>000
22011-01-3111101Monday312011d_3<NA><NA><NA><NA>000
32011-02-0111101Tuesday422011d_4<NA><NA><NA><NA>110
42011-02-0211101Wednesday522011d_5<NA><NA><NA><NA>101
.............................................
19642016-06-1511620Wednesday562016d_1965<NA><NA><NA><NA>011
19652016-06-1611620Thursday662016d_1966<NA><NA><NA><NA>000
19662016-06-1711620Friday762016d_1967<NA><NA><NA><NA>000
19672016-06-1811621Saturday162016d_1968<NA><NA><NA><NA>000
19682016-06-1911621Sunday262016d_1969NBAFinalsEndSportingFather's dayCultural000
\n", + "

1969 rows × 14 columns

\n", + "
" + ], + "text/plain": [ + " date wm_yr_wk weekday wday month year day_id \\\n", + "0 2011-01-29 11101 Saturday 1 1 2011 d_1 \n", + "1 2011-01-30 11101 Sunday 2 1 2011 d_2 \n", + "2 2011-01-31 11101 Monday 3 1 2011 d_3 \n", + "3 2011-02-01 11101 Tuesday 4 2 2011 d_4 \n", + "4 2011-02-02 11101 Wednesday 5 2 2011 d_5 \n", + "... ... ... ... ... ... ... ... \n", + "1964 2016-06-15 11620 Wednesday 5 6 2016 d_1965 \n", + "1965 2016-06-16 11620 Thursday 6 6 2016 d_1966 \n", + "1966 2016-06-17 11620 Friday 7 6 2016 d_1967 \n", + "1967 2016-06-18 11621 Saturday 1 6 2016 d_1968 \n", + "1968 2016-06-19 11621 Sunday 2 6 2016 d_1969 \n", + "\n", + " event_name_1 event_type_1 event_name_2 event_type_2 snap_CA snap_TX \\\n", + "0 0 0 \n", + "1 0 0 \n", + "2 0 0 \n", + "3 1 1 \n", + "4 1 0 \n", + "... ... ... ... ... ... ... \n", + "1964 0 1 \n", + "1965 0 0 \n", + "1966 0 0 \n", + "1967 0 0 \n", + "1968 NBAFinalsEnd Sporting Father's day Cultural 0 0 \n", + "\n", + " snap_WI \n", + "0 0 \n", + "1 0 \n", + "2 0 \n", + "3 0 \n", + "4 1 \n", + "... ... \n", + "1964 1 \n", + "1965 0 \n", + "1966 0 \n", + "1967 0 \n", + "1968 0 \n", + "\n", + "[1969 rows x 14 columns]" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "calendar_df" + ] + }, + { + "cell_type": "markdown", + "id": "c957eb3a-93a5-4800-bc9d-dfd3722c3ccd", + "metadata": {}, + "source": [ + "### Reformat sales times series data" + ] + }, + { + "cell_type": "markdown", + "id": "0db2e815-7de8-4496-b785-0f7d046f6fc1", + "metadata": {}, + "source": [ + "Pivot the columns `d_1`, `d_2`, ..., `d_1941` into separate rows using `cudf.melt`." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "5113af62-bea8-469c-bb3e-3bfc2e8805ee", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
iditem_iddept_idcat_idstore_idstate_idday_idsales
0HOBBIES_1_001_CA_1_evaluationHOBBIES_1_001HOBBIES_1HOBBIESCA_1CAd_10
1HOBBIES_1_002_CA_1_evaluationHOBBIES_1_002HOBBIES_1HOBBIESCA_1CAd_10
2HOBBIES_1_003_CA_1_evaluationHOBBIES_1_003HOBBIES_1HOBBIESCA_1CAd_10
3HOBBIES_1_004_CA_1_evaluationHOBBIES_1_004HOBBIES_1HOBBIESCA_1CAd_10
4HOBBIES_1_005_CA_1_evaluationHOBBIES_1_005HOBBIES_1HOBBIESCA_1CAd_10
...........................
59181085FOODS_3_823_WI_3_evaluationFOODS_3_823FOODS_3FOODSWI_3WId_19411
59181086FOODS_3_824_WI_3_evaluationFOODS_3_824FOODS_3FOODSWI_3WId_19410
59181087FOODS_3_825_WI_3_evaluationFOODS_3_825FOODS_3FOODSWI_3WId_19412
59181088FOODS_3_826_WI_3_evaluationFOODS_3_826FOODS_3FOODSWI_3WId_19410
59181089FOODS_3_827_WI_3_evaluationFOODS_3_827FOODS_3FOODSWI_3WId_19411
\n", + "

59181090 rows × 8 columns

\n", + "
" + ], + "text/plain": [ + " id item_id dept_id cat_id \\\n", + "0 HOBBIES_1_001_CA_1_evaluation HOBBIES_1_001 HOBBIES_1 HOBBIES \n", + "1 HOBBIES_1_002_CA_1_evaluation HOBBIES_1_002 HOBBIES_1 HOBBIES \n", + "2 HOBBIES_1_003_CA_1_evaluation HOBBIES_1_003 HOBBIES_1 HOBBIES \n", + "3 HOBBIES_1_004_CA_1_evaluation HOBBIES_1_004 HOBBIES_1 HOBBIES \n", + "4 HOBBIES_1_005_CA_1_evaluation HOBBIES_1_005 HOBBIES_1 HOBBIES \n", + "... ... ... ... ... \n", + "59181085 FOODS_3_823_WI_3_evaluation FOODS_3_823 FOODS_3 FOODS \n", + "59181086 FOODS_3_824_WI_3_evaluation FOODS_3_824 FOODS_3 FOODS \n", + "59181087 FOODS_3_825_WI_3_evaluation FOODS_3_825 FOODS_3 FOODS \n", + "59181088 FOODS_3_826_WI_3_evaluation FOODS_3_826 FOODS_3 FOODS \n", + "59181089 FOODS_3_827_WI_3_evaluation FOODS_3_827 FOODS_3 FOODS \n", + "\n", + " store_id state_id day_id sales \n", + "0 CA_1 CA d_1 0 \n", + "1 CA_1 CA d_1 0 \n", + "2 CA_1 CA d_1 0 \n", + "3 CA_1 CA d_1 0 \n", + "4 CA_1 CA d_1 0 \n", + "... ... ... ... ... \n", + "59181085 WI_3 WI d_1941 1 \n", + "59181086 WI_3 WI d_1941 0 \n", + "59181087 WI_3 WI d_1941 2 \n", + "59181088 WI_3 WI d_1941 0 \n", + "59181089 WI_3 WI d_1941 1 \n", + "\n", + "[59181090 rows x 8 columns]" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "index_columns = [\"id\", \"item_id\", \"dept_id\", \"cat_id\", \"store_id\", \"state_id\"]\n", + "grid_df = cudf.melt(\n", + " train_df, id_vars=index_columns, var_name=\"day_id\", value_name=TARGET\n", + ")\n", + "grid_df" + ] + }, + { + "cell_type": "markdown", + "id": "6b0fc482-3d30-4d8a-be62-caf7a3e3b9a1", + "metadata": {}, + "source": [ + "For each time series, add 28 rows that corresponds to the future forecast horizon:" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "e7c95c1e-1f33-4a53-b390-f37cffd398d1", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
iditem_iddept_idcat_idstore_idstate_idday_idsales
0HOBBIES_1_001_CA_1_evaluationHOBBIES_1_001HOBBIES_1HOBBIESCA_1CAd_10.0
1HOBBIES_1_002_CA_1_evaluationHOBBIES_1_002HOBBIES_1HOBBIESCA_1CAd_10.0
2HOBBIES_1_003_CA_1_evaluationHOBBIES_1_003HOBBIES_1HOBBIESCA_1CAd_10.0
3HOBBIES_1_004_CA_1_evaluationHOBBIES_1_004HOBBIES_1HOBBIESCA_1CAd_10.0
4HOBBIES_1_005_CA_1_evaluationHOBBIES_1_005HOBBIES_1HOBBIESCA_1CAd_10.0
...........................
60034805FOODS_3_823_WI_3_evaluationFOODS_3_823FOODS_3FOODSWI_3WId_1969NaN
60034806FOODS_3_824_WI_3_evaluationFOODS_3_824FOODS_3FOODSWI_3WId_1969NaN
60034807FOODS_3_825_WI_3_evaluationFOODS_3_825FOODS_3FOODSWI_3WId_1969NaN
60034808FOODS_3_826_WI_3_evaluationFOODS_3_826FOODS_3FOODSWI_3WId_1969NaN
60034809FOODS_3_827_WI_3_evaluationFOODS_3_827FOODS_3FOODSWI_3WId_1969NaN
\n", + "

60034810 rows × 8 columns

\n", + "
" + ], + "text/plain": [ + " id item_id dept_id cat_id \\\n", + "0 HOBBIES_1_001_CA_1_evaluation HOBBIES_1_001 HOBBIES_1 HOBBIES \n", + "1 HOBBIES_1_002_CA_1_evaluation HOBBIES_1_002 HOBBIES_1 HOBBIES \n", + "2 HOBBIES_1_003_CA_1_evaluation HOBBIES_1_003 HOBBIES_1 HOBBIES \n", + "3 HOBBIES_1_004_CA_1_evaluation HOBBIES_1_004 HOBBIES_1 HOBBIES \n", + "4 HOBBIES_1_005_CA_1_evaluation HOBBIES_1_005 HOBBIES_1 HOBBIES \n", + "... ... ... ... ... \n", + "60034805 FOODS_3_823_WI_3_evaluation FOODS_3_823 FOODS_3 FOODS \n", + "60034806 FOODS_3_824_WI_3_evaluation FOODS_3_824 FOODS_3 FOODS \n", + "60034807 FOODS_3_825_WI_3_evaluation FOODS_3_825 FOODS_3 FOODS \n", + "60034808 FOODS_3_826_WI_3_evaluation FOODS_3_826 FOODS_3 FOODS \n", + "60034809 FOODS_3_827_WI_3_evaluation FOODS_3_827 FOODS_3 FOODS \n", + "\n", + " store_id state_id day_id sales \n", + "0 CA_1 CA d_1 0.0 \n", + "1 CA_1 CA d_1 0.0 \n", + "2 CA_1 CA d_1 0.0 \n", + "3 CA_1 CA d_1 0.0 \n", + "4 CA_1 CA d_1 0.0 \n", + "... ... ... ... ... \n", + "60034805 WI_3 WI d_1969 NaN \n", + "60034806 WI_3 WI d_1969 NaN \n", + "60034807 WI_3 WI d_1969 NaN \n", + "60034808 WI_3 WI d_1969 NaN \n", + "60034809 WI_3 WI d_1969 NaN \n", + "\n", + "[60034810 rows x 8 columns]" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "add_grid = cudf.DataFrame()\n", + "for i in range(1, 29):\n", + " temp_df = train_df[index_columns]\n", + " temp_df = temp_df.drop_duplicates()\n", + " temp_df[\"day_id\"] = \"d_\" + str(END_TRAIN + i)\n", + " temp_df[TARGET] = np.nan # Sales amount at time (n + i) is unknown\n", + " add_grid = cudf.concat([add_grid, temp_df])\n", + "add_grid[\"day_id\"] = add_grid[\"day_id\"].astype(\n", + " \"category\"\n", + ") # The day_id column is categorical, after cudf.melt\n", + "\n", + "grid_df = cudf.concat([grid_df, add_grid])\n", + "grid_df = grid_df.reset_index(drop=True)\n", + "grid_df[\"sales\"] = grid_df[\"sales\"].astype(\n", + " np.float32\n", + ") # Use float32 type for sales column, to conserve memory\n", + "grid_df" + ] + }, + { + "cell_type": "markdown", + "id": "e250945a-e251-4b79-9fa8-25fffb96815b", + "metadata": {}, + "source": [ + "### Free up GPU memory\n", + "\n", + "GPU memory is a precious resource, so let's try to free up some memory. First, delete temporary variables we no longer need:" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "bae6a99d-1674-4937-aab4-7098eec87dff", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "8136" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Use xdel magic to scrub extra references from Jupyter notebook\n", + "%xdel temp_df\n", + "%xdel add_grid\n", + "%xdel train_df\n", + "\n", + "# Invoke the garbage collector explicitly to free up memory\n", + "gc.collect()" + ] + }, + { + "cell_type": "markdown", + "id": "b126fa0f-18fe-4ba2-9dad-ee278851f781", + "metadata": {}, + "source": [ + "Second, let's reduce the footprint of `grid_df` by converting strings into categoricals:" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "5906f5c1-54ba-47fc-977a-e6516076e333", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "grid_df takes up 5.2GiB memory on GPU\n" + ] + } + ], + "source": [ + "report_dataframe_size(grid_df, \"grid_df\")" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "0c6db032-8798-4cf0-a9db-d483c836ead5", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "id object\n", + "item_id object\n", + "dept_id object\n", + "cat_id object\n", + "store_id object\n", + "state_id object\n", + "day_id category\n", + "sales float32\n", + "dtype: object" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "grid_df.dtypes" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "82e650e3-c001-4c33-b75d-77897c7c9a2c", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "grid_df takes up 802.6MiB memory on GPU\n" + ] + } + ], + "source": [ + "for col in index_columns:\n", + " grid_df[col] = grid_df[col].astype(\"category\")\n", + " gc.collect()\n", + "report_dataframe_size(grid_df, \"grid_df\")" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "1f4f4421-75fa-471a-aa02-8521d2497b02", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "id category\n", + "item_id category\n", + "dept_id category\n", + "cat_id category\n", + "store_id category\n", + "state_id category\n", + "day_id category\n", + "sales float32\n", + "dtype: object" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "grid_df.dtypes" + ] + }, + { + "cell_type": "markdown", + "id": "b62205cf-3e18-4f4c-ae98-e818e510e931", + "metadata": {}, + "source": [ + "### Identify the release week of each product\n", + "\n", + "Each row in the `prices_df` table contains the price of a product sold at a store for a given week." + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "774b69c6-b3f5-44ee-b769-df78397c4b37", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
store_iditem_idwm_yr_wksell_price
0CA_1HOBBIES_1_001113259.58
1CA_1HOBBIES_1_001113269.58
2CA_1HOBBIES_1_001113278.26
3CA_1HOBBIES_1_001113288.26
4CA_1HOBBIES_1_001113298.26
...............
6841116WI_3FOODS_3_827116171.00
6841117WI_3FOODS_3_827116181.00
6841118WI_3FOODS_3_827116191.00
6841119WI_3FOODS_3_827116201.00
6841120WI_3FOODS_3_827116211.00
\n", + "

6841121 rows × 4 columns

\n", + "
" + ], + "text/plain": [ + " store_id item_id wm_yr_wk sell_price\n", + "0 CA_1 HOBBIES_1_001 11325 9.58\n", + "1 CA_1 HOBBIES_1_001 11326 9.58\n", + "2 CA_1 HOBBIES_1_001 11327 8.26\n", + "3 CA_1 HOBBIES_1_001 11328 8.26\n", + "4 CA_1 HOBBIES_1_001 11329 8.26\n", + "... ... ... ... ...\n", + "6841116 WI_3 FOODS_3_827 11617 1.00\n", + "6841117 WI_3 FOODS_3_827 11618 1.00\n", + "6841118 WI_3 FOODS_3_827 11619 1.00\n", + "6841119 WI_3 FOODS_3_827 11620 1.00\n", + "6841120 WI_3 FOODS_3_827 11621 1.00\n", + "\n", + "[6841121 rows x 4 columns]" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "prices_df" + ] + }, + { + "cell_type": "markdown", + "id": "47df7d5b-5697-4fa9-8ab8-b68ac3264259", + "metadata": {}, + "source": [ + "Notice that not all products were sold over every week. Some products were sold only during some weeks. Let's use the groupby operation to identify the first week in which each product went on the shelf." + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "c5d1d2e6-6f8c-4150-9c11-e5519c41a461", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
store_iditem_idrelease_week
0CA_4FOODS_3_52911421
1TX_1HOUSEHOLD_1_40911230
2WI_2FOODS_3_14511214
3CA_4HOUSEHOLD_1_49411106
4WI_3HOBBIES_1_09311223
............
30485CA_3HOUSEHOLD_1_36911205
30486CA_2FOODS_3_10911101
30487CA_4FOODS_2_11911101
30488CA_4HOUSEHOLD_2_38411110
30489WI_3HOBBIES_1_13511328
\n", + "

30490 rows × 3 columns

\n", + "
" + ], + "text/plain": [ + " store_id item_id release_week\n", + "0 CA_4 FOODS_3_529 11421\n", + "1 TX_1 HOUSEHOLD_1_409 11230\n", + "2 WI_2 FOODS_3_145 11214\n", + "3 CA_4 HOUSEHOLD_1_494 11106\n", + "4 WI_3 HOBBIES_1_093 11223\n", + "... ... ... ...\n", + "30485 CA_3 HOUSEHOLD_1_369 11205\n", + "30486 CA_2 FOODS_3_109 11101\n", + "30487 CA_4 FOODS_2_119 11101\n", + "30488 CA_4 HOUSEHOLD_2_384 11110\n", + "30489 WI_3 HOBBIES_1_135 11328\n", + "\n", + "[30490 rows x 3 columns]" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "release_df = (\n", + " prices_df.groupby([\"store_id\", \"item_id\"])[\"wm_yr_wk\"].agg(\"min\").reset_index()\n", + ")\n", + "release_df.columns = [\"store_id\", \"item_id\", \"release_week\"]\n", + "release_df" + ] + }, + { + "cell_type": "markdown", + "id": "f4c4a98b-20fe-494a-8b42-b292c8c35f53", + "metadata": {}, + "source": [ + "Now that we've computed the release week for each product, let's merge it back to `grid_df`:" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "96439c30-64fe-4f8e-a872-0f2f1982a48f", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
iditem_iddept_idcat_idstore_idstate_idday_idsalesrelease_week
0FOODS_1_001_CA_1_evaluationFOODS_1_001FOODS_1FOODSCA_1CAd_13.011101
1FOODS_1_001_CA_1_evaluationFOODS_1_001FOODS_1FOODSCA_1CAd_20.011101
2FOODS_1_001_CA_1_evaluationFOODS_1_001FOODS_1FOODSCA_1CAd_30.011101
3FOODS_1_001_CA_1_evaluationFOODS_1_001FOODS_1FOODSCA_1CAd_41.011101
4FOODS_1_001_CA_1_evaluationFOODS_1_001FOODS_1FOODSCA_1CAd_54.011101
..............................
60034805HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_1965NaN11101
60034806HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_1966NaN11101
60034807HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_1967NaN11101
60034808HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_1968NaN11101
60034809HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_1969NaN11101
\n", + "

60034810 rows × 9 columns

\n", + "
" + ], + "text/plain": [ + " id item_id dept_id \\\n", + "0 FOODS_1_001_CA_1_evaluation FOODS_1_001 FOODS_1 \n", + "1 FOODS_1_001_CA_1_evaluation FOODS_1_001 FOODS_1 \n", + "2 FOODS_1_001_CA_1_evaluation FOODS_1_001 FOODS_1 \n", + "3 FOODS_1_001_CA_1_evaluation FOODS_1_001 FOODS_1 \n", + "4 FOODS_1_001_CA_1_evaluation FOODS_1_001 FOODS_1 \n", + "... ... ... ... \n", + "60034805 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "60034806 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "60034807 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "60034808 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "60034809 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "\n", + " cat_id store_id state_id day_id sales release_week \n", + "0 FOODS CA_1 CA d_1 3.0 11101 \n", + "1 FOODS CA_1 CA d_2 0.0 11101 \n", + "2 FOODS CA_1 CA d_3 0.0 11101 \n", + "3 FOODS CA_1 CA d_4 1.0 11101 \n", + "4 FOODS CA_1 CA d_5 4.0 11101 \n", + "... ... ... ... ... ... ... \n", + "60034805 HOUSEHOLD WI_3 WI d_1965 NaN 11101 \n", + "60034806 HOUSEHOLD WI_3 WI d_1966 NaN 11101 \n", + "60034807 HOUSEHOLD WI_3 WI d_1967 NaN 11101 \n", + "60034808 HOUSEHOLD WI_3 WI d_1968 NaN 11101 \n", + "60034809 HOUSEHOLD WI_3 WI d_1969 NaN 11101 \n", + "\n", + "[60034810 rows x 9 columns]" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "grid_df = grid_df.merge(release_df, on=[\"store_id\", \"item_id\"], how=\"left\")\n", + "grid_df = grid_df.sort_values(index_columns + [\"day_id\"]).reset_index(drop=True)\n", + "grid_df" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "18226d36-4e8c-4b49-ac99-b126790a21a4", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "139" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "del release_df # No longer needed\n", + "gc.collect()" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "5f1a3c6e-9411-42bc-8e2c-73f9abb7c09b", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "grid_df takes up 1.2GiB memory on GPU\n" + ] + } + ], + "source": [ + "report_dataframe_size(grid_df, \"grid_df\")" + ] + }, + { + "cell_type": "markdown", + "id": "cbfca2cc-c90f-4852-b9f0-d3882603a981", + "metadata": {}, + "source": [ + "### Filter out entries with zero sales\n", + "\n", + "We can further save space by dropping rows from `grid_df` that correspond to zero sales. Since each product doesn't go on the shelf until its release week, its sale must be zero during any week that's prior to the release week.\n", + "\n", + "To make use of this insight, we bring in the `wm_yr_wk` column from `calendar_df`:" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "acdf40b5-395d-4dc2-a620-8cdd9eec70cf", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
iditem_iddept_idcat_idstore_idstate_idday_idsalesrelease_weekwm_yr_wk
0FOODS_1_001_WI_2_evaluationFOODS_1_001FOODS_1FOODSWI_2WId_8090.01110111312
1FOODS_1_001_WI_2_evaluationFOODS_1_001FOODS_1FOODSWI_2WId_8100.01110111312
2FOODS_1_001_WI_2_evaluationFOODS_1_001FOODS_1FOODSWI_2WId_8112.01110111312
3FOODS_1_001_WI_2_evaluationFOODS_1_001FOODS_1FOODSWI_2WId_8120.01110111312
4FOODS_1_001_WI_2_evaluationFOODS_1_001FOODS_1FOODSWI_2WId_8131.01110111313
.................................
60034805HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_520.01110111108
60034806HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_530.01110111108
60034807HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_540.01110111108
60034808HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_550.01110111108
60034809HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_490.01110111107
\n", + "

60034810 rows × 10 columns

\n", + "
" + ], + "text/plain": [ + " id item_id dept_id \\\n", + "0 FOODS_1_001_WI_2_evaluation FOODS_1_001 FOODS_1 \n", + "1 FOODS_1_001_WI_2_evaluation FOODS_1_001 FOODS_1 \n", + "2 FOODS_1_001_WI_2_evaluation FOODS_1_001 FOODS_1 \n", + "3 FOODS_1_001_WI_2_evaluation FOODS_1_001 FOODS_1 \n", + "4 FOODS_1_001_WI_2_evaluation FOODS_1_001 FOODS_1 \n", + "... ... ... ... \n", + "60034805 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "60034806 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "60034807 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "60034808 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "60034809 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "\n", + " cat_id store_id state_id day_id sales release_week wm_yr_wk \n", + "0 FOODS WI_2 WI d_809 0.0 11101 11312 \n", + "1 FOODS WI_2 WI d_810 0.0 11101 11312 \n", + "2 FOODS WI_2 WI d_811 2.0 11101 11312 \n", + "3 FOODS WI_2 WI d_812 0.0 11101 11312 \n", + "4 FOODS WI_2 WI d_813 1.0 11101 11313 \n", + "... ... ... ... ... ... ... ... \n", + "60034805 HOUSEHOLD WI_3 WI d_52 0.0 11101 11108 \n", + "60034806 HOUSEHOLD WI_3 WI d_53 0.0 11101 11108 \n", + "60034807 HOUSEHOLD WI_3 WI d_54 0.0 11101 11108 \n", + "60034808 HOUSEHOLD WI_3 WI d_55 0.0 11101 11108 \n", + "60034809 HOUSEHOLD WI_3 WI d_49 0.0 11101 11107 \n", + "\n", + "[60034810 rows x 10 columns]" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "grid_df = grid_df.merge(calendar_df[[\"wm_yr_wk\", \"day_id\"]], on=[\"day_id\"], how=\"left\")\n", + "grid_df" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "43705184-96ac-420c-94b4-ac932558484a", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "grid_df takes up 1.7GiB memory on GPU\n" + ] + } + ], + "source": [ + "report_dataframe_size(grid_df, \"grid_df\")" + ] + }, + { + "cell_type": "markdown", + "id": "98a935f3-d410-4f45-a745-8c875ea31a39", + "metadata": {}, + "source": [ + "The `wm_yr_wk` column identifies the week that contains the day given by the `day_id` column. Now let's filter all rows in `grid_df` for which `wm_yr_wk` is less than `release_week`:" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "d696b5dd-fbf8-4875-903a-db50b166cfe6", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
iditem_iddept_idcat_idstore_idstate_idday_idsalesrelease_weekwm_yr_wk
6766FOODS_1_002_TX_1_evaluationFOODS_1_002FOODS_1FOODSTX_1TXd_10.01110211101
6767FOODS_1_002_TX_1_evaluationFOODS_1_002FOODS_1FOODSTX_1TXd_20.01110211101
19686FOODS_1_001_TX_3_evaluationFOODS_1_001FOODS_1FOODSTX_3TXd_10.01110211101
19687FOODS_1_001_TX_3_evaluationFOODS_1_001FOODS_1FOODSTX_3TXd_20.01110211101
19688FOODS_1_001_TX_3_evaluationFOODS_1_001FOODS_1FOODSTX_3TXd_30.01110211101
.................................
60033493HOUSEHOLD_2_516_WI_2_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_2WId_200.01110611103
60033494HOUSEHOLD_2_516_WI_2_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_2WId_210.01110611103
60033495HOUSEHOLD_2_516_WI_2_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_2WId_220.01110611104
60033496HOUSEHOLD_2_516_WI_2_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_2WId_230.01110611104
60033497HOUSEHOLD_2_516_WI_2_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_2WId_240.01110611104
\n", + "

12299413 rows × 10 columns

\n", + "
" + ], + "text/plain": [ + " id item_id dept_id \\\n", + "6766 FOODS_1_002_TX_1_evaluation FOODS_1_002 FOODS_1 \n", + "6767 FOODS_1_002_TX_1_evaluation FOODS_1_002 FOODS_1 \n", + "19686 FOODS_1_001_TX_3_evaluation FOODS_1_001 FOODS_1 \n", + "19687 FOODS_1_001_TX_3_evaluation FOODS_1_001 FOODS_1 \n", + "19688 FOODS_1_001_TX_3_evaluation FOODS_1_001 FOODS_1 \n", + "... ... ... ... \n", + "60033493 HOUSEHOLD_2_516_WI_2_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "60033494 HOUSEHOLD_2_516_WI_2_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "60033495 HOUSEHOLD_2_516_WI_2_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "60033496 HOUSEHOLD_2_516_WI_2_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "60033497 HOUSEHOLD_2_516_WI_2_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "\n", + " cat_id store_id state_id day_id sales release_week wm_yr_wk \n", + "6766 FOODS TX_1 TX d_1 0.0 11102 11101 \n", + "6767 FOODS TX_1 TX d_2 0.0 11102 11101 \n", + "19686 FOODS TX_3 TX d_1 0.0 11102 11101 \n", + "19687 FOODS TX_3 TX d_2 0.0 11102 11101 \n", + "19688 FOODS TX_3 TX d_3 0.0 11102 11101 \n", + "... ... ... ... ... ... ... ... \n", + "60033493 HOUSEHOLD WI_2 WI d_20 0.0 11106 11103 \n", + "60033494 HOUSEHOLD WI_2 WI d_21 0.0 11106 11103 \n", + "60033495 HOUSEHOLD WI_2 WI d_22 0.0 11106 11104 \n", + "60033496 HOUSEHOLD WI_2 WI d_23 0.0 11106 11104 \n", + "60033497 HOUSEHOLD WI_2 WI d_24 0.0 11106 11104 \n", + "\n", + "[12299413 rows x 10 columns]" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df = grid_df[grid_df[\"wm_yr_wk\"] < grid_df[\"release_week\"]]\n", + "df" + ] + }, + { + "cell_type": "markdown", + "id": "b76d6141-1f3e-43ec-9f2e-ad7da2ab32bd", + "metadata": {}, + "source": [ + "As we suspected, the sales amount is zero during weeks that come before the release week." + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "515404e3-ef4b-459f-a24e-00f3cf084151", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "assert (df[\"sales\"] == 0).all()" + ] + }, + { + "cell_type": "markdown", + "id": "42213544-9dd4-4045-88a5-215319c6ce5f", + "metadata": {}, + "source": [ + "For the purpose of our data analysis, we can safely drop the rows with zero sales:" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "2b697475-41c2-4a9c-bc00-c23085b71edc", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
iditem_iddept_idcat_idstore_idstate_idday_idsalesrelease_weekwm_yr_wk
0FOODS_1_001_WI_2_evaluationFOODS_1_001FOODS_1FOODSWI_2WId_8090.01110111312
1FOODS_1_001_WI_2_evaluationFOODS_1_001FOODS_1FOODSWI_2WId_8100.01110111312
2FOODS_1_001_WI_2_evaluationFOODS_1_001FOODS_1FOODSWI_2WId_8112.01110111312
3FOODS_1_001_WI_2_evaluationFOODS_1_001FOODS_1FOODSWI_2WId_8120.01110111312
4FOODS_1_001_WI_2_evaluationFOODS_1_001FOODS_1FOODSWI_2WId_8131.01110111313
.................................
47735392HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_520.01110111108
47735393HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_530.01110111108
47735394HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_540.01110111108
47735395HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_550.01110111108
47735396HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_490.01110111107
\n", + "

47735397 rows × 10 columns

\n", + "
" + ], + "text/plain": [ + " id item_id dept_id \\\n", + "0 FOODS_1_001_WI_2_evaluation FOODS_1_001 FOODS_1 \n", + "1 FOODS_1_001_WI_2_evaluation FOODS_1_001 FOODS_1 \n", + "2 FOODS_1_001_WI_2_evaluation FOODS_1_001 FOODS_1 \n", + "3 FOODS_1_001_WI_2_evaluation FOODS_1_001 FOODS_1 \n", + "4 FOODS_1_001_WI_2_evaluation FOODS_1_001 FOODS_1 \n", + "... ... ... ... \n", + "47735392 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "47735393 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "47735394 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "47735395 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "47735396 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "\n", + " cat_id store_id state_id day_id sales release_week wm_yr_wk \n", + "0 FOODS WI_2 WI d_809 0.0 11101 11312 \n", + "1 FOODS WI_2 WI d_810 0.0 11101 11312 \n", + "2 FOODS WI_2 WI d_811 2.0 11101 11312 \n", + "3 FOODS WI_2 WI d_812 0.0 11101 11312 \n", + "4 FOODS WI_2 WI d_813 1.0 11101 11313 \n", + "... ... ... ... ... ... ... ... \n", + "47735392 HOUSEHOLD WI_3 WI d_52 0.0 11101 11108 \n", + "47735393 HOUSEHOLD WI_3 WI d_53 0.0 11101 11108 \n", + "47735394 HOUSEHOLD WI_3 WI d_54 0.0 11101 11108 \n", + "47735395 HOUSEHOLD WI_3 WI d_55 0.0 11101 11108 \n", + "47735396 HOUSEHOLD WI_3 WI d_49 0.0 11101 11107 \n", + "\n", + "[47735397 rows x 10 columns]" + ] + }, + "execution_count": 31, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "grid_df = grid_df[grid_df[\"wm_yr_wk\"] >= grid_df[\"release_week\"]].reset_index(drop=True)\n", + "grid_df[\"wm_yr_wk\"] = grid_df[\"wm_yr_wk\"].astype(\n", + " np.int32\n", + ") # Convert wm_yr_wk column to int32, to conserve memory\n", + "grid_df" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "e7db4f50-d61b-4a44-8e7e-fb5dcf2fdc9d", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "grid_df takes up 1.2GiB memory on GPU\n" + ] + } + ], + "source": [ + "report_dataframe_size(grid_df, \"grid_df\")" + ] + }, + { + "cell_type": "markdown", + "id": "b1ac8ea2-8430-4760-8fc1-0d3f45300eb0", + "metadata": {}, + "source": [ + "### Assign weights for product items\n", + "\n", + "When we assess the accuracy of our machine learning model, we should assign a weight for each product item, to indicate the relative importance of the item. For the M5 competition, the weights are computed from the total sales amount (in US dollars) in the lastest 28 days." + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "3cfffc9d-2767-42fa-b67a-5056671bd2f8", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
sales_usd_sum
item_id
FOODS_1_0013516.80
FOODS_1_00212418.80
FOODS_1_0035943.20
FOODS_1_00454184.82
FOODS_1_00517877.00
......
HOUSEHOLD_2_5126034.40
HOUSEHOLD_2_5132668.80
HOUSEHOLD_2_5149574.60
HOUSEHOLD_2_515630.40
HOUSEHOLD_2_5162574.00
\n", + "

3049 rows × 1 columns

\n", + "
" + ], + "text/plain": [ + " sales_usd_sum\n", + "item_id \n", + "FOODS_1_001 3516.80\n", + "FOODS_1_002 12418.80\n", + "FOODS_1_003 5943.20\n", + "FOODS_1_004 54184.82\n", + "FOODS_1_005 17877.00\n", + "... ...\n", + "HOUSEHOLD_2_512 6034.40\n", + "HOUSEHOLD_2_513 2668.80\n", + "HOUSEHOLD_2_514 9574.60\n", + "HOUSEHOLD_2_515 630.40\n", + "HOUSEHOLD_2_516 2574.00\n", + "\n", + "[3049 rows x 1 columns]" + ] + }, + "execution_count": 33, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Convert day_id to integers\n", + "grid_df[\"day_id_int\"] = grid_df[\"day_id\"].to_pandas().apply(lambda x: x[2:]).astype(int)\n", + "\n", + "# Compute the total sales over the latest 28 days, per product item\n", + "last28 = grid_df[(grid_df[\"day_id_int\"] >= 1914) & (grid_df[\"day_id_int\"] < 1942)]\n", + "last28 = last28[[\"item_id\", \"wm_yr_wk\", \"sales\"]].merge(\n", + " prices_df[[\"item_id\", \"wm_yr_wk\", \"sell_price\"]], on=[\"item_id\", \"wm_yr_wk\"]\n", + ")\n", + "last28[\"sales_usd\"] = last28[\"sales\"] * last28[\"sell_price\"]\n", + "total_sales_usd = last28.groupby(\"item_id\")[[\"sales_usd\"]].agg([\"sum\"]).sort_index()\n", + "total_sales_usd.columns = total_sales_usd.columns.map(\"_\".join)\n", + "total_sales_usd" + ] + }, + { + "cell_type": "markdown", + "id": "ea0c6f49-60af-48b5-b020-42d35b3cc2eb", + "metadata": {}, + "source": [ + "To obtain weights, we normalize the sales amount for one item by the total sales for all items." + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "id": "e651c504-de65-4765-b9b1-0087a5dab3b3", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
weights
item_id
FOODS_1_0010.000090
FOODS_1_0020.000318
FOODS_1_0030.000152
FOODS_1_0040.001389
FOODS_1_0050.000458
......
HOUSEHOLD_2_5120.000155
HOUSEHOLD_2_5130.000068
HOUSEHOLD_2_5140.000245
HOUSEHOLD_2_5150.000016
HOUSEHOLD_2_5160.000066
\n", + "

3049 rows × 1 columns

\n", + "
" + ], + "text/plain": [ + " weights\n", + "item_id \n", + "FOODS_1_001 0.000090\n", + "FOODS_1_002 0.000318\n", + "FOODS_1_003 0.000152\n", + "FOODS_1_004 0.001389\n", + "FOODS_1_005 0.000458\n", + "... ...\n", + "HOUSEHOLD_2_512 0.000155\n", + "HOUSEHOLD_2_513 0.000068\n", + "HOUSEHOLD_2_514 0.000245\n", + "HOUSEHOLD_2_515 0.000016\n", + "HOUSEHOLD_2_516 0.000066\n", + "\n", + "[3049 rows x 1 columns]" + ] + }, + "execution_count": 34, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "weights = total_sales_usd / total_sales_usd.sum()\n", + "weights = weights.rename(columns={\"sales_usd_sum\": \"weights\"})\n", + "weights" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "9882bfec-f7e2-45f9-97e4-93a836b28a0a", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# No longer needed\n", + "del grid_df[\"day_id_int\"]" + ] + }, + { + "cell_type": "markdown", + "id": "dd8706a0-32d5-4cb3-9a2e-bf42fcc17254", + "metadata": {}, + "source": [ + "### Generate price-related features\n", + "Let us engineer additional features that are related to the sale price. We consider the distribution of the price of a given product over time and ask how the current price compares to the historical trend." + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "id": "451c180c-13a0-4517-a25e-97557bb12718", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Highest price over all weeks\n", + "prices_df[\"price_max\"] = prices_df.groupby([\"store_id\", \"item_id\"])[\n", + " \"sell_price\"\n", + "].transform(\"max\")\n", + "# Lowest price over all weeks\n", + "prices_df[\"price_min\"] = prices_df.groupby([\"store_id\", \"item_id\"])[\n", + " \"sell_price\"\n", + "].transform(\"min\")\n", + "# Standard deviation of the price\n", + "prices_df[\"price_std\"] = prices_df.groupby([\"store_id\", \"item_id\"])[\n", + " \"sell_price\"\n", + "].transform(\"std\")\n", + "# Mean (average) price over all weeks\n", + "prices_df[\"price_mean\"] = prices_df.groupby([\"store_id\", \"item_id\"])[\n", + " \"sell_price\"\n", + "].transform(\"mean\")" + ] + }, + { + "cell_type": "markdown", + "id": "ffbadd58-5836-4668-8011-0f110e333160", + "metadata": {}, + "source": [ + "We also consider the ratio of the current price to the max price." + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "id": "a737282a-a38b-497c-a611-64df7343d01a", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "prices_df[\"price_norm\"] = prices_df[\"sell_price\"] / prices_df[\"price_max\"]" + ] + }, + { + "cell_type": "markdown", + "id": "cdde0f24-4ae3-4352-a210-7c2a2a3ac503", + "metadata": {}, + "source": [ + "Some items have a very stable price, whereas other items respond to inflation quickly and rise in price. To capture the price elasticity, we count the number of unique price values for a given product over time." + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "id": "7369eb2f-3055-47e6-a9ed-cb225697e2e5", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "prices_df[\"price_nunique\"] = prices_df.groupby([\"store_id\", \"item_id\"])[\n", + " \"sell_price\"\n", + "].transform(\"nunique\")" + ] + }, + { + "cell_type": "markdown", + "id": "96bdd61a-db51-4a65-935f-65153fcb5b4a", + "metadata": {}, + "source": [ + "We also consider, for a given price, how many other items are being sold at the exact same price." + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "id": "1c264e78-abca-411f-a8c7-0d42f3b3910d", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "prices_df[\"item_nunique\"] = prices_df.groupby([\"store_id\", \"sell_price\"])[\n", + " \"item_id\"\n", + "].transform(\"nunique\")" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "id": "d7eda764-0745-48c6-b368-33e93c6df618", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
store_iditem_idwm_yr_wksell_priceprice_maxprice_minprice_stdprice_meanprice_normprice_nuniqueitem_nunique
0CA_1HOBBIES_1_001113259.589.588.260.1521398.2857141.00000033
1CA_1HOBBIES_1_001113269.589.588.260.1521398.2857141.00000033
2CA_1HOBBIES_1_001113278.269.588.260.1521398.2857140.86221335
3CA_1HOBBIES_1_001113288.269.588.260.1521398.2857140.86221335
4CA_1HOBBIES_1_001113298.269.588.260.1521398.2857140.86221335
....................................
6841116WI_3FOODS_3_827116171.001.001.000.0000001.0000001.0000001142
6841117WI_3FOODS_3_827116181.001.001.000.0000001.0000001.0000001142
6841118WI_3FOODS_3_827116191.001.001.000.0000001.0000001.0000001142
6841119WI_3FOODS_3_827116201.001.001.000.0000001.0000001.0000001142
6841120WI_3FOODS_3_827116211.001.001.000.0000001.0000001.0000001142
\n", + "

6841121 rows × 11 columns

\n", + "
" + ], + "text/plain": [ + " store_id item_id wm_yr_wk sell_price price_max price_min \\\n", + "0 CA_1 HOBBIES_1_001 11325 9.58 9.58 8.26 \n", + "1 CA_1 HOBBIES_1_001 11326 9.58 9.58 8.26 \n", + "2 CA_1 HOBBIES_1_001 11327 8.26 9.58 8.26 \n", + "3 CA_1 HOBBIES_1_001 11328 8.26 9.58 8.26 \n", + "4 CA_1 HOBBIES_1_001 11329 8.26 9.58 8.26 \n", + "... ... ... ... ... ... ... \n", + "6841116 WI_3 FOODS_3_827 11617 1.00 1.00 1.00 \n", + "6841117 WI_3 FOODS_3_827 11618 1.00 1.00 1.00 \n", + "6841118 WI_3 FOODS_3_827 11619 1.00 1.00 1.00 \n", + "6841119 WI_3 FOODS_3_827 11620 1.00 1.00 1.00 \n", + "6841120 WI_3 FOODS_3_827 11621 1.00 1.00 1.00 \n", + "\n", + " price_std price_mean price_norm price_nunique item_nunique \n", + "0 0.152139 8.285714 1.000000 3 3 \n", + "1 0.152139 8.285714 1.000000 3 3 \n", + "2 0.152139 8.285714 0.862213 3 5 \n", + "3 0.152139 8.285714 0.862213 3 5 \n", + "4 0.152139 8.285714 0.862213 3 5 \n", + "... ... ... ... ... ... \n", + "6841116 0.000000 1.000000 1.000000 1 142 \n", + "6841117 0.000000 1.000000 1.000000 1 142 \n", + "6841118 0.000000 1.000000 1.000000 1 142 \n", + "6841119 0.000000 1.000000 1.000000 1 142 \n", + "6841120 0.000000 1.000000 1.000000 1 142 \n", + "\n", + "[6841121 rows x 11 columns]" + ] + }, + "execution_count": 40, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "prices_df" + ] + }, + { + "cell_type": "markdown", + "id": "56187024-fd76-4782-8e05-2ebab5daf3d3", + "metadata": {}, + "source": [ + "Another useful way to put prices in context is to compare the price of a product to its historical price a week ago, a month ago, or an year ago." + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "id": "d66e5c04-afaf-49d9-b7f6-fceb756939a9", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Add \"month\" and \"year\" columns to prices_df\n", + "week_to_month_map = calendar_df[[\"wm_yr_wk\", \"month\", \"year\"]].drop_duplicates(\n", + " subset=[\"wm_yr_wk\"]\n", + ")\n", + "prices_df = prices_df.merge(week_to_month_map, on=[\"wm_yr_wk\"], how=\"left\")\n", + "\n", + "# Sort by wm_yr_wk. The rows will also be sorted in ascending months and years.\n", + "prices_df = prices_df.sort_values([\"store_id\", \"item_id\", \"wm_yr_wk\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "id": "01a27a82-daf8-4492-b2ab-9e1d77ecbdfc", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Compare with the average price in the previous week\n", + "prices_df[\"price_momentum\"] = prices_df[\"sell_price\"] / prices_df.groupby(\n", + " [\"store_id\", \"item_id\"]\n", + ")[\"sell_price\"].shift(1)\n", + "# Compare with the average price in the previous month\n", + "prices_df[\"price_momentum_m\"] = prices_df[\"sell_price\"] / prices_df.groupby(\n", + " [\"store_id\", \"item_id\", \"month\"]\n", + ")[\"sell_price\"].transform(\"mean\")\n", + "# Compare with the average price in the previous year\n", + "prices_df[\"price_momentum_y\"] = prices_df[\"sell_price\"] / prices_df.groupby(\n", + " [\"store_id\", \"item_id\", \"year\"]\n", + ")[\"sell_price\"].transform(\"mean\")" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "id": "ad812423-bde6-4d1a-9677-222f81af6e0e", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Remove \"month\" and \"year\" columns, as we don't need them any more\n", + "del prices_df[\"month\"], prices_df[\"year\"]\n", + "\n", + "# Convert float64 columns into float32 type to save memory\n", + "columns = [\n", + " \"sell_price\",\n", + " \"price_max\",\n", + " \"price_min\",\n", + " \"price_std\",\n", + " \"price_mean\",\n", + " \"price_norm\",\n", + " \"price_momentum\",\n", + " \"price_momentum_m\",\n", + " \"price_momentum_y\",\n", + "]\n", + "for col in columns:\n", + " prices_df[col] = prices_df[col].astype(np.float32)" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "id": "9ad80497-bfd8-4421-ad09-dfde9d12e943", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "store_id object\n", + "item_id object\n", + "wm_yr_wk int64\n", + "sell_price float32\n", + "price_max float32\n", + "price_min float32\n", + "price_std float32\n", + "price_mean float32\n", + "price_norm float32\n", + "price_nunique int32\n", + "item_nunique int32\n", + "price_momentum float32\n", + "price_momentum_m float32\n", + "price_momentum_y float32\n", + "dtype: object" + ] + }, + "execution_count": 44, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "prices_df.dtypes" + ] + }, + { + "cell_type": "markdown", + "id": "04ddd07f-f9dc-4b01-ada3-d0ea9710024d", + "metadata": {}, + "source": [ + "### Bring in price-related features into `grid_df`" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "id": "9a5a5f3e-585d-462e-9c96-c7432461a458", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idday_idsell_priceprice_maxprice_minprice_stdprice_meanprice_normprice_nuniqueitem_nuniqueprice_momentumprice_momentum_mprice_momentum_y
0FOODS_1_002_TX_3_evaluationd_3628.889.487.884.849956e-018.9353190.9367093101.00.9761041.0
1FOODS_1_002_TX_3_evaluationd_3638.889.487.884.849956e-018.9353190.9367093101.00.9761041.0
2FOODS_1_002_TX_3_evaluationd_3648.889.487.884.849956e-018.9353190.9367093101.00.9761041.0
3FOODS_1_002_TX_3_evaluationd_3658.889.487.884.849956e-018.9353190.9367093101.00.9761041.0
4FOODS_1_002_TX_3_evaluationd_3668.889.487.884.849956e-018.9353190.9367093101.00.9761041.0
..........................................
47735392HOUSEHOLD_2_516_WI_2_evaluationd_6955.945.945.943.648122e-145.9400001.0000001471.01.0000001.0
47735393HOUSEHOLD_2_516_WI_2_evaluationd_6965.945.945.943.648122e-145.9400001.0000001471.01.0000001.0
47735394HOUSEHOLD_2_516_WI_2_evaluationd_6905.945.945.943.648122e-145.9400001.0000001471.01.0000001.0
47735395HOUSEHOLD_2_516_WI_2_evaluationd_6915.945.945.943.648122e-145.9400001.0000001471.01.0000001.0
47735396HOUSEHOLD_2_516_WI_2_evaluationd_6945.945.945.943.648122e-145.9400001.0000001471.01.0000001.0
\n", + "

47735397 rows × 13 columns

\n", + "
" + ], + "text/plain": [ + " id day_id sell_price price_max \\\n", + "0 FOODS_1_002_TX_3_evaluation d_362 8.88 9.48 \n", + "1 FOODS_1_002_TX_3_evaluation d_363 8.88 9.48 \n", + "2 FOODS_1_002_TX_3_evaluation d_364 8.88 9.48 \n", + "3 FOODS_1_002_TX_3_evaluation d_365 8.88 9.48 \n", + "4 FOODS_1_002_TX_3_evaluation d_366 8.88 9.48 \n", + "... ... ... ... ... \n", + "47735392 HOUSEHOLD_2_516_WI_2_evaluation d_695 5.94 5.94 \n", + "47735393 HOUSEHOLD_2_516_WI_2_evaluation d_696 5.94 5.94 \n", + "47735394 HOUSEHOLD_2_516_WI_2_evaluation d_690 5.94 5.94 \n", + "47735395 HOUSEHOLD_2_516_WI_2_evaluation d_691 5.94 5.94 \n", + "47735396 HOUSEHOLD_2_516_WI_2_evaluation d_694 5.94 5.94 \n", + "\n", + " price_min price_std price_mean price_norm price_nunique \\\n", + "0 7.88 4.849956e-01 8.935319 0.936709 3 \n", + "1 7.88 4.849956e-01 8.935319 0.936709 3 \n", + "2 7.88 4.849956e-01 8.935319 0.936709 3 \n", + "3 7.88 4.849956e-01 8.935319 0.936709 3 \n", + "4 7.88 4.849956e-01 8.935319 0.936709 3 \n", + "... ... ... ... ... ... \n", + "47735392 5.94 3.648122e-14 5.940000 1.000000 1 \n", + "47735393 5.94 3.648122e-14 5.940000 1.000000 1 \n", + "47735394 5.94 3.648122e-14 5.940000 1.000000 1 \n", + "47735395 5.94 3.648122e-14 5.940000 1.000000 1 \n", + "47735396 5.94 3.648122e-14 5.940000 1.000000 1 \n", + "\n", + " item_nunique price_momentum price_momentum_m price_momentum_y \n", + "0 10 1.0 0.976104 1.0 \n", + "1 10 1.0 0.976104 1.0 \n", + "2 10 1.0 0.976104 1.0 \n", + "3 10 1.0 0.976104 1.0 \n", + "4 10 1.0 0.976104 1.0 \n", + "... ... ... ... ... \n", + "47735392 47 1.0 1.000000 1.0 \n", + "47735393 47 1.0 1.000000 1.0 \n", + "47735394 47 1.0 1.000000 1.0 \n", + "47735395 47 1.0 1.000000 1.0 \n", + "47735396 47 1.0 1.000000 1.0 \n", + "\n", + "[47735397 rows x 13 columns]" + ] + }, + "execution_count": 45, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# After merging price_df, keep columns id and day_id from grid_df and drop all other columns from grid_df\n", + "original_columns = list(grid_df)\n", + "grid_df_with_price = grid_df.copy()\n", + "grid_df_with_price = grid_df_with_price.merge(\n", + " prices_df, on=[\"store_id\", \"item_id\", \"wm_yr_wk\"], how=\"left\"\n", + ")\n", + "columns_to_keep = [\"id\", \"day_id\"] + [\n", + " col for col in list(grid_df_with_price) if col not in original_columns\n", + "]\n", + "grid_df_with_price = grid_df_with_price[[\"id\", \"day_id\"] + columns_to_keep]\n", + "grid_df_with_price" + ] + }, + { + "cell_type": "markdown", + "id": "60d7cd4d-9cf8-4097-90a2-a46f698ce4b7", + "metadata": {}, + "source": [ + "### Generate date-related features\n", + "We identify the date in each row of `grid_df` using information from `calendar_df`." + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "id": "6bb4475a-787c-4ec4-ad2c-1c72ec369b18", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idday_iddateevent_name_1event_type_1event_name_2event_type_2snap_CAsnap_TXsnap_WI
0FOODS_1_002_WI_3_evaluationd_16782015-09-02<NA><NA><NA><NA>101
1FOODS_1_002_WI_3_evaluationd_16792015-09-03<NA><NA><NA><NA>111
2FOODS_1_002_WI_3_evaluationd_16802015-09-04<NA><NA><NA><NA>100
3FOODS_1_002_WI_3_evaluationd_16812015-09-05<NA><NA><NA><NA>111
4FOODS_1_002_WI_3_evaluationd_16822015-09-06<NA><NA><NA><NA>111
.................................
47735392HOUSEHOLD_2_516_WI_2_evaluationd_9772013-10-01<NA><NA><NA><NA>110
47735393HOUSEHOLD_2_516_WI_2_evaluationd_9782013-10-02<NA><NA><NA><NA>101
47735394HOUSEHOLD_2_516_WI_2_evaluationd_9792013-10-03<NA><NA><NA><NA>111
47735395HOUSEHOLD_2_516_WI_3_evaluationd_9032013-07-19<NA><NA><NA><NA>000
47735396HOUSEHOLD_2_516_WI_3_evaluationd_8882013-07-04IndependenceDayNational<NA><NA>100
\n", + "

47735397 rows × 10 columns

\n", + "
" + ], + "text/plain": [ + " id day_id date \\\n", + "0 FOODS_1_002_WI_3_evaluation d_1678 2015-09-02 \n", + "1 FOODS_1_002_WI_3_evaluation d_1679 2015-09-03 \n", + "2 FOODS_1_002_WI_3_evaluation d_1680 2015-09-04 \n", + "3 FOODS_1_002_WI_3_evaluation d_1681 2015-09-05 \n", + "4 FOODS_1_002_WI_3_evaluation d_1682 2015-09-06 \n", + "... ... ... ... \n", + "47735392 HOUSEHOLD_2_516_WI_2_evaluation d_977 2013-10-01 \n", + "47735393 HOUSEHOLD_2_516_WI_2_evaluation d_978 2013-10-02 \n", + "47735394 HOUSEHOLD_2_516_WI_2_evaluation d_979 2013-10-03 \n", + "47735395 HOUSEHOLD_2_516_WI_3_evaluation d_903 2013-07-19 \n", + "47735396 HOUSEHOLD_2_516_WI_3_evaluation d_888 2013-07-04 \n", + "\n", + " event_name_1 event_type_1 event_name_2 event_type_2 snap_CA \\\n", + "0 1 \n", + "1 1 \n", + "2 1 \n", + "3 1 \n", + "4 1 \n", + "... ... ... ... ... ... \n", + "47735392 1 \n", + "47735393 1 \n", + "47735394 1 \n", + "47735395 0 \n", + "47735396 IndependenceDay National 1 \n", + "\n", + " snap_TX snap_WI \n", + "0 0 1 \n", + "1 1 1 \n", + "2 0 0 \n", + "3 1 1 \n", + "4 1 1 \n", + "... ... ... \n", + "47735392 1 0 \n", + "47735393 0 1 \n", + "47735394 1 1 \n", + "47735395 0 0 \n", + "47735396 0 0 \n", + "\n", + "[47735397 rows x 10 columns]" + ] + }, + "execution_count": 46, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Bring in the following columns from calendar_df into grid_df\n", + "grid_df_id_only = grid_df[[\"id\", \"day_id\"]].copy()\n", + "\n", + "icols = [\n", + " \"date\",\n", + " \"day_id\",\n", + " \"event_name_1\",\n", + " \"event_type_1\",\n", + " \"event_name_2\",\n", + " \"event_type_2\",\n", + " \"snap_CA\",\n", + " \"snap_TX\",\n", + " \"snap_WI\",\n", + "]\n", + "grid_df_with_calendar = grid_df_id_only.merge(\n", + " calendar_df[icols], on=[\"day_id\"], how=\"left\"\n", + ")\n", + "grid_df_with_calendar" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "id": "4dfcf0ed-d461-48d4-acb9-6aa1d0fa973e", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Convert columns into categorical type to save memory\n", + "for col in [\n", + " \"event_name_1\",\n", + " \"event_type_1\",\n", + " \"event_name_2\",\n", + " \"event_type_2\",\n", + " \"snap_CA\",\n", + " \"snap_TX\",\n", + " \"snap_WI\",\n", + "]:\n", + " grid_df_with_calendar[col] = grid_df_with_calendar[col].astype(\"category\")\n", + "# Convert \"date\" column into timestamp type\n", + "grid_df_with_calendar[\"date\"] = cudf.to_datetime(grid_df_with_calendar[\"date\"])" + ] + }, + { + "cell_type": "markdown", + "id": "9ee43075-54f8-4c62-a173-1e016e9d9697", + "metadata": {}, + "source": [ + "Using the `date` column, we can generate related features, such as day, week, or month." + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "id": "7a530f96-67b8-49f3-9f45-a1533eeef2c5", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idday_idevent_name_1event_type_1event_name_2event_type_2snap_CAsnap_TXsnap_WItm_dtm_wtm_mtm_ytm_wmtm_dwtm_w_end
0FOODS_1_002_WI_3_evaluationd_1678<NA><NA><NA><NA>10123694120
1FOODS_1_002_WI_3_evaluationd_1679<NA><NA><NA><NA>11133694130
2FOODS_1_002_WI_3_evaluationd_1680<NA><NA><NA><NA>10043694140
3FOODS_1_002_WI_3_evaluationd_1681<NA><NA><NA><NA>11153694151
4FOODS_1_002_WI_3_evaluationd_1682<NA><NA><NA><NA>11163694161
...................................................
47735392HOUSEHOLD_2_516_WI_2_evaluationd_977<NA><NA><NA><NA>110140102110
47735393HOUSEHOLD_2_516_WI_2_evaluationd_978<NA><NA><NA><NA>101240102120
47735394HOUSEHOLD_2_516_WI_2_evaluationd_979<NA><NA><NA><NA>111340102130
47735395HOUSEHOLD_2_516_WI_3_evaluationd_903<NA><NA><NA><NA>000192972340
47735396HOUSEHOLD_2_516_WI_3_evaluationd_888IndependenceDayNational<NA><NA>10042772130
\n", + "

47735397 rows × 16 columns

\n", + "
" + ], + "text/plain": [ + " id day_id event_name_1 \\\n", + "0 FOODS_1_002_WI_3_evaluation d_1678 \n", + "1 FOODS_1_002_WI_3_evaluation d_1679 \n", + "2 FOODS_1_002_WI_3_evaluation d_1680 \n", + "3 FOODS_1_002_WI_3_evaluation d_1681 \n", + "4 FOODS_1_002_WI_3_evaluation d_1682 \n", + "... ... ... ... \n", + "47735392 HOUSEHOLD_2_516_WI_2_evaluation d_977 \n", + "47735393 HOUSEHOLD_2_516_WI_2_evaluation d_978 \n", + "47735394 HOUSEHOLD_2_516_WI_2_evaluation d_979 \n", + "47735395 HOUSEHOLD_2_516_WI_3_evaluation d_903 \n", + "47735396 HOUSEHOLD_2_516_WI_3_evaluation d_888 IndependenceDay \n", + "\n", + " event_type_1 event_name_2 event_type_2 snap_CA snap_TX snap_WI tm_d \\\n", + "0 1 0 1 2 \n", + "1 1 1 1 3 \n", + "2 1 0 0 4 \n", + "3 1 1 1 5 \n", + "4 1 1 1 6 \n", + "... ... ... ... ... ... ... ... \n", + "47735392 1 1 0 1 \n", + "47735393 1 0 1 2 \n", + "47735394 1 1 1 3 \n", + "47735395 0 0 0 19 \n", + "47735396 National 1 0 0 4 \n", + "\n", + " tm_w tm_m tm_y tm_wm tm_dw tm_w_end \n", + "0 36 9 4 1 2 0 \n", + "1 36 9 4 1 3 0 \n", + "2 36 9 4 1 4 0 \n", + "3 36 9 4 1 5 1 \n", + "4 36 9 4 1 6 1 \n", + "... ... ... ... ... ... ... \n", + "47735392 40 10 2 1 1 0 \n", + "47735393 40 10 2 1 2 0 \n", + "47735394 40 10 2 1 3 0 \n", + "47735395 29 7 2 3 4 0 \n", + "47735396 27 7 2 1 3 0 \n", + "\n", + "[47735397 rows x 16 columns]" + ] + }, + "execution_count": 48, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import cupy as cp\n", + "\n", + "grid_df_with_calendar[\"tm_d\"] = grid_df_with_calendar[\"date\"].dt.day.astype(np.int8)\n", + "grid_df_with_calendar[\"tm_w\"] = (\n", + " grid_df_with_calendar[\"date\"].dt.isocalendar().week.astype(np.int8)\n", + ")\n", + "grid_df_with_calendar[\"tm_m\"] = grid_df_with_calendar[\"date\"].dt.month.astype(np.int8)\n", + "grid_df_with_calendar[\"tm_y\"] = grid_df_with_calendar[\"date\"].dt.year\n", + "grid_df_with_calendar[\"tm_y\"] = (\n", + " grid_df_with_calendar[\"tm_y\"] - grid_df_with_calendar[\"tm_y\"].min()\n", + ").astype(np.int8)\n", + "grid_df_with_calendar[\"tm_wm\"] = cp.ceil(\n", + " grid_df_with_calendar[\"tm_d\"].to_cupy() / 7\n", + ").astype(\n", + " np.int8\n", + ") # which week in tje month?\n", + "grid_df_with_calendar[\"tm_dw\"] = grid_df_with_calendar[\"date\"].dt.dayofweek.astype(\n", + " np.int8\n", + ") # which day in the week?\n", + "grid_df_with_calendar[\"tm_w_end\"] = (grid_df_with_calendar[\"tm_dw\"] >= 5).astype(\n", + " np.int8\n", + ") # whether today is in the weekend\n", + "del grid_df_with_calendar[\"date\"] # no longer needed\n", + "\n", + "grid_df_with_calendar" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "id": "f1f05d6a-7d92-4613-92fb-57c55a01f3db", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "96" + ] + }, + "execution_count": 49, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "del grid_df_id_only # No longer needed\n", + "gc.collect()" + ] + }, + { + "cell_type": "markdown", + "id": "0cd47558-1960-4a06-b941-884c6b0464fa", + "metadata": {}, + "source": [ + "### Generate lag features\n", + "\n", + "**Lag features** are the value of the target variable at prior timestamps. Lag features are useful because what happens in the past often influences what would happen in the future. In our example, we generate lag features by reading the sales amount at X days prior, where X = 28, 29, ..., 42." + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "id": "18792591-f6e1-49d9-b3b9-721fae1675aa", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "SHIFT_DAY = 28\n", + "LAG_DAYS = [col for col in range(SHIFT_DAY, SHIFT_DAY + 15)]\n", + "\n", + "# Need to first ensure that rows in each time series are sorted by day_id\n", + "grid_df_lags = grid_df[[\"id\", \"day_id\", \"sales\"]].copy()\n", + "grid_df_lags = grid_df_lags.sort_values([\"id\", \"day_id\"])\n", + "\n", + "grid_df_lags = grid_df_lags.assign(\n", + " **{\n", + " f\"sales_lag_{l}\": grid_df_lags.groupby([\"id\"])[\"sales\"].shift(l)\n", + " for l in LAG_DAYS\n", + " }\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "id": "ed47d972-f6c9-4618-986a-08649dfbfcb8", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idday_idsalessales_lag_28sales_lag_29sales_lag_30sales_lag_31sales_lag_32sales_lag_33sales_lag_34sales_lag_35sales_lag_36sales_lag_37sales_lag_38sales_lag_39sales_lag_40sales_lag_41sales_lag_42
34023FOODS_1_001_CA_1_evaluationd_13.0<NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA>
34024FOODS_1_001_CA_1_evaluationd_20.0<NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA>
34025FOODS_1_001_CA_1_evaluationd_30.0<NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA>
34026FOODS_1_001_CA_1_evaluationd_41.0<NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA>
34027FOODS_1_001_CA_1_evaluationd_54.0<NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA>
.........................................................
47733744HOUSEHOLD_2_516_WI_3_evaluationd_1965NaN0.00.00.00.01.00.00.00.00.00.00.00.00.00.00.0
47733745HOUSEHOLD_2_516_WI_3_evaluationd_1966NaN0.00.00.00.00.01.00.00.00.00.00.00.00.00.00.0
47733746HOUSEHOLD_2_516_WI_3_evaluationd_1967NaN0.00.00.00.00.00.01.00.00.00.00.00.00.00.00.0
47733747HOUSEHOLD_2_516_WI_3_evaluationd_1968NaN0.00.00.00.00.00.00.01.00.00.00.00.00.00.00.0
47733748HOUSEHOLD_2_516_WI_3_evaluationd_1969NaN0.00.00.00.00.00.00.00.01.00.00.00.00.00.00.0
\n", + "

47735397 rows × 18 columns

\n", + "
" + ], + "text/plain": [ + " id day_id sales sales_lag_28 \\\n", + "34023 FOODS_1_001_CA_1_evaluation d_1 3.0 \n", + "34024 FOODS_1_001_CA_1_evaluation d_2 0.0 \n", + "34025 FOODS_1_001_CA_1_evaluation d_3 0.0 \n", + "34026 FOODS_1_001_CA_1_evaluation d_4 1.0 \n", + "34027 FOODS_1_001_CA_1_evaluation d_5 4.0 \n", + "... ... ... ... ... \n", + "47733744 HOUSEHOLD_2_516_WI_3_evaluation d_1965 NaN 0.0 \n", + "47733745 HOUSEHOLD_2_516_WI_3_evaluation d_1966 NaN 0.0 \n", + "47733746 HOUSEHOLD_2_516_WI_3_evaluation d_1967 NaN 0.0 \n", + "47733747 HOUSEHOLD_2_516_WI_3_evaluation d_1968 NaN 0.0 \n", + "47733748 HOUSEHOLD_2_516_WI_3_evaluation d_1969 NaN 0.0 \n", + "\n", + " sales_lag_29 sales_lag_30 sales_lag_31 sales_lag_32 sales_lag_33 \\\n", + "34023 \n", + "34024 \n", + "34025 \n", + "34026 \n", + "34027 \n", + "... ... ... ... ... ... \n", + "47733744 0.0 0.0 0.0 1.0 0.0 \n", + "47733745 0.0 0.0 0.0 0.0 1.0 \n", + "47733746 0.0 0.0 0.0 0.0 0.0 \n", + "47733747 0.0 0.0 0.0 0.0 0.0 \n", + "47733748 0.0 0.0 0.0 0.0 0.0 \n", + "\n", + " sales_lag_34 sales_lag_35 sales_lag_36 sales_lag_37 sales_lag_38 \\\n", + "34023 \n", + "34024 \n", + "34025 \n", + "34026 \n", + "34027 \n", + "... ... ... ... ... ... \n", + "47733744 0.0 0.0 0.0 0.0 0.0 \n", + "47733745 0.0 0.0 0.0 0.0 0.0 \n", + "47733746 1.0 0.0 0.0 0.0 0.0 \n", + "47733747 0.0 1.0 0.0 0.0 0.0 \n", + "47733748 0.0 0.0 1.0 0.0 0.0 \n", + "\n", + " sales_lag_39 sales_lag_40 sales_lag_41 sales_lag_42 \n", + "34023 \n", + "34024 \n", + "34025 \n", + "34026 \n", + "34027 \n", + "... ... ... ... ... \n", + "47733744 0.0 0.0 0.0 0.0 \n", + "47733745 0.0 0.0 0.0 0.0 \n", + "47733746 0.0 0.0 0.0 0.0 \n", + "47733747 0.0 0.0 0.0 0.0 \n", + "47733748 0.0 0.0 0.0 0.0 \n", + "\n", + "[47735397 rows x 18 columns]" + ] + }, + "execution_count": 51, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "grid_df_lags" + ] + }, + { + "cell_type": "markdown", + "id": "9e85f484-a5e9-4acd-a60b-f7cc407fe7a4", + "metadata": {}, + "source": [ + "### Compute rolling window statistics\n", + "\n", + "In the previous cell, we used the value of sales at a single timestamp to generate lag features. To capture richer information about the past, let us also get the distribution of the sales value over multiple timestamps, by computing **rolling window statistics**. Rolling window statistics are statistics (e.g. mean, standard deviation) over a time duration in the past. Rolling windows statistics complement lag features and provide more information about the past behavior of the target variable.\n", + "\n", + "Read more about lag features and rolling window statistics in [Introduction to feature engineering for time series forecasting](https://medium.com/data-science-at-microsoft/introduction-to-feature-engineering-for-time-series-forecasting-620aa55fcab0)." + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "id": "485b447e-213f-4091-a0df-874c4d35ac64", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Shift size: 28\n", + " Window size: 7\n", + " Window size: 14\n", + " Window size: 30\n", + " Window size: 60\n", + " Window size: 180\n" + ] + } + ], + "source": [ + "# Shift by 28 days and apply windows of various sizes\n", + "print(f\"Shift size: {SHIFT_DAY}\")\n", + "for i in [7, 14, 30, 60, 180]:\n", + " print(f\" Window size: {i}\")\n", + " grid_df_lags[f\"rolling_mean_{i}\"] = (\n", + " grid_df_lags.groupby([\"id\"])[\"sales\"]\n", + " .shift(SHIFT_DAY)\n", + " .rolling(i)\n", + " .mean()\n", + " .astype(np.float32)\n", + " )\n", + " grid_df_lags[f\"rolling_std_{i}\"] = (\n", + " grid_df_lags.groupby([\"id\"])[\"sales\"]\n", + " .shift(SHIFT_DAY)\n", + " .rolling(i)\n", + " .std()\n", + " .astype(np.float32)\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "id": "64d0e7bc-8c79-43be-9809-192807d57ada", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "Index(['id', 'day_id', 'sales', 'sales_lag_28', 'sales_lag_29', 'sales_lag_30',\n", + " 'sales_lag_31', 'sales_lag_32', 'sales_lag_33', 'sales_lag_34',\n", + " 'sales_lag_35', 'sales_lag_36', 'sales_lag_37', 'sales_lag_38',\n", + " 'sales_lag_39', 'sales_lag_40', 'sales_lag_41', 'sales_lag_42',\n", + " 'rolling_mean_7', 'rolling_std_7', 'rolling_mean_14', 'rolling_std_14',\n", + " 'rolling_mean_30', 'rolling_std_30', 'rolling_mean_60',\n", + " 'rolling_std_60', 'rolling_mean_180', 'rolling_std_180'],\n", + " dtype='object')" + ] + }, + "execution_count": 53, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "grid_df_lags.columns" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "id": "977f501a-f358-423c-9248-210405a0e51b", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "id category\n", + "day_id category\n", + "sales float32\n", + "sales_lag_28 float32\n", + "sales_lag_29 float32\n", + "sales_lag_30 float32\n", + "sales_lag_31 float32\n", + "sales_lag_32 float32\n", + "sales_lag_33 float32\n", + "sales_lag_34 float32\n", + "sales_lag_35 float32\n", + "sales_lag_36 float32\n", + "sales_lag_37 float32\n", + "sales_lag_38 float32\n", + "sales_lag_39 float32\n", + "sales_lag_40 float32\n", + "sales_lag_41 float32\n", + "sales_lag_42 float32\n", + "rolling_mean_7 float32\n", + "rolling_std_7 float32\n", + "rolling_mean_14 float32\n", + "rolling_std_14 float32\n", + "rolling_mean_30 float32\n", + "rolling_std_30 float32\n", + "rolling_mean_60 float32\n", + "rolling_std_60 float32\n", + "rolling_mean_180 float32\n", + "rolling_std_180 float32\n", + "dtype: object" + ] + }, + "execution_count": 54, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "grid_df_lags.dtypes" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "id": "22d8519f-73a5-4dde-be3a-1896986057e6", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idday_idsalessales_lag_28sales_lag_29sales_lag_30sales_lag_31sales_lag_32sales_lag_33sales_lag_34...rolling_mean_7rolling_std_7rolling_mean_14rolling_std_14rolling_mean_30rolling_std_30rolling_mean_60rolling_std_60rolling_mean_180rolling_std_180
34023FOODS_1_001_CA_1_evaluationd_13.0<NA><NA><NA><NA><NA><NA><NA>...<NA><NA><NA><NA><NA><NA><NA><NA><NA><NA>
34024FOODS_1_001_CA_1_evaluationd_20.0<NA><NA><NA><NA><NA><NA><NA>...<NA><NA><NA><NA><NA><NA><NA><NA><NA><NA>
34025FOODS_1_001_CA_1_evaluationd_30.0<NA><NA><NA><NA><NA><NA><NA>...<NA><NA><NA><NA><NA><NA><NA><NA><NA><NA>
34026FOODS_1_001_CA_1_evaluationd_41.0<NA><NA><NA><NA><NA><NA><NA>...<NA><NA><NA><NA><NA><NA><NA><NA><NA><NA>
34027FOODS_1_001_CA_1_evaluationd_54.0<NA><NA><NA><NA><NA><NA><NA>...<NA><NA><NA><NA><NA><NA><NA><NA><NA><NA>
..................................................................
47733744HOUSEHOLD_2_516_WI_3_evaluationd_1965NaN0.00.00.00.01.00.00.0...0.1428571490.3779644670.0714285750.2672612370.066666670.2537081240.0333333350.1810203340.0777777810.288621366
47733745HOUSEHOLD_2_516_WI_3_evaluationd_1966NaN0.00.00.00.00.01.00.0...0.1428571490.3779644670.0714285750.2672612370.066666670.2537081240.0333333350.1810203340.0777777810.288621366
47733746HOUSEHOLD_2_516_WI_3_evaluationd_1967NaN0.00.00.00.00.00.01.0...0.1428571490.3779644670.0714285750.2672612370.066666670.2537081240.0333333350.1810203340.0777777810.288621366
47733747HOUSEHOLD_2_516_WI_3_evaluationd_1968NaN0.00.00.00.00.00.00.0...0.00.00.0714285750.2672612370.066666670.2537081240.0333333350.1810203340.0777777810.288621366
47733748HOUSEHOLD_2_516_WI_3_evaluationd_1969NaN0.00.00.00.00.00.00.0...0.00.00.0714285750.2672612370.066666670.2537081240.0333333350.1810203340.0777777810.288621366
\n", + "

47735397 rows × 28 columns

\n", + "
" + ], + "text/plain": [ + " id day_id sales sales_lag_28 \\\n", + "34023 FOODS_1_001_CA_1_evaluation d_1 3.0 \n", + "34024 FOODS_1_001_CA_1_evaluation d_2 0.0 \n", + "34025 FOODS_1_001_CA_1_evaluation d_3 0.0 \n", + "34026 FOODS_1_001_CA_1_evaluation d_4 1.0 \n", + "34027 FOODS_1_001_CA_1_evaluation d_5 4.0 \n", + "... ... ... ... ... \n", + "47733744 HOUSEHOLD_2_516_WI_3_evaluation d_1965 NaN 0.0 \n", + "47733745 HOUSEHOLD_2_516_WI_3_evaluation d_1966 NaN 0.0 \n", + "47733746 HOUSEHOLD_2_516_WI_3_evaluation d_1967 NaN 0.0 \n", + "47733747 HOUSEHOLD_2_516_WI_3_evaluation d_1968 NaN 0.0 \n", + "47733748 HOUSEHOLD_2_516_WI_3_evaluation d_1969 NaN 0.0 \n", + "\n", + " sales_lag_29 sales_lag_30 sales_lag_31 sales_lag_32 sales_lag_33 \\\n", + "34023 \n", + "34024 \n", + "34025 \n", + "34026 \n", + "34027 \n", + "... ... ... ... ... ... \n", + "47733744 0.0 0.0 0.0 1.0 0.0 \n", + "47733745 0.0 0.0 0.0 0.0 1.0 \n", + "47733746 0.0 0.0 0.0 0.0 0.0 \n", + "47733747 0.0 0.0 0.0 0.0 0.0 \n", + "47733748 0.0 0.0 0.0 0.0 0.0 \n", + "\n", + " sales_lag_34 ... rolling_mean_7 rolling_std_7 rolling_mean_14 \\\n", + "34023 ... \n", + "34024 ... \n", + "34025 ... \n", + "34026 ... \n", + "34027 ... \n", + "... ... ... ... ... ... \n", + "47733744 0.0 ... 0.142857149 0.377964467 0.071428575 \n", + "47733745 0.0 ... 0.142857149 0.377964467 0.071428575 \n", + "47733746 1.0 ... 0.142857149 0.377964467 0.071428575 \n", + "47733747 0.0 ... 0.0 0.0 0.071428575 \n", + "47733748 0.0 ... 0.0 0.0 0.071428575 \n", + "\n", + " rolling_std_14 rolling_mean_30 rolling_std_30 rolling_mean_60 \\\n", + "34023 \n", + "34024 \n", + "34025 \n", + "34026 \n", + "34027 \n", + "... ... ... ... ... \n", + "47733744 0.267261237 0.06666667 0.253708124 0.033333335 \n", + "47733745 0.267261237 0.06666667 0.253708124 0.033333335 \n", + "47733746 0.267261237 0.06666667 0.253708124 0.033333335 \n", + "47733747 0.267261237 0.06666667 0.253708124 0.033333335 \n", + "47733748 0.267261237 0.06666667 0.253708124 0.033333335 \n", + "\n", + " rolling_std_60 rolling_mean_180 rolling_std_180 \n", + "34023 \n", + "34024 \n", + "34025 \n", + "34026 \n", + "34027 \n", + "... ... ... ... \n", + "47733744 0.181020334 0.077777781 0.288621366 \n", + "47733745 0.181020334 0.077777781 0.288621366 \n", + "47733746 0.181020334 0.077777781 0.288621366 \n", + "47733747 0.181020334 0.077777781 0.288621366 \n", + "47733748 0.181020334 0.077777781 0.288621366 \n", + "\n", + "[47735397 rows x 28 columns]" + ] + }, + "execution_count": 55, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "grid_df_lags" + ] + }, + { + "cell_type": "markdown", + "id": "898c2f27-50b0-4160-abd0-d6554c767c7f", + "metadata": {}, + "source": [ + "### Target encoding\n", + "Categorical variables present challenges to many machine learning algorithms such as XGBoost. One way to overcome the challenge is to use **target encoding**, where we encode categorical variables by replacing them with a statistic for the target variable. In this example, we will use the mean and the standard deviation.\n", + "\n", + "Read more about target encoding in [Target-encoding Categorical Variables](https://towardsdatascience.com/dealing-with-categorical-variables-by-using-target-encoder-a0f1733a4c69)." + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "id": "62cfbd3b-ede8-4bd1-967b-5a8785df49ea", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Encoding columns ['store_id', 'dept_id']\n", + "Encoding columns ['item_id', 'state_id']\n" + ] + } + ], + "source": [ + "icols = [[\"store_id\", \"dept_id\"], [\"item_id\", \"state_id\"]]\n", + "new_columns = []\n", + "\n", + "grid_df_target_enc = grid_df[\n", + " [\"id\", \"day_id\", \"item_id\", \"state_id\", \"store_id\", \"dept_id\", \"sales\"]\n", + "].copy()\n", + "grid_df_target_enc[\"sales\"].fillna(value=0, inplace=True)\n", + "\n", + "for col in icols:\n", + " print(f\"Encoding columns {col}\")\n", + " col_name = \"_\" + \"_\".join(col) + \"_\"\n", + " grid_df_target_enc[\"enc\" + col_name + \"mean\"] = (\n", + " grid_df_target_enc.groupby(col)[\"sales\"].transform(\"mean\").astype(np.float32)\n", + " )\n", + " grid_df_target_enc[\"enc\" + col_name + \"std\"] = (\n", + " grid_df_target_enc.groupby(col)[\"sales\"].transform(\"std\").astype(np.float32)\n", + " )\n", + " new_columns.extend([\"enc\" + col_name + \"mean\", \"enc\" + col_name + \"std\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "id": "8c83b674-7ec1-4203-acc3-7f12c0419001", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idday_idenc_store_id_dept_id_meanenc_store_id_dept_id_stdenc_item_id_state_id_meanenc_item_id_state_id_std
0FOODS_1_001_WI_2_evaluationd_8091.4929883.9876570.4335530.851153
1FOODS_1_001_WI_2_evaluationd_8101.4929883.9876570.4335530.851153
2FOODS_1_001_WI_2_evaluationd_8111.4929883.9876570.4335530.851153
3FOODS_1_001_WI_2_evaluationd_8121.4929883.9876570.4335530.851153
4FOODS_1_001_WI_2_evaluationd_8131.4929883.9876570.4335530.851153
.....................
47735392HOUSEHOLD_2_516_WI_3_evaluationd_520.2570270.6615410.0820840.299445
47735393HOUSEHOLD_2_516_WI_3_evaluationd_530.2570270.6615410.0820840.299445
47735394HOUSEHOLD_2_516_WI_3_evaluationd_540.2570270.6615410.0820840.299445
47735395HOUSEHOLD_2_516_WI_3_evaluationd_550.2570270.6615410.0820840.299445
47735396HOUSEHOLD_2_516_WI_3_evaluationd_490.2570270.6615410.0820840.299445
\n", + "

47735397 rows × 6 columns

\n", + "
" + ], + "text/plain": [ + " id day_id enc_store_id_dept_id_mean \\\n", + "0 FOODS_1_001_WI_2_evaluation d_809 1.492988 \n", + "1 FOODS_1_001_WI_2_evaluation d_810 1.492988 \n", + "2 FOODS_1_001_WI_2_evaluation d_811 1.492988 \n", + "3 FOODS_1_001_WI_2_evaluation d_812 1.492988 \n", + "4 FOODS_1_001_WI_2_evaluation d_813 1.492988 \n", + "... ... ... ... \n", + "47735392 HOUSEHOLD_2_516_WI_3_evaluation d_52 0.257027 \n", + "47735393 HOUSEHOLD_2_516_WI_3_evaluation d_53 0.257027 \n", + "47735394 HOUSEHOLD_2_516_WI_3_evaluation d_54 0.257027 \n", + "47735395 HOUSEHOLD_2_516_WI_3_evaluation d_55 0.257027 \n", + "47735396 HOUSEHOLD_2_516_WI_3_evaluation d_49 0.257027 \n", + "\n", + " enc_store_id_dept_id_std enc_item_id_state_id_mean \\\n", + "0 3.987657 0.433553 \n", + "1 3.987657 0.433553 \n", + "2 3.987657 0.433553 \n", + "3 3.987657 0.433553 \n", + "4 3.987657 0.433553 \n", + "... ... ... \n", + "47735392 0.661541 0.082084 \n", + "47735393 0.661541 0.082084 \n", + "47735394 0.661541 0.082084 \n", + "47735395 0.661541 0.082084 \n", + "47735396 0.661541 0.082084 \n", + "\n", + " enc_item_id_state_id_std \n", + "0 0.851153 \n", + "1 0.851153 \n", + "2 0.851153 \n", + "3 0.851153 \n", + "4 0.851153 \n", + "... ... \n", + "47735392 0.299445 \n", + "47735393 0.299445 \n", + "47735394 0.299445 \n", + "47735395 0.299445 \n", + "47735396 0.299445 \n", + "\n", + "[47735397 rows x 6 columns]" + ] + }, + "execution_count": 57, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "grid_df_target_enc = grid_df_target_enc[[\"id\", \"day_id\"] + new_columns]\n", + "grid_df_target_enc" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "id": "02c6b045-331e-4c4c-baa7-07a44e44ab39", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "id category\n", + "day_id category\n", + "enc_store_id_dept_id_mean float32\n", + "enc_store_id_dept_id_std float32\n", + "enc_item_id_state_id_mean float32\n", + "enc_item_id_state_id_std float32\n", + "dtype: object" + ] + }, + "execution_count": 58, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "grid_df_target_enc.dtypes" + ] + }, + { + "cell_type": "markdown", + "id": "e786fff5-0567-476d-be1a-aa440132fc3c", + "metadata": {}, + "source": [ + "### Filter by store and product department and create data segments\n", + "After combining all columns produced in the previous notebooks, we filter the rows in the data set by `store_id` and `dept_id` and create a segment. Each segment is saved as a pickle file and then upload to Cloud Storage." + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "id": "4ec8a3a1-f5a6-4ac9-aaef-2debcbe02919", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "segmented_data_dir = pathlib.Path(\"./segmented_data/\")\n", + "segmented_data_dir.mkdir(exist_ok=True)\n", + "\n", + "STORES = [\n", + " \"CA_1\",\n", + " \"CA_2\",\n", + " \"CA_3\",\n", + " \"CA_4\",\n", + " \"TX_1\",\n", + " \"TX_2\",\n", + " \"TX_3\",\n", + " \"WI_1\",\n", + " \"WI_2\",\n", + " \"WI_3\",\n", + "]\n", + "DEPTS = [\n", + " \"HOBBIES_1\",\n", + " \"HOBBIES_2\",\n", + " \"HOUSEHOLD_1\",\n", + " \"HOUSEHOLD_2\",\n", + " \"FOODS_1\",\n", + " \"FOODS_2\",\n", + " \"FOODS_3\",\n", + "]\n", + "\n", + "grid2_colnm = [\n", + " \"sell_price\",\n", + " \"price_max\",\n", + " \"price_min\",\n", + " \"price_std\",\n", + " \"price_mean\",\n", + " \"price_norm\",\n", + " \"price_nunique\",\n", + " \"item_nunique\",\n", + " \"price_momentum\",\n", + " \"price_momentum_m\",\n", + " \"price_momentum_y\",\n", + "]\n", + "\n", + "grid3_colnm = [\n", + " \"event_name_1\",\n", + " \"event_type_1\",\n", + " \"event_name_2\",\n", + " \"event_type_2\",\n", + " \"snap_CA\",\n", + " \"snap_TX\",\n", + " \"snap_WI\",\n", + " \"tm_d\",\n", + " \"tm_w\",\n", + " \"tm_m\",\n", + " \"tm_y\",\n", + " \"tm_wm\",\n", + " \"tm_dw\",\n", + " \"tm_w_end\",\n", + "]\n", + "\n", + "lag_colnm = [\n", + " \"sales_lag_28\",\n", + " \"sales_lag_29\",\n", + " \"sales_lag_30\",\n", + " \"sales_lag_31\",\n", + " \"sales_lag_32\",\n", + " \"sales_lag_33\",\n", + " \"sales_lag_34\",\n", + " \"sales_lag_35\",\n", + " \"sales_lag_36\",\n", + " \"sales_lag_37\",\n", + " \"sales_lag_38\",\n", + " \"sales_lag_39\",\n", + " \"sales_lag_40\",\n", + " \"sales_lag_41\",\n", + " \"sales_lag_42\",\n", + " \"rolling_mean_7\",\n", + " \"rolling_std_7\",\n", + " \"rolling_mean_14\",\n", + " \"rolling_std_14\",\n", + " \"rolling_mean_30\",\n", + " \"rolling_std_30\",\n", + " \"rolling_mean_60\",\n", + " \"rolling_std_60\",\n", + " \"rolling_mean_180\",\n", + " \"rolling_std_180\",\n", + "]\n", + "\n", + "target_enc_colnm = [\n", + " \"enc_store_id_dept_id_mean\",\n", + " \"enc_store_id_dept_id_std\",\n", + " \"enc_item_id_state_id_mean\",\n", + " \"enc_item_id_state_id_std\",\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "id": "23674edc-4f63-49dd-a421-a15dc4fb0a70", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def prepare_data(store, dept=None):\n", + " \"\"\"\n", + " Filter and clean data according to stores and product departments\n", + "\n", + " Parameters\n", + " ----------\n", + " store: Filter data by retaining rows whose store_id matches this parameter.\n", + " dept: Filter data by retaining rows whose dept_id matches this parameter.\n", + " This parameter can be set to None to indicate that we shouldn't filter by dept_id.\n", + " \"\"\"\n", + " if store is None:\n", + " raise ValueError(f\"store parameter must not be None\")\n", + "\n", + " if dept is None:\n", + " grid1 = grid_df[grid_df[\"store_id\"] == store]\n", + " else:\n", + " grid1 = grid_df[\n", + " (grid_df[\"store_id\"] == store) & (grid_df[\"dept_id\"] == dept)\n", + " ].drop(columns=[\"dept_id\"])\n", + " grid1 = grid1.drop(columns=[\"release_week\", \"wm_yr_wk\", \"store_id\", \"state_id\"])\n", + "\n", + " grid2 = grid_df_with_price[[\"id\", \"day_id\"] + grid2_colnm]\n", + " grid_combined = grid1.merge(grid2, on=[\"id\", \"day_id\"], how=\"left\")\n", + " del grid1, grid2\n", + "\n", + " grid3 = grid_df_with_calendar[[\"id\", \"day_id\"] + grid3_colnm]\n", + " grid_combined = grid_combined.merge(grid3, on=[\"id\", \"day_id\"], how=\"left\")\n", + " del grid3\n", + "\n", + " lag_df = grid_df_lags[[\"id\", \"day_id\"] + lag_colnm]\n", + " grid_combined = grid_combined.merge(lag_df, on=[\"id\", \"day_id\"], how=\"left\")\n", + " del lag_df\n", + "\n", + " target_enc_df = grid_df_target_enc[[\"id\", \"day_id\"] + target_enc_colnm]\n", + " grid_combined = grid_combined.merge(target_enc_df, on=[\"id\", \"day_id\"], how=\"left\")\n", + " del target_enc_df\n", + " gc.collect()\n", + "\n", + " grid_combined = grid_combined.drop(columns=[\"id\"])\n", + " grid_combined[\"day_id\"] = (\n", + " grid_combined[\"day_id\"]\n", + " .to_pandas()\n", + " .astype(\"str\")\n", + " .apply(lambda x: x[2:])\n", + " .astype(np.int16)\n", + " )\n", + "\n", + " return grid_combined" + ] + }, + { + "cell_type": "code", + "execution_count": 61, + "id": "9762f74a-3842-475a-9569-231df886cfdd", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Processing store CA_1...\n", + "Processing store CA_2...\n", + "Processing store CA_3...\n", + "Processing store CA_4...\n", + "Processing store TX_1...\n", + "Processing store TX_2...\n", + "Processing store TX_3...\n", + "Processing store WI_1...\n", + "Processing store WI_2...\n", + "Processing store WI_3...\n", + "Processing (store CA_1, department HOBBIES_1)...\n", + "Processing (store CA_1, department HOBBIES_2)...\n", + "Processing (store CA_1, department HOUSEHOLD_1)...\n", + "Processing (store CA_1, department HOUSEHOLD_2)...\n", + "Processing (store CA_1, department FOODS_1)...\n", + "Processing (store CA_1, department FOODS_2)...\n", + "Processing (store CA_1, department FOODS_3)...\n", + "Processing (store CA_2, department HOBBIES_1)...\n", + "Processing (store CA_2, department HOBBIES_2)...\n", + "Processing (store CA_2, department HOUSEHOLD_1)...\n", + "Processing (store CA_2, department HOUSEHOLD_2)...\n", + "Processing (store CA_2, department FOODS_1)...\n", + "Processing (store CA_2, department FOODS_2)...\n", + "Processing (store CA_2, department FOODS_3)...\n", + "Processing (store CA_3, department HOBBIES_1)...\n", + "Processing (store CA_3, department HOBBIES_2)...\n", + "Processing (store CA_3, department HOUSEHOLD_1)...\n", + "Processing (store CA_3, department HOUSEHOLD_2)...\n", + "Processing (store CA_3, department FOODS_1)...\n", + "Processing (store CA_3, department FOODS_2)...\n", + "Processing (store CA_3, department FOODS_3)...\n", + "Processing (store CA_4, department HOBBIES_1)...\n", + "Processing (store CA_4, department HOBBIES_2)...\n", + "Processing (store CA_4, department HOUSEHOLD_1)...\n", + "Processing (store CA_4, department HOUSEHOLD_2)...\n", + "Processing (store CA_4, department FOODS_1)...\n", + "Processing (store CA_4, department FOODS_2)...\n", + "Processing (store CA_4, department FOODS_3)...\n", + "Processing (store TX_1, department HOBBIES_1)...\n", + "Processing (store TX_1, department HOBBIES_2)...\n", + "Processing (store TX_1, department HOUSEHOLD_1)...\n", + "Processing (store TX_1, department HOUSEHOLD_2)...\n", + "Processing (store TX_1, department FOODS_1)...\n", + "Processing (store TX_1, department FOODS_2)...\n", + "Processing (store TX_1, department FOODS_3)...\n", + "Processing (store TX_2, department HOBBIES_1)...\n", + "Processing (store TX_2, department HOBBIES_2)...\n", + "Processing (store TX_2, department HOUSEHOLD_1)...\n", + "Processing (store TX_2, department HOUSEHOLD_2)...\n", + "Processing (store TX_2, department FOODS_1)...\n", + "Processing (store TX_2, department FOODS_2)...\n", + "Processing (store TX_2, department FOODS_3)...\n", + "Processing (store TX_3, department HOBBIES_1)...\n", + "Processing (store TX_3, department HOBBIES_2)...\n", + "Processing (store TX_3, department HOUSEHOLD_1)...\n", + "Processing (store TX_3, department HOUSEHOLD_2)...\n", + "Processing (store TX_3, department FOODS_1)...\n", + "Processing (store TX_3, department FOODS_2)...\n", + "Processing (store TX_3, department FOODS_3)...\n", + "Processing (store WI_1, department HOBBIES_1)...\n", + "Processing (store WI_1, department HOBBIES_2)...\n", + "Processing (store WI_1, department HOUSEHOLD_1)...\n", + "Processing (store WI_1, department HOUSEHOLD_2)...\n", + "Processing (store WI_1, department FOODS_1)...\n", + "Processing (store WI_1, department FOODS_2)...\n", + "Processing (store WI_1, department FOODS_3)...\n", + "Processing (store WI_2, department HOBBIES_1)...\n", + "Processing (store WI_2, department HOBBIES_2)...\n", + "Processing (store WI_2, department HOUSEHOLD_1)...\n", + "Processing (store WI_2, department HOUSEHOLD_2)...\n", + "Processing (store WI_2, department FOODS_1)...\n", + "Processing (store WI_2, department FOODS_2)...\n", + "Processing (store WI_2, department FOODS_3)...\n", + "Processing (store WI_3, department HOBBIES_1)...\n", + "Processing (store WI_3, department HOBBIES_2)...\n", + "Processing (store WI_3, department HOUSEHOLD_1)...\n", + "Processing (store WI_3, department HOUSEHOLD_2)...\n", + "Processing (store WI_3, department FOODS_1)...\n", + "Processing (store WI_3, department FOODS_2)...\n", + "Processing (store WI_3, department FOODS_3)...\n" + ] + } + ], + "source": [ + "# First save the segment to the disk\n", + "for store in STORES:\n", + " print(f\"Processing store {store}...\")\n", + " segment_df = prepare_data(store=store)\n", + " segment_df.to_pandas().to_pickle(\n", + " segmented_data_dir / f\"combined_df_store_{store}.pkl\"\n", + " )\n", + " del segment_df\n", + " gc.collect()\n", + "\n", + "for store in STORES:\n", + " for dept in DEPTS:\n", + " print(f\"Processing (store {store}, department {dept})...\")\n", + " segment_df = prepare_data(store=store, dept=dept)\n", + " segment_df.to_pandas().to_pickle(\n", + " segmented_data_dir / f\"combined_df_store_{store}_dept_{dept}.pkl\"\n", + " )\n", + " del segment_df\n", + " gc.collect()" + ] + }, + { + "cell_type": "code", + "execution_count": 62, + "id": "17c6859b-6159-424b-a041-1bc6b8201716", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Uploading segmented_data/combined_df_store_CA_3_dept_HOBBIES_2.pkl...\n", + "Uploading segmented_data/combined_df_store_TX_3_dept_FOODS_3.pkl...\n", + "Uploading segmented_data/combined_df_store_CA_1_dept_HOUSEHOLD_1.pkl...\n", + "Uploading segmented_data/combined_df_store_TX_3_dept_HOBBIES_1.pkl...\n", + "Uploading segmented_data/combined_df_store_WI_2_dept_HOUSEHOLD_2.pkl...\n", + "Uploading segmented_data/combined_df_store_TX_3_dept_HOUSEHOLD_1.pkl...\n", + "Uploading segmented_data/combined_df_store_WI_1_dept_HOUSEHOLD_2.pkl...\n", + "Uploading segmented_data/combined_df_store_CA_3_dept_HOBBIES_1.pkl...\n", + "Uploading segmented_data/combined_df_store_CA_1_dept_FOODS_3.pkl...\n", + "Uploading segmented_data/combined_df_store_TX_1_dept_HOBBIES_2.pkl...\n", + "Uploading segmented_data/combined_df_store_TX_2_dept_FOODS_3.pkl...\n", + "Uploading segmented_data/combined_df_store_CA_2_dept_FOODS_3.pkl...\n", + "Uploading segmented_data/combined_df_store_WI_3.pkl...\n", + "Uploading segmented_data/combined_df_store_TX_1_dept_HOUSEHOLD_2.pkl...\n", + "Uploading segmented_data/combined_df_store_WI_3_dept_FOODS_3.pkl...\n", + "Uploading segmented_data/combined_df_store_WI_2_dept_FOODS_1.pkl...\n", + "Uploading segmented_data/combined_df_store_TX_1_dept_FOODS_3.pkl...\n", + "Uploading segmented_data/combined_df_store_CA_3_dept_FOODS_3.pkl...\n", + "Uploading segmented_data/combined_df_store_WI_1_dept_HOBBIES_1.pkl...\n", + "Uploading segmented_data/combined_df_store_CA_1_dept_FOODS_2.pkl...\n", + "Uploading segmented_data/combined_df_store_TX_1_dept_HOBBIES_1.pkl...\n", + "Uploading segmented_data/combined_df_store_CA_3_dept_HOUSEHOLD_2.pkl...\n", + "Uploading segmented_data/combined_df_store_TX_1_dept_FOODS_2.pkl...\n", + "Uploading segmented_data/combined_df_store_CA_1.pkl...\n", + "Uploading segmented_data/combined_df_store_CA_2_dept_FOODS_2.pkl...\n", + "Uploading segmented_data/combined_df_store_TX_2_dept_HOUSEHOLD_2.pkl...\n", + "Uploading segmented_data/combined_df_store_WI_2.pkl...\n", + "Uploading segmented_data/combined_df_store_CA_4_dept_HOUSEHOLD_1.pkl...\n", + "Uploading segmented_data/combined_df_store_CA_3_dept_FOODS_1.pkl...\n", + "Uploading segmented_data/combined_df_store_WI_1.pkl...\n", + "Uploading segmented_data/combined_df_store_WI_1_dept_FOODS_1.pkl...\n", + "Uploading segmented_data/combined_df_store_CA_4_dept_FOODS_3.pkl...\n", + "Uploading segmented_data/combined_df_store_CA_2_dept_HOUSEHOLD_2.pkl...\n", + "Uploading segmented_data/combined_df_store_WI_2_dept_FOODS_2.pkl...\n", + "Uploading segmented_data/combined_df_store_CA_2_dept_FOODS_1.pkl...\n", + "Uploading segmented_data/combined_df_store_CA_1_dept_FOODS_1.pkl...\n", + "Uploading segmented_data/combined_df_store_TX_3.pkl...\n", + "Uploading segmented_data/combined_df_store_WI_1_dept_HOBBIES_2.pkl...\n", + "Uploading segmented_data/combined_df_store_CA_4.pkl...\n", + "Uploading segmented_data/combined_df_store_CA_1_dept_HOBBIES_1.pkl...\n", + "Uploading segmented_data/combined_df_store_WI_3_dept_HOUSEHOLD_1.pkl...\n", + "Uploading segmented_data/combined_df_store_CA_4_dept_HOUSEHOLD_2.pkl...\n", + "Uploading segmented_data/combined_df_store_CA_2_dept_HOBBIES_1.pkl...\n", + "Uploading segmented_data/combined_df_store_CA_2_dept_HOUSEHOLD_1.pkl...\n", + "Uploading segmented_data/combined_df_store_CA_4_dept_FOODS_1.pkl...\n", + "Uploading segmented_data/combined_df_store_WI_1_dept_HOUSEHOLD_1.pkl...\n", + "Uploading segmented_data/combined_df_store_CA_3_dept_FOODS_2.pkl...\n", + "Uploading segmented_data/combined_df_store_WI_1_dept_FOODS_2.pkl...\n", + "Uploading segmented_data/combined_df_store_WI_3_dept_HOBBIES_2.pkl...\n", + "Uploading segmented_data/combined_df_store_WI_3_dept_HOBBIES_1.pkl...\n", + "Uploading segmented_data/combined_df_store_WI_3_dept_HOUSEHOLD_2.pkl...\n", + "Uploading segmented_data/combined_df_store_TX_1_dept_FOODS_1.pkl...\n", + "Uploading segmented_data/combined_df_store_CA_3.pkl...\n", + "Uploading segmented_data/combined_df_store_TX_2.pkl...\n", + "Uploading segmented_data/combined_df_store_WI_2_dept_FOODS_3.pkl...\n", + "Uploading segmented_data/combined_df_store_CA_1_dept_HOUSEHOLD_2.pkl...\n", + "Uploading segmented_data/combined_df_store_WI_3_dept_FOODS_2.pkl...\n", + "Uploading segmented_data/combined_df_store_TX_2_dept_HOUSEHOLD_1.pkl...\n", + "Uploading segmented_data/combined_df_store_WI_2_dept_HOBBIES_1.pkl...\n", + "Uploading segmented_data/combined_df_store_CA_4_dept_HOBBIES_1.pkl...\n", + "Uploading segmented_data/combined_df_store_WI_2_dept_HOUSEHOLD_1.pkl...\n", + "Uploading segmented_data/combined_df_store_CA_4_dept_HOBBIES_2.pkl...\n", + "Uploading segmented_data/combined_df_store_TX_1_dept_HOUSEHOLD_1.pkl...\n", + "Uploading segmented_data/combined_df_store_WI_3_dept_FOODS_1.pkl...\n", + "Uploading segmented_data/combined_df_store_TX_3_dept_HOBBIES_2.pkl...\n", + "Uploading segmented_data/combined_df_store_CA_2_dept_HOBBIES_2.pkl...\n", + "Uploading segmented_data/combined_df_store_WI_1_dept_FOODS_3.pkl...\n", + "Uploading segmented_data/combined_df_store_TX_2_dept_FOODS_1.pkl...\n", + "Uploading segmented_data/combined_df_store_WI_2_dept_HOBBIES_2.pkl...\n", + "Uploading segmented_data/combined_df_store_CA_4_dept_FOODS_2.pkl...\n", + "Uploading segmented_data/combined_df_store_TX_3_dept_HOUSEHOLD_2.pkl...\n", + "Uploading segmented_data/combined_df_store_TX_2_dept_HOBBIES_2.pkl...\n", + "Uploading segmented_data/combined_df_store_TX_2_dept_HOBBIES_1.pkl...\n", + "Uploading segmented_data/combined_df_store_TX_2_dept_FOODS_2.pkl...\n", + "Uploading segmented_data/combined_df_store_TX_3_dept_FOODS_2.pkl...\n", + "Uploading segmented_data/combined_df_store_CA_3_dept_HOUSEHOLD_1.pkl...\n", + "Uploading segmented_data/combined_df_store_CA_1_dept_HOBBIES_2.pkl...\n", + "Uploading segmented_data/combined_df_store_TX_1.pkl...\n", + "Uploading segmented_data/combined_df_store_TX_3_dept_FOODS_1.pkl...\n", + "Uploading segmented_data/combined_df_store_CA_2.pkl...\n" + ] + } + ], + "source": [ + "# Then copy the segment to Cloud Storage\n", + "fs = gcsfs.GCSFileSystem()\n", + "\n", + "for e in segmented_data_dir.glob(\"*.pkl\"):\n", + " print(f\"Uploading {e}...\")\n", + " basename = e.name\n", + " fs.put_file(e, f\"{bucket_name}/{basename}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 64, + "id": "8357de8c-0b19-4914-9d41-67cf1a9359c4", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Also upload the product weights\n", + "fs = gcsfs.GCSFileSystem()\n", + "\n", + "weights.to_pandas().to_pickle(\"product_weights.pkl\")\n", + "fs.put_file(\"product_weights.pkl\", f\"{bucket_name}/product_weights.pkl\")" + ] + }, + { + "cell_type": "markdown", + "id": "8777971f-a46d-4954-b9c8-45a7c182ac96", + "metadata": {}, + "source": [ + "## Training and Evaluation with Hyperparameter Optimization (HPO)\n", + "\n", + "Now that we finished processing the data, we are now ready to train a model to forecast future sales. We will leverage the worker pods to run multiple training jobs in parallel, speeding up the hyperparameter search." + ] + }, + { + "cell_type": "markdown", + "id": "d14cff4c-5761-4c08-bd36-92710b2bfc32", + "metadata": {}, + "source": [ + "### Import modules and define constants" + ] + }, + { + "cell_type": "code", + "execution_count": 65, + "id": "b2d02ade-247c-4fff-bcd0-afd9ff30d026", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import cudf\n", + "import gcsfs\n", + "import xgboost as xgb\n", + "import pandas as pd\n", + "import numpy as np\n", + "import optuna\n", + "import gc\n", + "import time\n", + "import pickle\n", + "import copy\n", + "import json\n", + "\n", + "import matplotlib.pyplot as plt\n", + "from matplotlib.patches import Patch\n", + "import matplotlib\n", + "\n", + "from dask.distributed import wait\n", + "from dask_kubernetes.operator import KubeCluster\n", + "from dask.distributed import Client" + ] + }, + { + "cell_type": "code", + "execution_count": 66, + "id": "5343eb9a-2eb5-4e88-b18c-3cc5a1c84cec", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Choose the same RAPIDS image you used for launching the notebook session\n", + "rapids_image = \"rapidsai/notebooks:23.10a-cuda12.0-py3.10\"\n", + "# Use the number of worker nodes in your Kubernetes cluster.\n", + "n_workers = 2\n", + "# Bucket that contains the processed data pickles\n", + "bucket_name = \"\"\n", + "bucket_name = \"phcho-m5-competition-hpo-example\"\n", + "\n", + "# List of stores and product departments\n", + "STORES = [\n", + " \"CA_1\",\n", + " \"CA_2\",\n", + " \"CA_3\",\n", + " \"CA_4\",\n", + " \"TX_1\",\n", + " \"TX_2\",\n", + " \"TX_3\",\n", + " \"WI_1\",\n", + " \"WI_2\",\n", + " \"WI_3\",\n", + "]\n", + "DEPTS = [\n", + " \"HOBBIES_1\",\n", + " \"HOBBIES_2\",\n", + " \"HOUSEHOLD_1\",\n", + " \"HOUSEHOLD_2\",\n", + " \"FOODS_1\",\n", + " \"FOODS_2\",\n", + " \"FOODS_3\",\n", + "]" + ] + }, + { + "cell_type": "markdown", + "id": "9b4d2eea-7b13-4249-9b94-2573be07ad84", + "metadata": {}, + "source": [ + "### Define cross-validation folds\n", + "**[Cross-validation](https://en.wikipedia.org/wiki/Cross-validation_(statistics))** is a statistical method for estimating how well a machine learning model generalizes to an independent data set. The method is also useful for evaluating the choice of a given combination of model hyperparameters.\n", + "\n", + "To estimate the capacity to generalize, we define multiple cross-validation **folds** consisting of mulitple pairs of `(training set, validation set)`. For each fold, we fit a model using the training set and evaluate its accuracy on the validation set. The \"goodness\" score for a given hyperparameter combination is the accuracy of the model on each validation set, averaged over all cross-validation folds.\n", + "\n", + "Great care must be taken when defining cross-validation folds for time-series data. We are not allowed to use the future to predict the past, so the training set must precede (in time) the validation set. Consequently, we partition the data set in the time dimension and assign the training and validation sets using time ranges:" + ] + }, + { + "cell_type": "code", + "execution_count": 67, + "id": "525c26ce-df6d-42a1-90c4-2b3369105728", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Cross-validation folds and held-out test set (in time dimension)\n", + "# The held-out test set is used for final evaluation\n", + "cv_folds = [ # (train_set, validation_set)\n", + " ([0, 1114], [1114, 1314]),\n", + " ([0, 1314], [1314, 1514]),\n", + " ([0, 1514], [1514, 1714]),\n", + " ([0, 1714], [1714, 1914]),\n", + "]\n", + "n_folds = len(cv_folds)\n", + "holdout = [1914, 1942]\n", + "time_horizon = 1942" + ] + }, + { + "cell_type": "markdown", + "id": "50cd8aea-18bd-4005-be43-6906e8b31de2", + "metadata": {}, + "source": [ + "It is helpful to visualize the cross-validation folds using Matplotlib." + ] + }, + { + "cell_type": "code", + "execution_count": 68, + "id": "c0dddf7f-bc70-4d12-a108-9b1ecbb784ba", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAxQAAAEiCAYAAABgP5QIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy81sbWrAAAACXBIWXMAAA9hAAAPYQGoP6dpAAA2+UlEQVR4nO3deXxNd+L/8fdNIhHZLZEgEXtC7VvplFBLtIyOztDWIEU6Ru10MLV201W1X0Nnphp0fMuYajs/2k6lCJrY0kaVSDGxR9MaEoRE5Pz+8M0dV7ab4yY35PV8PO7j4Z7zOefzOZ987nXf95xzPxbDMAwBAAAAgAkuzm4AAAAAgLsXgQIAAACAaQQKAAAAAKYRKAAAAACYRqAAAAAAYBqBAgAAAIBpBAoAAAAAphEoAAAAAJjm5uwG3A3y8/N19uxZ+fj4yGKxOLs5AAAAgMMYhqFLly6pXr16cnEp+/kGAoUdzp49q5CQEGc3AwAAACg3p06dUoMGDcq8HYHCDj4+PpJudrKvr6+TWwMAAAA4TlZWlkJCQqyfecuKQGGHgsucfH19CRQAAAC4J5m9tJ+bsgEAAACYRqAAAAAAYBqBAgAAAIBpBAoAAAAAphEoAAAAAJhGoAAAAABgGoECAAAAgGkECgAAAACmESgAAAAAmEagAAAAAGAagQIAAACAaQQKAAAAAKa5ObsBKF1K4jTl5V5wdjMAAHZYH1dXksXZzXA6X19fzXh2trObAaACcIbiLkCYAIC7CWFCkrKyspzdBAAVhEABAAAAwDQCBQAAAADTCBQAAAAATCNQAAAAADCNQAEAAADANAIFAAAAANMIFAAAAABMq9SBIjIyUlOmTCmxTFhYmJYsWVIh7QEAAABgq1wDRXR0tCwWS6HH0aNHy7PaQj766CO1bNlSHh4eatmypT7++OMKrR8AAAC4V5X7GYqoqCilp6fbPBo1alTe1VolJiZq2LBhGjFihPbv368RI0Zo6NCh2r17d4W1AQAAALhXlXug8PDwUFBQkM3D1dVVkhQfH68uXbrIw8NDwcHBmjVrlvLy8ordV0ZGhgYNGiRPT081atRIa9asKbX+JUuWqG/fvpo9e7bCw8M1e/ZsPfTQQ1wmBQAAADiA0+6hOHPmjB5++GF17txZ+/fv1/Lly7VixQq9+OKLxW4THR2t48ePa8uWLfrHP/6hZcuWKSMjo8R6EhMT1a9fP5tl/fv3V0JCQrHb5OTkKCsry+YBAAAAoDC38q5g48aN8vb2tj4fMGCA1q9fr2XLlikkJERLly6VxWJReHi4zp49q5kzZ2revHlycbHNOj/88IM+//xz7dq1S127dpUkrVixQhERESXWf+7cOdWtW9dmWd26dXXu3Llit1m0aJEWLlxY1kMFAAAAqpxyDxS9evXS8uXLrc+9vLwkSSkpKerWrZssFot13QMPPKDLly/r9OnTCg0NtdlPSkqK3Nzc1KlTJ+uy8PBw+fv7l9qGW+uQJMMwCi271ezZszVt2jTr86ysLIWEhJRaDwAAAFDVlHug8PLyUtOmTQstL+pDvWEYkgoHgNLWlSQoKKjQ2YiMjIxCZy1u5eHhIQ8PjzLVAwAAAFRFTruHomXLlkpISLAGBUlKSEiQj4+P6tevX6h8RESE8vLytG/fPuuy1NRUXbx4scR6unXrps2bN9ss+/LLL9W9e/c7OwAAAAAAzgsU48eP16lTpzRx4kQdPnxYn376qebPn69p06YVun9Cklq0aKGoqCjFxMRo9+7dSkpK0tixY+Xp6VliPZMnT9aXX36pV199VYcPH9arr76quLi4UifMAwAAAFA6pwWK+vXr67PPPtOePXvUtm1bjRs3TmPGjNGcOXOK3SY2NlYhISHq2bOnhgwZoqefflqBgYEl1tO9e3etXbtWsbGxatOmjVauXKl169ZZb+wGAAAAYJ7FuPWaIxQpKytLfn5+yszMlK+vb4XXfyD+qQqvEwBgzvq4IGc3odJ4/oVFzm4CADvc6Wddp52hAAAAAHD3I1AAAAAAMI1AAQAAAMA0AgUAAAAA0wgUAAAAAEwjUAAAAAAwjUBxF3BzD3B2EwAAduPX2CU55WfWATiHm7MbgNJFdFvs7CYAAOzUuqezWwAAFYszFAAAAABMI1AAAAAAMI1AAQAAAMA0AgUAAAAA0wgUAAAAAEwjUAAAAAAwjUABAAAAwDQCBQAAAADTCBQAAAAATCNQAAAAADCNQAEAAADANAIFAAAAANMIFAAAAABMc3N2A1C6lMRpysu94OxmAABgl4076uhqjoski7Ob4lQWi0ULn3/Z2c0Ayh2B4i5AmAAA3E2u5rg6uwmVgmEYzm4CUCG45AkAAACAaQQKAAAAAKYRKAAAAACYRqAAAAAAYBqBAgAAAIBpBAoAAAAAphEoAAAAAJhWqQNFZGSkpkyZUmKZsLAwLVmypELaAwAAAMBWuQaK6OhoWSyWQo+jR4+WZ7U2Dh48qMcee0xhYWGyWCyEDwAAAMCByv0MRVRUlNLT020ejRo1Ku9qrbKzs9W4cWO98sorCgoKqrB6AQAAgKqg3AOFh4eHgoKCbB6urq6SpPj4eHXp0kUeHh4KDg7WrFmzlJeXV+y+MjIyNGjQIHl6eqpRo0Zas2ZNqfV37txZr7/+uh5//HF5eHg47LgAAAAASG7OqvjMmTN6+OGHFR0drdWrV+vw4cOKiYlR9erVtWDBgiK3iY6O1qlTp7Rlyxa5u7tr0qRJysjIcHjbcnJylJOTY32elZXl8DoAAACAe0G5B4qNGzfK29vb+nzAgAFav369li1bppCQEC1dulQWi0Xh4eE6e/asZs6cqXnz5snFxfbkyQ8//KDPP/9cu3btUteuXSVJK1asUEREhMPbvGjRIi1cuNDh+wUAAADuNeUeKHr16qXly5dbn3t5eUmSUlJS1K1bN1ksFuu6Bx54QJcvX9bp06cVGhpqs5+UlBS5ubmpU6dO1mXh4eHy9/d3eJtnz56tadOmWZ9nZWUpJCTE4fUAAAAAd7tyDxReXl5q2rRpoeWGYdiEiYJlkgotL22do3l4eHC/BQAAAGAHp81D0bJlSyUkJFiDgiQlJCTIx8dH9evXL1Q+IiJCeXl52rdvn3VZamqqLl68WBHNBQAAAFAEpwWK8ePH69SpU5o4caIOHz6sTz/9VPPnz9e0adMK3T8hSS1atFBUVJRiYmK0e/duJSUlaezYsfL09CyxntzcXCUnJys5OVm5ubk6c+aMkpOTK3QuDAAAAOBe5bRAUb9+fX322Wfas2eP2rZtq3HjxmnMmDGaM2dOsdvExsYqJCREPXv21JAhQ/T0008rMDCwxHrOnj2r9u3bq3379kpPT9cbb7yh9u3ba+zYsY4+JAAAAKDKsRi3XnOEImVlZcnPz0+ZmZny9fWt8PoPxD9V4XUCAGDW+jgmki3w/AuLnN0EoFR3+lnXaWcoAAAAANz9CBQAAAAATCNQAAAAADCNQAEAAADANAIFAAAAANMIFAAAAABMI1DcBdzcA5zdBAAA7ObpcUMSv0pvsVic3QSgQrg5uwEoXUS3xc5uAgAAdmvd09ktAFCROEMBAAAAwDQCBQAAAADTCBQAAAAATCNQAAAAADCNQAEAAADANAIFAAAAANMIFAAAAABMI1AAAAAAMI1AAQAAAMA0AgUAAAAA0wgUAAAAAEwjUAAAAAAwjUABAAAAwDQ3ZzcApUtJnKa83AvObgYAACiD9XF1JVmc3Qyn8/X11YxnZzu7GShHnKG4CxAmAAC4GxEmJCkrK8vZTUA5I1AAAAAAMI1AAQAAAMA0AgUAAAAA0wgUAAAAAEwjUAAAAAAwjUABAAAAwDQCBQAAAADTKnWgiIyM1JQpU0osExYWpiVLllRIewAAAADYKtdAER0dLYvFUuhx9OjR8qzWxl//+lc9+OCDCggIUEBAgPr06aM9e/ZUWP0AAADAvazcz1BERUUpPT3d5tGoUaPyrtZq27ZteuKJJ7R161YlJiYqNDRU/fr105kzZyqsDQAAAMC9qtwDhYeHh4KCgmwerq6ukqT4+Hh16dJFHh4eCg4O1qxZs5SXl1fsvjIyMjRo0CB5enqqUaNGWrNmTan1r1mzRuPHj1e7du0UHh6uv/71r8rPz9dXX33lsGMEAAAAqio3Z1V85swZPfzww4qOjtbq1at1+PBhxcTEqHr16lqwYEGR20RHR+vUqVPasmWL3N3dNWnSJGVkZJSp3uzsbF2/fl01a9YstkxOTo5ycnKsz7OysspUBwAAAFBVlHug2Lhxo7y9va3PBwwYoPXr12vZsmUKCQnR0qVLZbFYFB4errNnz2rmzJmaN2+eXFxsT5788MMP+vzzz7Vr1y517dpVkrRixQpFRESUqT2zZs1S/fr11adPn2LLLFq0SAsXLizTfgEAAICqqNwDRa9evbR8+XLrcy8vL0lSSkqKunXrJovFYl33wAMP6PLlyzp9+rRCQ0Nt9pOSkiI3Nzd16tTJuiw8PFz+/v52t+W1117Thx9+qG3btql69erFlps9e7amTZtmfZ6VlaWQkBC76wEAAACqinIPFF5eXmratGmh5YZh2ISJgmWSCi0vbZ093njjDb388suKi4tTmzZtSizr4eEhDw8PU/UAAAAAVYnT5qFo2bKlEhISrEFBkhISEuTj46P69esXKh8REaG8vDzt27fPuiw1NVUXL14sta7XX39dL7zwgr744gubMxwAAAAA7ozTAsX48eN16tQpTZw4UYcPH9ann36q+fPna9q0aYXun5CkFi1aKCoqSjExMdq9e7eSkpI0duxYeXp6lljPa6+9pjlz5uj9999XWFiYzp07p3Pnzuny5cvldWgAAABAleG0QFG/fn199tln2rNnj9q2batx48ZpzJgxmjNnTrHbxMbGKiQkRD179tSQIUP09NNPKzAwsMR6li1bptzcXP36179WcHCw9fHGG284+pAAAACAKsdi3HrNEYqUlZUlPz8/ZWZmytfXt8LrPxD/VIXXCQAA7sz6uCBnN6HSeP6FRc5uAkpwp591nXaGAgAAAMDdj0ABAAAAwDQCBQAAAADTCBQAAAAATCNQAAAAADCNQAEAAADANALFXcDNPcDZTQAAAGXGL/NLcspP7qNiuTm7AShdRLfFzm4CAAAoo9Y9nd0CoGJwhgIAAACAaQQKAAAAAKYRKAAAAACYRqAAAAAAYBqBAgAAAIBpBAoAAAAAphEoAAAAAJhGoAAAAABgGoECAAAAgGkECgAAAACmESgAAAAAmEagAAAAAGAagQIAAACAaW7ObgBKl5I4TXm5F5zdDAAAgDLZuKOOrua4SLI4uylOZbFYtPD5l53djHJDoLgLECYAAMDd6GqOq7ObUCkYhuHsJpQrLnkCAAAAYBqBAgAAAIBpBAoAAAAAphEoAAAAAJhGoAAAAABgGoECAAAAgGkECgAAAACmVepAERkZqSlTppRYJiwsTEuWLKmQ9gAAAACwVa6BIjo6WhaLpdDj6NGj5VmtjQ0bNqhTp07y9/eXl5eX2rVrpw8++KDC6gcAAADuZeU+U3ZUVJRiY2NtltWpU6e8q7WqWbOmnnvuOYWHh8vd3V0bN27UU089pcDAQPXv37/C2gEAAADci8r9kicPDw8FBQXZPFxdb07DHh8fry5dusjDw0PBwcGaNWuW8vLyit1XRkaGBg0aJE9PTzVq1Ehr1qwptf7IyEj96le/UkREhJo0aaLJkyerTZs22rlzp8OOEQAAAKiqyv0MRXHOnDmjhx9+WNHR0Vq9erUOHz6smJgYVa9eXQsWLChym+joaJ06dUpbtmyRu7u7Jk2apIyMDLvrNAxDW7ZsUWpqql599dViy+Xk5CgnJ8f6PCsry+46AAAAgKqk3APFxo0b5e3tbX0+YMAArV+/XsuWLVNISIiWLl0qi8Wi8PBwnT17VjNnztS8efPk4mJ78uSHH37Q559/rl27dqlr166SpBUrVigiIqLUNmRmZqp+/frKycmRq6urli1bpr59+xZbftGiRVq4cKHJIwYAAACqjnIPFL169dLy5cutz728vCRJKSkp6tatmywWi3XdAw88oMuXL+v06dMKDQ212U9KSorc3NzUqVMn67Lw8HD5+/uX2gYfHx8lJyfr8uXL+uqrrzRt2jQ1btxYkZGRRZafPXu2pk2bZn2elZWlkJAQew4XAAAAqFLKPVB4eXmpadOmhZYbhmETJgqWSSq0vLR1pXFxcbG2oV27dkpJSdGiRYuKDRQeHh7y8PAocz0AAABAVeO0eShatmyphIQEa1CQpISEBPn4+Kh+/fqFykdERCgvL0/79u2zLktNTdXFixfLXLdhGDb3SAAAAAAwx2mBYvz48Tp16pQmTpyow4cP69NPP9X8+fM1bdq0QvdPSFKLFi0UFRWlmJgY7d69W0lJSRo7dqw8PT1LrGfRokXavHmz/v3vf+vw4cNavHixVq9erd/+9rfldWgAAABAleG0X3mqX7++PvvsMz377LNq27atatasqTFjxmjOnDnFbhMbG6uxY8eqZ8+eqlu3rl588UXNnTu3xHquXLmi8ePH6/Tp0/L09FR4eLj+9re/adiwYY4+JAAAAKDKsRi3XnOEImVlZcnPz0+ZmZny9fWt8PoPxD9V4XUCAADcqfVxQc5uQqXx/AuLnN2EYt3pZ12nXfIEAAAA4O5HoAAAAABgGoECAAAAgGkECgAAAACmESgAAAAAmEagAAAAAGAageIu4OYe4OwmAAAAlJmnxw1JzFBgsVic3YRy5bSJ7WC/iG6Lnd0EAACAMmvd09ktQEXgDAUAAAAA0wgUAAAAAEwjUAAAAAAwjUABAAAAwDQCBQAAAADTCBQAAAAATCNQAAAAADCNQAEAAADANAIFAAAAANMIFAAAAABMI1AAAAAAMI1AAQAAAMA0AgUAAAAA09yc3QAAAADgXta692Tl5xullju47Z0KaI3jcYYCAAAAKEf2hIm7GYECAAAAgGkECgAAAACmESgAAAAAmEagAAAAAGAagQIAAACAaQQKAAAAAKYRKAAAAACY5pRAERYWpiVLlpRYxmKx6JNPPqmQ9gAAAAAwp0yBIjo6Wo8++mih5du2bZPFYtHFixcd1KzyYU+QAQAAAGA/N2c34F5hGIby8vJ048YNZzcFAIB7RrVq1eTq6ursZgAoQbkEio8++kjz5s3T0aNHFRwcrIkTJ2r69OnFlj9y5IjGjBmjPXv2qHHjxnr77bcLlTlw4IAmT56sxMRE1ahRQ4899pgWL14sb29vSVJkZKTatWtncwbi0Ucflb+/v1auXKnIyEidOHFCU6dO1dSpUyXdDAGOkJubq/T0dGVnZztkfwAA4CaLxaIGDRpY/78HUPk4PFAkJSVp6NChWrBggYYNG6aEhASNHz9etWrVUnR0dKHy+fn5GjJkiGrXrq1du3YpKytLU6ZMsSmTnZ2tqKgo3X///dq7d68yMjI0duxYTZgwQStXrrSrXRs2bFDbtm319NNPKyYmpsSyOTk5ysnJsT7Pysoqtmx+fr7S0tLk6uqqevXqyd3dXRaLxa42AQCA4hmGoZ9++kmnT59Ws2bNOFMBVFJlDhQbN24s9C3BrZf5LF68WA899JDmzp0rSWrevLkOHTqk119/vchAERcXp5SUFB0/flwNGjSQJL388ssaMGCAtcyaNWt09epVrV69Wl5eXpKkpUuXatCgQXr11VdVt27dUttds2ZNubq6ysfHR0FBQSWWXbRokRYuXFjqPqWbZyfy8/MVEhKiGjVq2LUNAACwT506dXT8+HFdv36dQAFUUmX+ladevXopOTnZ5vHee+9Z16ekpOiBBx6w2eaBBx7QkSNHiry/ICUlRaGhodYwIUndunUrVKZt27bWMFGwz/z8fKWmppb1EEo1e/ZsZWZmWh+nTp0qdRsXF36BFwAAR+OsP1D5lfkMhZeXl5o2bWqz7PTp09Z/G4ZR6MVf0r0KRa0ravvi3lAKlru4uBTa1/Xr14uttyQeHh7y8PAwtS0AAABQlTj8a/WWLVtq586dNssSEhLUvHnzIk9VtmzZUidPntTZs2etyxITEwuVSU5O1pUrV6zLvv76a7m4uKh58+aSbp4STU9Pt66/ceOGvv/+e5v9uLu78ytM5SgyMrLQ/S8lOX78uCwWi5KTk8utTbh73D5+KnK+Gua9wa2YKwkAysbhN2VPnz5dnTt31gsvvKBhw4YpMTFRS5cu1bJly4os36dPH7Vo0UIjR47Um2++qaysLD333HM2ZYYPH6758+dr1KhRWrBggX766SdNnDhRI0aMsN4/0bt3b02bNk2bNm1SkyZN9NZbbxWaFyMsLEzbt2/X448/Lg8PD9WuXdvRh28j99p53bh+qVzrKOBazUfu1WvZVba008ejRo2y+2b3W23YsEHVqlWzu3xISIjS09PL/e9wp6Kjo3Xx4kWnfni4ePGisrOvlF7QQWrU8JK/v79dZQcNGqSrV68qLi6u0LrExER1795dSUlJ6tChQ5nasHfvXpvLHB1hwYIF+uSTTwqF2PT0dAUEBDi0LkcLCwvTlClTyhTaHensj//RxcyKG4P+fl6qV7em3eWLe51u27ZNvXr10oULF+we05VJUb9gWJTyGB/21n0n7va/D4CbHB4oOnTooL///e+aN2+eXnjhBQUHB+v5558v8oZs6ealSh9//LHGjBmjLl26KCwsTO+8846ioqKsZWrUqKF//etfmjx5sjp37mzzs7EFRo8erf3792vkyJFyc3PT1KlT1atXL5u6nn/+ef3ud79TkyZNlJOT47CfjS1K7rXz+mHvbBn55i67KiuLSzU177zIrlBx65mcdevWad68eTb3onh6etqUv379ul1BoWZN+//zlyRXV9dSb5DHzTDxzttvKi8vr8LqdHNz06TJ0+36D37MmDEaMmSITpw4oYYNG9qse//999WuXbsyhwnp5lnHisI4LNnZH/+jR0a8qNzcihuD7u5u2vTBnDKFCgCAc5TpkqeVK1cW+S1tZGSkDMOwfvh47LHHdPDgQeXm5urEiROaMWOGTfnjx4/bfIvSvHlz7dixQzk5OUpNTVX//v1lGIbNrNytW7fWli1bdPXqVZ0/f15/+ctfbH5tqlq1alq2bJnOnz+vH3/8UbNmzdInn3xi8037/fffr/379+vatWvlGiYk6cb1SxUWJiTJyL9u99mQoKAg68PPz08Wi8X6/Nq1a/L399ff//53RUZGqnr16vrb3/6m8+fP64knnlCDBg1Uo0YNtW7dWh9++KHNfou6ZOXll1/W6NGj5ePjo9DQUP3lL3+xrr/9kqeCGde/+uorderUSTVq1FD37t0L3Xj/4osvKjAwUD4+Pho7dqxmzZqldu3aFXu8Fy5c0PDhw1WnTh15enqqWbNmio2Nta4/c+aMhg0bpoCAANWqVUuDBw/W8ePHJd38RnvVqlX69NNPZbFYZLFYtG3bNrv62VGys69UaJiQpLy8PLvPiAwcOFCBgYGFzmplZ2dr3bp1GjNmjF3j53a3X3Zy5MgR9ejRQ9WrV1fLli21efPmQtvMnDlTzZs3V40aNdS4cWPNnTvXei/VypUrtXDhQu3fv9/6tyxo8+2Xrxw4cEC9e/eWp6enatWqpaefflqXL1+2ro+Ojtajjz6qN954Q8HBwapVq5aeeeaZEu/b2r9/v3r16iUfHx/5+vqqY8eO2rdvn3V9QkKCevToIU9PT4WEhGjSpEnWyzxvnUenoO0V6WLmlQoNE5KUm5tXbmdESurrotgz9opy8uRJDR48WN7e3vL19dXQoUP1448/WtcXjKNbTZkyRZGRkdb18fHxevvtt61/94L3pluVND5KO9Zly5apWbNmql69uurWratf//rXZaq7pH1IN++BfO2119S4cWN5enqqbdu2+sc//iHp5v8BBV/8BQQEyGKxFPvlI4DKjZ8mQpFmzpypSZMmKSUlRf3799e1a9fUsWNHbdy4Ud9//72efvppjRgxQrt37y5xP2+++aY6deqkb7/9VuPHj9fvf/97HT58uMRtnnvuOb355pvat2+f3NzcNHr0aOu6NWvW6KWXXtKrr76qpKQkhYaGavny5SXub+7cuTp06JA+//xzpaSkaPny5dbLrLKzs9WrVy95e3tr+/bt2rlzp7y9vRUVFaXc3FzNmDFDQ4cOVVRUlNLT05Wenq7u3bvb2YtVg5ubm0aOHKmVK1faBPX169crNzdXw4cPNz1+ChTMV+Pq6qpdu3bp3Xff1cyZMwuV8/Hx0cqVK3Xo0CG9/fbb+utf/6q33npLkjRs2DBNnz5drVq1sv4thw0bVmgfBfPeBAQEaO/evVq/fr3i4uI0YcIEm3Jbt27VsWPHtHXrVq1atUorV64s8VLB4cOHq0GDBtq7d6+SkpI0a9Ys65m/AwcOqH///hoyZIi+++47rVu3Tjt37rTWuWHDBjVo0EDPP/+8te0wp7S+vp29Y+92BV+K/ec//1F8fLw2b96sY8eOFTnmivP222+rW7duiomJsf7dQ0JCCpUrbnyUdqz79u3TpEmT9Pzzzys1NVVffPGFevToUaa6S9qHJM2ZM0exsbFavny5Dh48qKlTp+q3v/2t4uPjFRISoo8++kiSlJqaqvT09CIntgVQ+ZXLTNm4+02ZMkVDhgyxWXbrmaaJEyfqiy++0Pr169W1a9di9/Pwww9r/Pjxkm6GlLfeekvbtm1TeHh4sdu89NJL6tmzpyRp1qxZeuSRR3Tt2jVVr15d//M//6MxY8boqaeekiTNmzdPX375pc23x7c7efKk2rdvr06dOkm6+c13gbVr18rFxUXvvfee9Vu92NhY+fv7a9u2berXr588PT2Vk5PDZTElGD16tF5//XXr9dDSzcudhgwZooCAAAUEBJgaPwXsma9GuvnhpUBYWJimT5+udevW6Q9/+IM8PT3l7e0tNze3Ev+W9s57ExAQoKVLl8rV1VXh4eF65JFH9NVXXxU7cebJkyf17LPPWsd+s2bNrOtef/11Pfnkk9YzfM2aNdM777yjnj17avny5WWaR6cqK22eJKn0vq5evbpNeXvH3u3i4uL03XffKS0tzfpB/IMPPlCrVq20d+9ede7cudTj8fPzk7u7u2rUqFHi37248VHasZ48eVJeXl4aOHCgfHx81LBhQ7Vv375MdZe0jytXrmjx4sXasmWL9efgGzdurJ07d+rPf/6zevbsab1UNjAwkHsogLsYZyhQpIIP3wVu3Lihl156SW3atFGtWrXk7e2tL7/8UidPnixxP23atLH+u+DSqoyMDLu3CQ4OliTrNqmpqerSpYtN+duf3+73v/+91q5dq3bt2ukPf/iDEhISrOuSkpJ09OhR+fj4yNvbW97e3qpZs6auXbumY8eOlbhf/Fd4eLi6d++u999/X5J07Ngx7dixw3p2yez4KWDPfDWS9I9//EO/+MUvFBQUJG9vb82dO9fuOm6ty555b1q1amXzy3XBwcElju1p06Zp7Nix6tOnj1555RWb8ZWUlKSVK1dax6C3t7f69++v/Px8paWllan9VVlp8yRJZe9re8begAEDrPtq1aqVdbuQkBCbb/Vbtmwpf39/paSkOPKwi1Xasfbt21cNGzZU48aNNWLECK1Zs0bZ2dllqqOkfRw6dEjXrl1T3759bdqwevVq3l+BewxnKFCk239d580339Rbb72lJUuWqHXr1vLy8tKUKVOUm5tb4n5uv5nbYrEoPz/f7m0Kzhrcuk1Z5jmRbv5nf+LECW3atElxcXF66KGH9Mwzz+iNN95Qfn6+OnbsqDVr1hTariJvCr4XjBkzRhMmTNCf/vQnxcbGqmHDhnrooYckmR8/BeyZr2bXrl16/PHHtXDhQvXv319+fn5au3at3nzzzTIdhz3z3khlH9sLFizQk08+qU2bNunzzz/X/PnztXbtWv3qV79Sfn6+fve732nSpEmFtgsNDS1T+6uy0uZJklTmvrZn7L333nu6evWqpP+Oi+LG0a3LHTl/UlFKO1Z3d3d988032rZtm7788kvNmzdPCxYs0N69e+0+W+Dj41PsPgpeD5s2bVL9+vVttmOuJ+DeQqCAXXbs2KHBgwfrt7/9raSb/1EdOXJEERERFdqOFi1aaM+ePRoxYoR12a03thanTp06io6OVnR0tB588EE9++yzeuONN9ShQwetW7dOgYGB8vX1LXJb5i+xz9ChQzV58mT97//+r1atWqWYmBjrB6c7HT+3zldTr149SYXnq/n666/VsGFDm5+dPnHihE0Ze/6WLVu21KpVq3TlyhVrsL593huzmjdvrubNm2vq1Kl64oknFBsbq1/96lfq0KGDDh48WOjDcFnbjtLZ09e3smfs3f5h+dbtTp06ZT1LcejQIWVmZlrHfZ06dQrNl5ScnGwTVu39uxdVzp5jdXNzU58+fdSnTx/Nnz9f/v7+2rJli4YMGWJ33cXto2/fvvLw8NDJkyetl7EW1W6p8KVpAO4uXPIEuzRt2lSbN29WQkKCUlJS9Lvf/U7nzp2r8HZMnDhRK1as0KpVq3TkyBG9+OKL+u6770r81Zt58+bp008/1dGjR3Xw4EFt3LjR+h/68OHDVbt2bQ0ePFg7duxQWlqa4uPjNXnyZOs3m2FhYfruu++Umpqqn3/+2aHfIN5LvL29NWzYMP3xj3/U2bNnbX6t5U7Hz63z1ezfv187duwoNF9N06ZNdfLkSa1du1bHjh3TO++8o48//timTFhYmNLS0pScnKyff/5ZOTk5heoaPny4qlevrlGjRun777/X1q1bC817U1ZXr17VhAkTtG3bNp04cUJff/219u7dax2HM2fOVGJiop555hklJyfryJEj+uc//6mJEyfatH379u06c+aMfv75Z1PtgH19fSt7xl5x27Vp00bDhw/XN998oz179mjkyJHq2bOn9ZLS3r17a9++fVq9erWOHDmi+fPnFwoYYWFh2r17t44fP66ff/652LNgRY2P0o5148aNeuedd5ScnKwTJ05o9erVys/PV4sWLeyuu6R9+Pj4aMaMGZo6dapWrVqlY8eO6dtvv9Wf/vQnrVq1SpLUsGFDWSwWbdy4UT/99FOJ98MBqLwIFLDL3Llz1aFDB/Xv31+RkZEKCgoq9HOHFWH48OGaPXu2ZsyYoQ4dOigtLU3R0dGFbqS8lbu7u2bPnq02bdqoR48ecnV11dq1ayXdnONk+/btCg0N1ZAhQxQREaHRo0fr6tWr1jMWMTExatGihTp16qQ6dero66+/rpBjvRuNGTNGFy5cUJ8+fWwuH7nT8VMwX01OTo66dOmisWPH6qWXXrIpM3jwYE2dOlUTJkxQu3btlJCQoLlz59qUeeyxxxQVFaVevXqpTp06Rf50bcG8N//5z3/UuXNn/frXv9ZDDz2kpUuXlq0zbuHq6qrz589r5MiRat68uYYOHaoBAwZo4cKFkm7eNxQfH68jR47owQcfVPv27TV37lzrPUTSzXl0jh8/riZNmnA53h2wp69vZc/YK0rBTxEHBASoR48e6tOnjxo3bqx169ZZy/Tv319z587VH/7wB3Xu3FmXLl3SyJEjbfYzY8YMubq6qmXLlqpTp06x9wQVNT5KO1Z/f39t2LBBvXv3VkREhN599119+OGH1vtA7Km7tH288MILmjdvnhYtWqSIiAj1799f/+///T81atRI0s2zOwsXLtSsWbNUt27dYn9tC7jbubhU7M99VzSLUd4TMtwDsrKy5Ofnp8zMzEKXxVy7dk1paWlq1KiRzYfayjyx3b2mb9++CgoK0gcffODsppSLyj6xHe59TGwHZyru/1kAjlPSZ117cA9FOXGvXkvNOy+ye7K5O+VazadKhIns7Gy9++676t+/v1xdXfXhhx8qLi7O7omm7kb+/v6aNHm63RPNOUKNGl6ECVjVq1tTmz6YU24TzRXF38+LMAEAdwkCRTlyr15LqgIf8iuSxWLRZ599phdffFE5OTlq0aKFPvroI/Xp08fZTStX/v7+fMCHU9WrW5MP+ACAIhEocFfx9PRUXFycs5sBAACA/8NN2QAAAABMI1AAAAAAMI1A4SD8WBYAAI7H/69A5UeguEMFM5pmZ2c7uSUAANx7cnNzJd2czwVA5cRN2XfI1dVV/v7+ysjIkHRzUqySZm0GAAD2yc/P108//aQaNWrIzY2PLEBlxavTAYKCgiTJGioAAIBjuLi4KDQ0lC/rgEqMQOEAFotFwcHBCgwM1PXrFTMzNgAAVYG7u7tcXLhCG6jMCBQO5OrqyjWeAAAAqFKI/AAAAABMI1AAAAAAMI1AAQAAAMA07qGwQ8GkOllZWU5uCQAAAOBYBZ9xzU4kSaCww6VLlyRJISEhTm4JAAAAUD4uXbokPz+/Mm9nMZjTvlT5+fk6e/asfHx8nPI72FlZWQoJCdGpU6fk6+tb4fXfK+hHx6Af7xx96Bj0o2PQj45BPzoG/egYZe1HwzB06dIl1atXz9TPNHOGwg4uLi5q0KCBs5shX19fXlwOQD86Bv145+hDx6AfHYN+dAz60THoR8coSz+aOTNRgJuyAQAAAJhGoAAAAABgGoHiLuDh4aH58+fLw8PD2U25q9GPjkE/3jn60DHoR8egHx2DfnQM+tExKrofuSkbAAAAgGmcoQAAAABgGoECAAAAgGkECgAAAACmESgquWXLlqlRo0aqXr26OnbsqB07dji7SZXGokWL1LlzZ/n4+CgwMFCPPvqoUlNTbcpER0fLYrHYPO6//36bMjk5OZo4caJq164tLy8v/fKXv9Tp06cr8lCcasGCBYX6KCgoyLreMAwtWLBA9erVk6enpyIjI3Xw4EGbfVT1PpSksLCwQv1osVj0zDPPSGIsFmf79u0aNGiQ6tWrJ4vFok8++cRmvaPG34ULFzRixAj5+fnJz89PI0aM0MWLF8v56CpOSf14/fp1zZw5U61bt5aXl5fq1aunkSNH6uzZszb7iIyMLDRGH3/8cZsyVbkfJce9jqt6Pxb1XmmxWPT6669by1T18WjPZ5zK9P5IoKjE1q1bpylTpui5557Tt99+qwcffFADBgzQyZMnnd20SiE+Pl7PPPOMdu3apc2bNysvL0/9+vXTlStXbMpFRUUpPT3d+vjss89s1k+ZMkUff/yx1q5dq507d+ry5csaOHCgbty4UZGH41StWrWy6aMDBw5Y17322mtavHixli5dqr179yooKEh9+/bVpUuXrGXoQ2nv3r02fbh582ZJ0m9+8xtrGcZiYVeuXFHbtm21dOnSItc7avw9+eSTSk5O1hdffKEvvvhCycnJGjFiRLkfX0UpqR+zs7P1zTffaO7cufrmm2+0YcMG/fDDD/rlL39ZqGxMTIzNGP3zn/9ss74q92MBR7yOq3o/3tp/6enpev/992WxWPTYY4/ZlKvK49GezziV6v3RQKXVpUsXY9y4cTbLwsPDjVmzZjmpRZVbRkaGIcmIj4+3Lhs1apQxePDgYre5ePGiUa1aNWPt2rXWZWfOnDFcXFyML774ojybW2nMnz/faNu2bZHr8vPzjaCgIOOVV16xLrt27Zrh5+dnvPvuu4Zh0IfFmTx5stGkSRMjPz/fMAzGoj0kGR9//LH1uaPG36FDhwxJxq5du6xlEhMTDUnG4cOHy/moKt7t/ViUPXv2GJKMEydOWJf17NnTmDx5crHb0I+OeR3Tj4UNHjzY6N27t80yxqOt2z/jVLb3R85QVFK5ublKSkpSv379bJb369dPCQkJTmpV5ZaZmSlJqlmzps3ybdu2KTAwUM2bN1dMTIwyMjKs65KSknT9+nWbfq5Xr57uu+++KtXPR44cUb169dSoUSM9/vjj+ve//y1JSktL07lz52z6x8PDQz179rT2D31YWG5urv72t79p9OjRslgs1uWMxbJx1PhLTEyUn5+funbtai1z//33y8/Pr8r2bWZmpiwWi/z9/W2Wr1mzRrVr11arVq00Y8YMm2866ceb7vR1TD/a+vHHH7Vp0yaNGTOm0DrG43/d/hmnsr0/upk/NJSnn3/+WTdu3FDdunVtltetW1fnzp1zUqsqL8MwNG3aNP3iF7/QfffdZ10+YMAA/eY3v1HDhg2VlpamuXPnqnfv3kpKSpKHh4fOnTsnd3d3BQQE2OyvKvVz165dtXr1ajVv3lw//vijXnzxRXXv3l0HDx609kFR4/DEiROSRB8W4ZNPPtHFixcVHR1tXcZYLDtHjb9z584pMDCw0P4DAwOrZN9eu3ZNs2bN0pNPPilfX1/r8uHDh6tRo0YKCgrS999/r9mzZ2v//v3Wy/foR8e8julHW6tWrZKPj4+GDBlis5zx+F9FfcapbO+PBIpK7tZvN6Wbg+r2ZZAmTJig7777Tjt37rRZPmzYMOu/77vvPnXq1EkNGzbUpk2bCr153aoq9fOAAQOs/27durW6deumJk2aaNWqVdabDc2Mw6rUh7dbsWKFBgwYoHr16lmXMRbNc8T4K6p8Vezb69ev6/HHH1d+fr6WLVtmsy4mJsb67/vuu0/NmjVTp06d9M0336hDhw6S6EdHvY6rej/e6v3339fw4cNVvXp1m+WMx/8q7jOOVHneH7nkqZKqXbu2XF1dC6XDjIyMQmm0qps4caL++c9/auvWrWrQoEGJZYODg9WwYUMdOXJEkhQUFKTc3FxduHDBplxV7mcvLy+1bt1aR44csf7aU0njkD60deLECcXFxWns2LEllmMsls5R4y8oKEg//vhjof3/9NNPVapvr1+/rqFDhyotLU2bN2+2OTtRlA4dOqhatWo2Y5R+tGXmdUw//teOHTuUmppa6vulVHXHY3GfcSrb+yOBopJyd3dXx44draf2CmzevFndu3d3UqsqF8MwNGHCBG3YsEFbtmxRo0aNSt3m/PnzOnXqlIKDgyVJHTt2VLVq1Wz6OT09Xd9//32V7eecnBylpKQoODjYerr51v7Jzc1VfHy8tX/oQ1uxsbEKDAzUI488UmI5xmLpHDX+unXrpszMTO3Zs8daZvfu3crMzKwyfVsQJo4cOaK4uDjVqlWr1G0OHjyo69evW8co/ViYmdcx/fhfK1asUMeOHdW2bdtSy1a18VjaZ5xK9/5o//3lqGhr1641qlWrZqxYscI4dOiQMWXKFMPLy8s4fvy4s5tWKfz+9783/Pz8jG3bthnp6enWR3Z2tmEYhnHp0iVj+vTpRkJCgpGWlmZs3brV6Natm1G/fn0jKyvLup9x48YZDRo0MOLi4oxvvvnG6N27t9G2bVsjLy/PWYdWoaZPn25s27bN+Pe//23s2rXLGDhwoOHj42MdZ6+88orh5+dnbNiwwThw4IDxxBNPGMHBwfRhEW7cuGGEhoYaM2fOtFnOWCzepUuXjG+//db49ttvDUnG4sWLjW+//db660OOGn9RUVFGmzZtjMTERCMxMdFo3bq1MXDgwAo/3vJSUj9ev37d+OUvf2k0aNDASE5Otnm/zMnJMQzDMI4ePWosXLjQ2Lt3r5GWlmZs2rTJCA8PN9q3b08//l8/OvJ1XJX7sUBmZqZRo0YNY/ny5YW2ZzyW/hnHMCrX+yOBopL705/+ZDRs2NBwd3c3OnToYPOTqFWdpCIfsbGxhmEYRnZ2ttGvXz+jTp06RrVq1YzQ0FBj1KhRxsmTJ232c/XqVWPChAlGzZo1DU9PT2PgwIGFytzLhg0bZgQHBxvVqlUz6tWrZwwZMsQ4ePCgdX1+fr4xf/58IygoyPDw8DB69OhhHDhwwGYfVb0PC/zrX/8yJBmpqak2yxmLxdu6dWuRr+NRo0YZhuG48Xf+/Hlj+PDhho+Pj+Hj42MMHz7cuHDhQgUdZfkrqR/T0tKKfb/cunWrYRiGcfLkSaNHjx5GzZo1DXd3d6NJkybGpEmTjPPnz9vUU5X70ZGv46rcjwX+/Oc/G56ensbFixcLbc94LP0zjmFUrvdHy/81GgAAAADKjHsoAAAAAJhGoAAAAABgGoECAAAAgGkECgAAAACmESgAAAAAmEagAAAAAGAagQIAAACAaQQKAAAAAKYRKAAATrFgwQK1a9fO2c0AANwhZsoGADicxWIpcf2oUaO0dOlS5eTkqFatWhXUKgBAeSBQAAAc7ty5c9Z/r1u3TvPmzVNqaqp1maenp/z8/JzRNACAg3HJEwDA4YKCgqwPPz8/WSyWQstuv+QpOjpajz76qF5++WXVrVtX/v7+WrhwofLy8vTss8+qZs2aatCggd5//32bus6cOaNhw4YpICBAtWrV0uDBg3X8+PGKPWAAqMIIFACASmPLli06e/astm/frsWLF2vBggUaOHCgAgICtHv3bo0bN07jxo3TqVOnJEnZ2dnq1auXvL29tX37du3cuVPe3t6KiopSbm6uk48GAKoGAgUAoNKoWbOm3nnnHbVo0UKjR49WixYtlJ2drT/+8Y9q1qyZZs+eLXd3d3399deSpLVr18rFxUXvvfeeWrdurYiICMXGxurkyZPatm2bcw8GAKoIN2c3AACAAq1atZKLy3+/66pbt67uu+8+63NXV1fVqlVLGRkZkqSkpCQdPXpUPj4+Nvu5du2ajh07VjGNBoAqjkABAKg0qlWrZvPcYrEUuSw/P1+SlJ+fr44dO2rNmjWF9lWnTp3yaygAwIpAAQC4a3Xo0EHr1q1TYGCgfH19nd0cAKiSuIcCAHDXGj58uGrXrq3Bgwdrx44dSktLU3x8vCZPnqzTp087u3kAUCUQKAAAd60aNWpo+/btCg0N1ZAhQxQREaHRo0fr6tWrnLEAgArCxHYAAAAATOMMBQAAAADTCBQAAAAATCNQAAAAADCNQAEAAADANAIFAAAAANMIFAAAAABMI1AAAAAAMI1AAQAAAMA0AgUAAAAA0wgUAAAAAEwjUAAAAAAwjUABAAAAwLT/D9blFiP5mbuxAAAAAElFTkSuQmCC", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "cv_cmap = matplotlib.colormaps[\"cividis\"]\n", + "plt.figure(figsize=(8, 3))\n", + "\n", + "for i, (train_mask, valid_mask) in enumerate(cv_folds):\n", + " idx = np.array([np.nan] * time_horizon)\n", + " idx[np.arange(*train_mask)] = 1\n", + " idx[np.arange(*valid_mask)] = 0\n", + " plt.scatter(\n", + " range(time_horizon),\n", + " [i + 0.5] * time_horizon,\n", + " c=idx,\n", + " marker=\"_\",\n", + " capstyle=\"butt\",\n", + " s=1,\n", + " lw=20,\n", + " cmap=cv_cmap,\n", + " vmin=-1.5,\n", + " vmax=1.5,\n", + " )\n", + "\n", + "idx = np.array([np.nan] * time_horizon)\n", + "idx[np.arange(*holdout)] = -1\n", + "plt.scatter(\n", + " range(time_horizon),\n", + " [n_folds + 0.5] * time_horizon,\n", + " c=idx,\n", + " marker=\"_\",\n", + " capstyle=\"butt\",\n", + " s=1,\n", + " lw=20,\n", + " cmap=cv_cmap,\n", + " vmin=-1.5,\n", + " vmax=1.5,\n", + ")\n", + "\n", + "plt.xlabel(\"Time\")\n", + "plt.yticks(\n", + " ticks=np.arange(n_folds + 1) + 0.5,\n", + " labels=[f\"Fold {i}\" for i in range(n_folds)] + [\"Holdout\"],\n", + ")\n", + "plt.ylim([len(cv_folds) + 1.2, -0.2])\n", + "\n", + "norm = matplotlib.colors.Normalize(vmin=-1.5, vmax=1.5)\n", + "plt.legend(\n", + " [\n", + " Patch(color=cv_cmap(norm(1))),\n", + " Patch(color=cv_cmap(norm(0))),\n", + " Patch(color=cv_cmap(norm(-1))),\n", + " ],\n", + " [\"Training set\", \"Validation set\", \"Held-out test set\"],\n", + " ncol=3,\n", + " loc=\"best\",\n", + ")\n", + "plt.tight_layout()" + ] + }, + { + "cell_type": "markdown", + "id": "2b5dffdb-01ae-4703-950d-1ee7ea0b77dd", + "metadata": {}, + "source": [ + "### Launch a Dask client on Kubernetes\n", + "Let us set up a Dask cluster using the `KubeCluster` class." + ] + }, + { + "cell_type": "code", + "execution_count": 69, + "id": "580f18c9-1e99-4c43-a244-8db3bfe3c1ab", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "9a276ff2156246e0a62c3d40e862ddac", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "cluster = KubeCluster(\n",
+    "    name=\"rapids-dask\",\n",
+    "    image=rapids_image,\n",
+    "    worker_command=\"dask-cuda-worker\",\n",
+    "    n_workers=n_workers,\n",
+    "    resources={\"limits\": {\"nvidia.com/gpu\": \"1\"}},\n",
+    "    env={\"EXTRA_PIP_PACKAGES\": \"optuna gcsfs\"},\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 70,
+   "id": "54e8254c-f29e-467d-95b0-f111122deffc",
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "faa42202d7cd441d9a82072fa9d25ac5",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/html": [
+       "
\n", + "
\n", + "
\n", + "
\n", + "

KubeCluster

\n", + "

rapids-dask

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + " Dashboard: http://rapids-dask-scheduler.kubeflow-user-example-com:8787/status\n", + " \n", + " Workers: 0\n", + "
\n", + " Total threads: 0\n", + " \n", + " Total memory: 0 B\n", + "
\n", + "\n", + "
\n", + " \n", + "

Scheduler Info

\n", + "
\n", + "\n", + "
\n", + "
\n", + "
\n", + "
\n", + "

Scheduler

\n", + "

Scheduler-ca08b554-394b-419d-b380-7b3ccdfd882f

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + " Comm: tcp://10.36.0.26:8786\n", + " \n", + " Workers: 0\n", + "
\n", + " Dashboard: http://10.36.0.26:8787/status\n", + " \n", + " Total threads: 0\n", + "
\n", + " Started: Just now\n", + " \n", + " Total memory: 0 B\n", + "
\n", + "
\n", + "
\n", + "\n", + "
\n", + " \n", + "

Workers

\n", + "
\n", + "\n", + " \n", + "\n", + "
\n", + "
\n", + "\n", + "
\n", + "
\n", + "
" + ], + "text/plain": [ + "KubeCluster(rapids-dask, 'tcp://rapids-dask-scheduler.kubeflow-user-example-com:8786', workers=0, threads=0, memory=0 B)" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "cluster" + ] + }, + { + "cell_type": "code", + "execution_count": 73, + "id": "ad8be605-68bc-45d6-962e-dd3b4c5e8658", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "
\n", + "
\n", + "

Client

\n", + "

Client-dcc90a0f-5e32-11ee-8529-fa610e3dbb88

\n", + " \n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + "
Connection method: Cluster objectCluster type: dask_kubernetes.KubeCluster
\n", + " Dashboard: http://rapids-dask-scheduler.kubeflow-user-example-com:8787/status\n", + "
\n", + "\n", + " \n", + " \n", + " \n", + "\n", + " \n", + "
\n", + "

Cluster Info

\n", + "
\n", + "
\n", + "
\n", + "
\n", + "

KubeCluster

\n", + "

rapids-dask

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + " Dashboard: http://rapids-dask-scheduler.kubeflow-user-example-com:8787/status\n", + " \n", + " Workers: 1\n", + "
\n", + " Total threads: 1\n", + " \n", + " Total memory: 83.48 GiB\n", + "
\n", + "\n", + "
\n", + " \n", + "

Scheduler Info

\n", + "
\n", + "\n", + "
\n", + "
\n", + "
\n", + "
\n", + "

Scheduler

\n", + "

Scheduler-ca08b554-394b-419d-b380-7b3ccdfd882f

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + " Comm: tcp://10.36.0.26:8786\n", + " \n", + " Workers: 1\n", + "
\n", + " Dashboard: http://10.36.0.26:8787/status\n", + " \n", + " Total threads: 1\n", + "
\n", + " Started: Just now\n", + " \n", + " Total memory: 83.48 GiB\n", + "
\n", + "
\n", + "
\n", + "\n", + "
\n", + " \n", + "

Workers

\n", + "
\n", + "\n", + " \n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "

Worker: rapids-dask-default-worker-4f5ea8bc10

\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + "\n", + "
\n", + " Comm: tcp://10.36.2.23:34385\n", + " \n", + " Total threads: 1\n", + "
\n", + " Dashboard: http://10.36.2.23:8788/status\n", + " \n", + " Memory: 83.48 GiB\n", + "
\n", + " Nanny: tcp://10.36.2.23:44783\n", + "
\n", + " Local directory: /tmp/dask-scratch-space/worker-h2_cxa0d\n", + "
\n", + " GPU: NVIDIA A100-SXM4-40GB\n", + " \n", + " GPU memory: 40.00 GiB\n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "\n", + "
\n", + "
\n", + "\n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "\n", + "
\n", + "
" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 73, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "client = Client(cluster)\n", + "client" + ] + }, + { + "cell_type": "markdown", + "id": "fb7527b9-0039-43e5-bb25-985501480313", + "metadata": {}, + "source": [ + "### Define the custom evaluation metric\n", + "\n", + "The M5 forecasting competition defines a custom metric called WRMSSE as follows:\n", + "\n", + "$$\n", + "WRMSSE = \\sum w_i \\cdot RMSSE_i\n", + "$$\n", + "\n", + "i.e. WRMSEE is a weighted sum of RMSSE for all product items $i$. RMSSE is in turn defined to be\n", + "\n", + "$$\n", + "RMSSE = \\sqrt{\\frac{1/h \\cdot \\sum_t{\\left(Y_t - \\hat{Y}_t\\right)}^2}{1/(n-1)\\sum_t{(Y_t - Y_{t-1})}^2}}\n", + "$$\n", + "\n", + "where the squared error of the prediction (forecast) is normalized by the speed at which the sales amount changes per unit in the training data.\n", + "\n", + "Here is the implementation of the WRMSSE using cuDF. We use the product weights $w_i$ as computed in the first preprocessing notebook." + ] + }, + { + "cell_type": "code", + "execution_count": 74, + "id": "a3d7f300-310d-4c09-b30a-4d0c54903d23", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def wrmsse(product_weights, df, pred_sales, train_mask, valid_mask):\n", + " \"\"\"Compute WRMSSE metric\"\"\"\n", + " df_train = df[(df[\"day_id\"] >= train_mask[0]) & (df[\"day_id\"] < train_mask[1])]\n", + " df_valid = df[(df[\"day_id\"] >= valid_mask[0]) & (df[\"day_id\"] < valid_mask[1])]\n", + "\n", + " # Compute denominator: 1/(n-1) * sum( (y(t) - y(t-1))**2 )\n", + " diff = (\n", + " df_train.sort_values([\"item_id\", \"day_id\"])\n", + " .groupby([\"item_id\"])[[\"sales\"]]\n", + " .diff(1)\n", + " )\n", + " x = (\n", + " df_train[[\"item_id\", \"day_id\"]]\n", + " .join(diff, how=\"left\")\n", + " .rename(columns={\"sales\": \"diff\"})\n", + " .sort_values([\"item_id\", \"day_id\"])\n", + " )\n", + " x[\"diff\"] = x[\"diff\"] ** 2\n", + " xx = x.groupby([\"item_id\"])[[\"diff\"]].agg([\"sum\", \"count\"]).sort_index()\n", + " xx.columns = xx.columns.map(\"_\".join)\n", + " xx[\"denominator\"] = xx[\"diff_sum\"] / xx[\"diff_count\"]\n", + " t = xx.reset_index()\n", + "\n", + " # Compute numerator: 1/h * sum( (y(t) - y_pred(t))**2 )\n", + " X_valid = df_valid.drop(columns=[\"item_id\", \"cat_id\", \"day_id\", \"sales\"])\n", + " if \"dept_id\" in X_valid.columns:\n", + " X_valid = X_valid.drop(columns=[\"dept_id\"])\n", + " df_pred = cudf.DataFrame(\n", + " {\n", + " \"item_id\": df_valid[\"item_id\"].copy(),\n", + " \"pred_sales\": pred_sales,\n", + " \"sales\": df_valid[\"sales\"].copy(),\n", + " }\n", + " )\n", + " df_pred[\"diff\"] = (df_pred[\"sales\"] - df_pred[\"pred_sales\"]) ** 2\n", + " yy = df_pred.groupby([\"item_id\"])[[\"diff\"]].agg([\"sum\", \"count\"]).sort_index()\n", + " yy.columns = yy.columns.map(\"_\".join)\n", + " yy[\"numerator\"] = yy[\"diff_sum\"] / yy[\"diff_count\"]\n", + "\n", + " zz = yy[[\"numerator\"]].join(xx[[\"denominator\"]], how=\"left\")\n", + " zz = zz.join(product_weights, how=\"left\").sort_index()\n", + " # Filter out zero denominator.\n", + " # This can occur if the product was never on sale during the period in the training set\n", + " zz = zz[zz[\"denominator\"] != 0]\n", + " zz[\"rmsse\"] = np.sqrt(zz[\"numerator\"] / zz[\"denominator\"])\n", + " t = zz[\"rmsse\"].multiply(zz[\"weights\"])\n", + " return zz[\"rmsse\"].multiply(zz[\"weights\"]).sum()" + ] + }, + { + "cell_type": "markdown", + "id": "17ed80bd-b6db-459d-b28a-7cf59ba2bdad", + "metadata": {}, + "source": [ + "### Define the training and hyperparameter search pipeline using Optuna\n", + "Optuna lets us define the training procedure iteratively, i.e. as if we were to write an ordinary function to train a single model. Instead of a fixed hyperparameter combination, the function now takes in a `trial` object which yields different hyperparameter combinations.\n", + "\n", + "In this example, we partition the training data according to the store and then fit a separate XGBoost model per data segment." + ] + }, + { + "cell_type": "code", + "execution_count": 75, + "id": "5f538176-2b69-4224-8570-4c0aedfde4d4", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def objective(trial):\n", + " fs = gcsfs.GCSFileSystem()\n", + " with fs.open(f\"{bucket_name}/product_weights.pkl\", \"rb\") as f:\n", + " product_weights = cudf.DataFrame(pd.read_pickle(f))\n", + " params = {\n", + " \"n_estimators\": 100,\n", + " \"verbosity\": 0,\n", + " \"learning_rate\": 0.01,\n", + " \"objective\": \"reg:tweedie\",\n", + " \"tree_method\": \"gpu_hist\",\n", + " \"grow_policy\": \"depthwise\",\n", + " \"predictor\": \"gpu_predictor\",\n", + " \"enable_categorical\": True,\n", + " \"lambda\": trial.suggest_float(\"lambda\", 1e-8, 100.0, log=True),\n", + " \"alpha\": trial.suggest_float(\"alpha\", 1e-8, 100.0, log=True),\n", + " \"colsample_bytree\": trial.suggest_float(\"colsample_bytree\", 0.2, 1.0),\n", + " \"max_depth\": trial.suggest_int(\"max_depth\", 2, 6, step=1),\n", + " \"min_child_weight\": trial.suggest_float(\n", + " \"min_child_weight\", 1e-8, 100, log=True\n", + " ),\n", + " \"gamma\": trial.suggest_float(\"gamma\", 1e-8, 1.0, log=True),\n", + " \"tweedie_variance_power\": trial.suggest_float(\"tweedie_variance_power\", 1, 2),\n", + " }\n", + " scores = [[] for store in STORES]\n", + "\n", + " for store_id, store in enumerate(STORES):\n", + " print(f\"Processing store {store}...\")\n", + " with fs.open(f\"{bucket_name}/combined_df_store_{store}.pkl\", \"rb\") as f:\n", + " df = cudf.DataFrame(pd.read_pickle(f))\n", + " for train_mask, valid_mask in cv_folds:\n", + " df_train = df[\n", + " (df[\"day_id\"] >= train_mask[0]) & (df[\"day_id\"] < train_mask[1])\n", + " ]\n", + " df_valid = df[\n", + " (df[\"day_id\"] >= valid_mask[0]) & (df[\"day_id\"] < valid_mask[1])\n", + " ]\n", + "\n", + " X_train, y_train = (\n", + " df_train.drop(\n", + " columns=[\"item_id\", \"dept_id\", \"cat_id\", \"day_id\", \"sales\"]\n", + " ),\n", + " df_train[\"sales\"],\n", + " )\n", + " X_valid = df_valid.drop(\n", + " columns=[\"item_id\", \"dept_id\", \"cat_id\", \"day_id\", \"sales\"]\n", + " )\n", + "\n", + " clf = xgb.XGBRegressor(**params)\n", + " clf.fit(X_train, y_train)\n", + " pred_sales = clf.predict(X_valid)\n", + " scores[store_id].append(\n", + " wrmsse(product_weights, df, pred_sales, train_mask, valid_mask)\n", + " )\n", + " del df_train, df_valid, X_train, y_train, clf\n", + " gc.collect()\n", + " del df\n", + " gc.collect()\n", + "\n", + " # We can sum WRMSSE scores over data segments because data segments contain disjoint sets of time series\n", + " return np.array(scores).sum(axis=0).mean()" + ] + }, + { + "cell_type": "markdown", + "id": "f2e4df5d-17a5-4b1c-8454-b5c9c1e631cf", + "metadata": {}, + "source": [ + "Using the Dask cluster client, we execute multiple training jobs in parallel. Optuna keeps track of the progress in the hyperparameter search using in-memory Dask storage." + ] + }, + { + "cell_type": "code", + "execution_count": 76, + "id": "6ba27561-5a31-4568-b341-1befc29516e8", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_1321/3389696366.py:7: ExperimentalWarning: DaskStorage is experimental (supported from v3.1.0). The interface can change in the future.\n", + " dask_storage = optuna.integration.DaskStorage(storage=backend_storage, client=client)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Testing hyperparameter combinations 0..2\n", + "Best cross-validation metric: 10.027767173304472, Time elapsed = 331.6198390149948\n", + "Testing hyperparameter combinations 2..4\n", + "Best cross-validation metric: 9.426913749927916, Time elapsed = 640.7606940959959\n", + "Testing hyperparameter combinations 4..6\n", + "Best cross-validation metric: 9.426913749927916, Time elapsed = 958.0816706369951\n", + "Testing hyperparameter combinations 6..8\n", + "Best cross-validation metric: 9.426913749927916, Time elapsed = 1295.700604706988\n", + "Testing hyperparameter combinations 8..9\n", + "Best cross-validation metric: 8.915009508695244, Time elapsed = 1476.1182343699911\n", + "Total time elapsed = 1476.1219055669935\n" + ] + } + ], + "source": [ + "##### Number of hyperparameter combinations to try in parallel\n", + "n_trials = 9 # Using a small n_trials so that the demo can finish quickly\n", + "# n_trials = 100\n", + "\n", + "# Optimize in parallel on your Dask cluster\n", + "backend_storage = optuna.storages.InMemoryStorage()\n", + "dask_storage = optuna.integration.DaskStorage(storage=backend_storage, client=client)\n", + "study = optuna.create_study(\n", + " direction=\"minimize\",\n", + " sampler=optuna.samplers.RandomSampler(seed=0),\n", + " storage=dask_storage,\n", + ")\n", + "futures = []\n", + "for i in range(0, n_trials, n_workers):\n", + " iter_range = (i, min([i + n_workers, n_trials]))\n", + " futures.append(\n", + " {\n", + " \"range\": iter_range,\n", + " \"futures\": [\n", + " client.submit(\n", + " # Work around bug https://github.com/optuna/optuna/issues/4859\n", + " lambda objective, n_trials: (\n", + " study.sampler.reseed_rng(),\n", + " study.optimize(objective, n_trials),\n", + " ),\n", + " objective,\n", + " n_trials=1,\n", + " pure=False,\n", + " )\n", + " for _ in range(*iter_range)\n", + " ],\n", + " }\n", + " )\n", + "\n", + "tstart = time.perf_counter()\n", + "for partition in futures:\n", + " iter_range = partition[\"range\"]\n", + " print(f\"Testing hyperparameter combinations {iter_range[0]}..{iter_range[1]}\")\n", + " _ = wait(partition[\"futures\"])\n", + " for fut in partition[\"futures\"]:\n", + " _ = fut.result() # Ensure that the training job was successful\n", + " tnow = time.perf_counter()\n", + " print(\n", + " f\"Best cross-validation metric: {study.best_value}, Time elapsed = {tnow - tstart}\"\n", + " )\n", + "tend = time.perf_counter()\n", + "print(f\"Total time elapsed = {tend - tstart}\")" + ] + }, + { + "cell_type": "markdown", + "id": "baa92e3e-73e6-4de5-92b0-f89aee77cbc8", + "metadata": {}, + "source": [ + "Once the hyperparameter search is complete, we fetch the optimal hyperparameter combination using the attributes of the `study` object." + ] + }, + { + "cell_type": "code", + "execution_count": 77, + "id": "2a01c20d-9581-4c36-bcb6-c9440f79b652", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'lambda': 2.6232990699579064e-06,\n", + " 'alpha': 0.004085800094564677,\n", + " 'colsample_bytree': 0.4064535567263888,\n", + " 'max_depth': 6,\n", + " 'min_child_weight': 9.652128310148716e-08,\n", + " 'gamma': 3.4446109254037165e-07,\n", + " 'tweedie_variance_power': 1.0914258082324833}" + ] + }, + "execution_count": 77, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "study.best_params" + ] + }, + { + "cell_type": "code", + "execution_count": 78, + "id": "6de40f4c-2b29-49da-881d-76c9be3424cf", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "FrozenTrial(number=8, state=TrialState.COMPLETE, values=[8.915009508695244], datetime_start=datetime.datetime(2023, 9, 28, 19, 35, 29, 888497), datetime_complete=datetime.datetime(2023, 9, 28, 19, 38, 30, 299541), params={'lambda': 2.6232990699579064e-06, 'alpha': 0.004085800094564677, 'colsample_bytree': 0.4064535567263888, 'max_depth': 6, 'min_child_weight': 9.652128310148716e-08, 'gamma': 3.4446109254037165e-07, 'tweedie_variance_power': 1.0914258082324833}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'lambda': FloatDistribution(high=100.0, log=True, low=1e-08, step=None), 'alpha': FloatDistribution(high=100.0, log=True, low=1e-08, step=None), 'colsample_bytree': FloatDistribution(high=1.0, log=False, low=0.2, step=None), 'max_depth': IntDistribution(high=6, log=False, low=2, step=1), 'min_child_weight': FloatDistribution(high=100.0, log=True, low=1e-08, step=None), 'gamma': FloatDistribution(high=1.0, log=True, low=1e-08, step=None), 'tweedie_variance_power': FloatDistribution(high=2.0, log=False, low=1.0, step=None)}, trial_id=8, value=None)" + ] + }, + "execution_count": 78, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "study.best_trial" + ] + }, + { + "cell_type": "code", + "execution_count": 79, + "id": "e0e1e8a4-846a-4d70-bc84-b9e20bebe067", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'lambda': 2.6232990699579064e-06,\n", + " 'alpha': 0.004085800094564677,\n", + " 'colsample_bytree': 0.4064535567263888,\n", + " 'max_depth': 6,\n", + " 'min_child_weight': 9.652128310148716e-08,\n", + " 'gamma': 3.4446109254037165e-07,\n", + " 'tweedie_variance_power': 1.0914258082324833}" + ] + }, + "execution_count": 79, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Make a deep copy to preserve the dictionary after deleting the Dask cluster\n", + "best_params = copy.deepcopy(study.best_params)\n", + "best_params" + ] + }, + { + "cell_type": "code", + "execution_count": 80, + "id": "354ab00c-e340-4de3-a3cf-308b476cf78a", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "fs = gcsfs.GCSFileSystem()\n", + "with fs.open(f\"{bucket_name}/params.json\", \"w\") as f:\n", + " json.dump(best_params, f)" + ] + }, + { + "cell_type": "markdown", + "id": "a47ec686-1b1b-49ac-8b87-012e0fcbead7", + "metadata": {}, + "source": [ + "### Train the final XGBoost model and evaluate\n", + "Using the optimal hyperparameters found in the search, fit a new model using the whole training data. As in the previous section, we fit a separate XGBoost model per data segment." + ] + }, + { + "cell_type": "code", + "execution_count": 81, + "id": "fafe5c72-576e-45f2-9f9d-3e5cf72874fc", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "fs = gcsfs.GCSFileSystem()\n", + "with fs.open(f\"{bucket_name}/params.json\", \"r\") as f:\n", + " best_params = json.load(f)\n", + "with fs.open(f\"{bucket_name}/product_weights.pkl\", \"rb\") as f:\n", + " product_weights = cudf.DataFrame(pd.read_pickle(f))" + ] + }, + { + "cell_type": "code", + "execution_count": 82, + "id": "f1783769-b19e-4515-8ff0-5e1b662bfcee", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def final_train(best_params):\n", + " fs = gcsfs.GCSFileSystem()\n", + " params = {\n", + " \"n_estimators\": 100,\n", + " \"verbosity\": 0,\n", + " \"learning_rate\": 0.01,\n", + " \"objective\": \"reg:tweedie\",\n", + " \"tree_method\": \"gpu_hist\",\n", + " \"grow_policy\": \"depthwise\",\n", + " \"predictor\": \"gpu_predictor\",\n", + " \"enable_categorical\": True,\n", + " }\n", + " params.update(best_params)\n", + " model = {}\n", + " train_mask = [0, 1914]\n", + "\n", + " for store in STORES:\n", + " print(f\"Processing store {store}...\")\n", + " with fs.open(f\"{bucket_name}/combined_df_store_{store}.pkl\", \"rb\") as f:\n", + " df = cudf.DataFrame(pd.read_pickle(f))\n", + "\n", + " df_train = df[(df[\"day_id\"] >= train_mask[0]) & (df[\"day_id\"] < train_mask[1])]\n", + " X_train, y_train = (\n", + " df_train.drop(columns=[\"item_id\", \"dept_id\", \"cat_id\", \"day_id\", \"sales\"]),\n", + " df_train[\"sales\"],\n", + " )\n", + "\n", + " clf = xgb.XGBRegressor(**params)\n", + " clf.fit(X_train, y_train)\n", + " model[store] = clf\n", + " del df\n", + " gc.collect()\n", + "\n", + " return model" + ] + }, + { + "cell_type": "code", + "execution_count": 83, + "id": "3fcaadeb-fa8f-4266-9280-e0f25cf1afca", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Processing store CA_1...\n", + "Processing store CA_2...\n", + "Processing store CA_3...\n", + "Processing store CA_4...\n", + "Processing store TX_1...\n", + "Processing store TX_2...\n", + "Processing store TX_3...\n", + "Processing store WI_1...\n", + "Processing store WI_2...\n", + "Processing store WI_3...\n" + ] + } + ], + "source": [ + "model = final_train(best_params)" + ] + }, + { + "cell_type": "markdown", + "id": "57da12bb-5801-46b5-9eef-16f62ed79737", + "metadata": {}, + "source": [ + "Let's now evaluate the final model using the held-out test set:" + ] + }, + { + "cell_type": "code", + "execution_count": 84, + "id": "ab4c5ede-cf84-4a62-a1d9-ce18d9f1d4a3", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "WRMSSE metric on the held-out test set: 9.478942050051291\n" + ] + } + ], + "source": [ + "test_wrmsse = 0\n", + "for store in STORES:\n", + " with fs.open(f\"{bucket_name}/combined_df_store_{store}.pkl\", \"rb\") as f:\n", + " df = cudf.DataFrame(pd.read_pickle(f))\n", + " df_test = df[(df[\"day_id\"] >= holdout[0]) & (df[\"day_id\"] < holdout[1])]\n", + " X_test = df_test.drop(columns=[\"item_id\", \"dept_id\", \"cat_id\", \"day_id\", \"sales\"])\n", + " pred_sales = model[store].predict(X_test)\n", + " test_wrmsse += wrmsse(\n", + " product_weights, df, pred_sales, train_mask=[0, 1914], valid_mask=holdout\n", + " )\n", + "print(f\"WRMSSE metric on the held-out test set: {test_wrmsse}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 85, + "id": "987923e1-d26f-44f7-8d5a-559f28f38bb3", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Save the model to the Cloud Storage\n", + "with fs.open(f\"{bucket_name}/final_model.pkl\", \"wb\") as f:\n", + " pickle.dump(model, f)" + ] + }, + { + "cell_type": "markdown", + "id": "e8e55536-4540-495c-a930-d54afb7eecaa", + "metadata": {}, + "source": [ + "## Create an ensemble model using a different strategy for segmenting sales data\n", + "It is common to create an ensemble model where multiple machine learning methods are used to obtain better predictive performance. Prediction is made from an ensemble model by averaging the prediction output of the constituent models.\n", + "\n", + "In this example, we will create a second model by segmenting the sales data in a different way. Instead of splitting by stores, we will split the data by both stores and product categories." + ] + }, + { + "cell_type": "code", + "execution_count": 86, + "id": "f8180260-cae0-4af3-8fdc-ebd654e7683d", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def objective_alt(trial):\n", + " fs = gcsfs.GCSFileSystem()\n", + " with fs.open(f\"{bucket_name}/product_weights.pkl\", \"rb\") as f:\n", + " product_weights = cudf.DataFrame(pd.read_pickle(f))\n", + " params = {\n", + " \"n_estimators\": 100,\n", + " \"verbosity\": 0,\n", + " \"learning_rate\": 0.01,\n", + " \"objective\": \"reg:tweedie\",\n", + " \"tree_method\": \"gpu_hist\",\n", + " \"grow_policy\": \"depthwise\",\n", + " \"predictor\": \"gpu_predictor\",\n", + " \"enable_categorical\": True,\n", + " \"lambda\": trial.suggest_float(\"lambda\", 1e-8, 100.0, log=True),\n", + " \"alpha\": trial.suggest_float(\"alpha\", 1e-8, 100.0, log=True),\n", + " \"colsample_bytree\": trial.suggest_float(\"colsample_bytree\", 0.2, 1.0),\n", + " \"max_depth\": trial.suggest_int(\"max_depth\", 2, 6, step=1),\n", + " \"min_child_weight\": trial.suggest_float(\n", + " \"min_child_weight\", 1e-8, 100, log=True\n", + " ),\n", + " \"gamma\": trial.suggest_float(\"gamma\", 1e-8, 1.0, log=True),\n", + " \"tweedie_variance_power\": trial.suggest_float(\"tweedie_variance_power\", 1, 2),\n", + " }\n", + " scores = [[] for i in range(len(STORES) * len(DEPTS))]\n", + "\n", + " for store_id, store in enumerate(STORES):\n", + " for dept_id, dept in enumerate(DEPTS):\n", + " print(f\"Processing store {store}, department {dept}...\")\n", + " with fs.open(\n", + " f\"{bucket_name}/combined_df_store_{store}_dept_{dept}.pkl\", \"rb\"\n", + " ) as f:\n", + " df = cudf.DataFrame(pd.read_pickle(f))\n", + " for train_mask, valid_mask in cv_folds:\n", + " df_train = df[\n", + " (df[\"day_id\"] >= train_mask[0]) & (df[\"day_id\"] < train_mask[1])\n", + " ]\n", + " df_valid = df[\n", + " (df[\"day_id\"] >= valid_mask[0]) & (df[\"day_id\"] < valid_mask[1])\n", + " ]\n", + "\n", + " X_train, y_train = (\n", + " df_train.drop(columns=[\"item_id\", \"cat_id\", \"day_id\", \"sales\"]),\n", + " df_train[\"sales\"],\n", + " )\n", + " X_valid = df_valid.drop(\n", + " columns=[\"item_id\", \"cat_id\", \"day_id\", \"sales\"]\n", + " )\n", + "\n", + " clf = xgb.XGBRegressor(**params)\n", + " clf.fit(X_train, y_train)\n", + " sales_pred = clf.predict(X_valid)\n", + " scores[store_id * len(DEPTS) + dept_id].append(\n", + " wrmsse(product_weights, df, sales_pred, train_mask, valid_mask)\n", + " )\n", + " del df_train, df_valid, X_train, y_train, clf\n", + " gc.collect()\n", + " del df\n", + " gc.collect()\n", + "\n", + " # We can sum WRMSSE scores over data segments because data segments contain disjoint sets of time series\n", + " return np.array(scores).sum(axis=0).mean()" + ] + }, + { + "cell_type": "code", + "execution_count": 87, + "id": "2c65c62f-97ee-4eb7-8334-6fa82eea83fd", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_1321/491731696.py:7: ExperimentalWarning: DaskStorage is experimental (supported from v3.1.0). The interface can change in the future.\n", + " dask_storage = optuna.integration.DaskStorage(storage=backend_storage, client=client)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Testing hyperparameter combinations 0..2\n", + "Best cross-validation metric: 9.896445497438858, Time elapsed = 802.2191872399999\n", + "Testing hyperparameter combinations 2..4\n", + "Best cross-validation metric: 9.896445497438858, Time elapsed = 1494.0718872279976\n", + "Testing hyperparameter combinations 4..6\n", + "Best cross-validation metric: 9.835407407395302, Time elapsed = 2393.3159628150024\n", + "Testing hyperparameter combinations 6..8\n", + "Best cross-validation metric: 9.330048901795887, Time elapsed = 3092.471466117\n", + "Testing hyperparameter combinations 8..9\n", + "Best cross-validation metric: 9.330048901795887, Time elapsed = 3459.9082761530008\n", + "Total time elapsed = 3459.911843854992\n" + ] + } + ], + "source": [ + "##### Number of hyperparameter combinations to try in parallel\n", + "n_trials = 9 # Using a small n_trials so that the demo can finish quickly\n", + "# n_trials = 100\n", + "\n", + "# Optimize in parallel on your Dask cluster\n", + "backend_storage = optuna.storages.InMemoryStorage()\n", + "dask_storage = optuna.integration.DaskStorage(storage=backend_storage, client=client)\n", + "study = optuna.create_study(\n", + " direction=\"minimize\",\n", + " sampler=optuna.samplers.RandomSampler(seed=0),\n", + " storage=dask_storage,\n", + ")\n", + "futures = []\n", + "for i in range(0, n_trials, n_workers):\n", + " iter_range = (i, min([i + n_workers, n_trials]))\n", + " futures.append(\n", + " {\n", + " \"range\": iter_range,\n", + " \"futures\": [\n", + " client.submit(\n", + " # Work around bug https://github.com/optuna/optuna/issues/4859\n", + " lambda objective, n_trials: (\n", + " study.sampler.reseed_rng(),\n", + " study.optimize(objective, n_trials),\n", + " ),\n", + " objective_alt,\n", + " n_trials=1,\n", + " pure=False,\n", + " )\n", + " for _ in range(*iter_range)\n", + " ],\n", + " }\n", + " )\n", + "\n", + "tstart = time.perf_counter()\n", + "for partition in futures:\n", + " iter_range = partition[\"range\"]\n", + " print(f\"Testing hyperparameter combinations {iter_range[0]}..{iter_range[1]}\")\n", + " _ = wait(partition[\"futures\"])\n", + " for fut in partition[\"futures\"]:\n", + " _ = fut.result() # Ensure that the training job was successful\n", + " tnow = time.perf_counter()\n", + " print(\n", + " f\"Best cross-validation metric: {study.best_value}, Time elapsed = {tnow - tstart}\"\n", + " )\n", + "tend = time.perf_counter()\n", + "print(f\"Total time elapsed = {tend - tstart}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 88, + "id": "b728d1d2-7180-428a-bbdd-249e6471c070", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'lambda': 0.028794929327421122,\n", + " 'alpha': 3.3150619134761685e-07,\n", + " 'colsample_bytree': 0.42330433646728755,\n", + " 'max_depth': 2,\n", + " 'min_child_weight': 0.09713314395591004,\n", + " 'gamma': 0.0016337599227941016,\n", + " 'tweedie_variance_power': 1.1915217521234043}" + ] + }, + "execution_count": 88, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Make a deep copy to preserve the dictionary after deleting the Dask cluster\n", + "best_params_alt = copy.deepcopy(study.best_params)\n", + "best_params_alt" + ] + }, + { + "cell_type": "code", + "execution_count": 89, + "id": "8c65edfa-ed67-4c49-b823-80cb72a7f6e5", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "fs = gcsfs.GCSFileSystem()\n", + "with fs.open(f\"{bucket_name}/params_alt.json\", \"w\") as f:\n", + " json.dump(best_params_alt, f)" + ] + }, + { + "cell_type": "markdown", + "id": "b2d7ba79-72fc-4298-b92e-6eb07b266bfd", + "metadata": {}, + "source": [ + "Using the optimal hyperparameters found in the search, fit a new model using the whole training data." + ] + }, + { + "cell_type": "code", + "execution_count": 90, + "id": "896d9342-794d-4b55-9b64-f644e528a9c0", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def final_train_alt(best_params):\n", + " fs = gcsfs.GCSFileSystem()\n", + " with fs.open(f\"{bucket_name}/product_weights.pkl\", \"rb\") as f:\n", + " product_weights = cudf.DataFrame(pd.read_pickle(f))\n", + " params = {\n", + " \"n_estimators\": 100,\n", + " \"verbosity\": 0,\n", + " \"learning_rate\": 0.01,\n", + " \"objective\": \"reg:tweedie\",\n", + " \"tree_method\": \"gpu_hist\",\n", + " \"grow_policy\": \"depthwise\",\n", + " \"predictor\": \"gpu_predictor\",\n", + " \"enable_categorical\": True,\n", + " }\n", + " params.update(best_params)\n", + " model = {}\n", + " train_mask = [0, 1914]\n", + "\n", + " for store_id, store in enumerate(STORES):\n", + " for dept_id, dept in enumerate(DEPTS):\n", + " print(f\"Processing store {store}, department {dept}...\")\n", + " with fs.open(\n", + " f\"{bucket_name}/combined_df_store_{store}_dept_{dept}.pkl\", \"rb\"\n", + " ) as f:\n", + " df = cudf.DataFrame(pd.read_pickle(f))\n", + " for train_mask, valid_mask in cv_folds:\n", + " df_train = df[\n", + " (df[\"day_id\"] >= train_mask[0]) & (df[\"day_id\"] < train_mask[1])\n", + " ]\n", + " X_train, y_train = (\n", + " df_train.drop(columns=[\"item_id\", \"cat_id\", \"day_id\", \"sales\"]),\n", + " df_train[\"sales\"],\n", + " )\n", + "\n", + " clf = xgb.XGBRegressor(**params)\n", + " clf.fit(X_train, y_train)\n", + " model[(store, dept)] = clf\n", + " del df\n", + " gc.collect()\n", + "\n", + " return model" + ] + }, + { + "cell_type": "code", + "execution_count": 91, + "id": "0abb7b4e-9865-42ad-824a-cf8b64c7c4f3", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "fs = gcsfs.GCSFileSystem()\n", + "with fs.open(f\"{bucket_name}/params_alt.json\", \"r\") as f:\n", + " best_params_alt = json.load(f)\n", + "with fs.open(f\"{bucket_name}/product_weights.pkl\", \"rb\") as f:\n", + " product_weights = cudf.DataFrame(pd.read_pickle(f))" + ] + }, + { + "cell_type": "code", + "execution_count": 92, + "id": "c7b5a383-1358-4201-9588-d799cf21822a", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Processing store CA_1, department HOBBIES_1...\n", + "Processing store CA_1, department HOBBIES_2...\n", + "Processing store CA_1, department HOUSEHOLD_1...\n", + "Processing store CA_1, department HOUSEHOLD_2...\n", + "Processing store CA_1, department FOODS_1...\n", + "Processing store CA_1, department FOODS_2...\n", + "Processing store CA_1, department FOODS_3...\n", + "Processing store CA_2, department HOBBIES_1...\n", + "Processing store CA_2, department HOBBIES_2...\n", + "Processing store CA_2, department HOUSEHOLD_1...\n", + "Processing store CA_2, department HOUSEHOLD_2...\n", + "Processing store CA_2, department FOODS_1...\n", + "Processing store CA_2, department FOODS_2...\n", + "Processing store CA_2, department FOODS_3...\n", + "Processing store CA_3, department HOBBIES_1...\n", + "Processing store CA_3, department HOBBIES_2...\n", + "Processing store CA_3, department HOUSEHOLD_1...\n", + "Processing store CA_3, department HOUSEHOLD_2...\n", + "Processing store CA_3, department FOODS_1...\n", + "Processing store CA_3, department FOODS_2...\n", + "Processing store CA_3, department FOODS_3...\n", + "Processing store CA_4, department HOBBIES_1...\n", + "Processing store CA_4, department HOBBIES_2...\n", + "Processing store CA_4, department HOUSEHOLD_1...\n", + "Processing store CA_4, department HOUSEHOLD_2...\n", + "Processing store CA_4, department FOODS_1...\n", + "Processing store CA_4, department FOODS_2...\n", + "Processing store CA_4, department FOODS_3...\n", + "Processing store TX_1, department HOBBIES_1...\n", + "Processing store TX_1, department HOBBIES_2...\n", + "Processing store TX_1, department HOUSEHOLD_1...\n", + "Processing store TX_1, department HOUSEHOLD_2...\n", + "Processing store TX_1, department FOODS_1...\n", + "Processing store TX_1, department FOODS_2...\n", + "Processing store TX_1, department FOODS_3...\n", + "Processing store TX_2, department HOBBIES_1...\n", + "Processing store TX_2, department HOBBIES_2...\n", + "Processing store TX_2, department HOUSEHOLD_1...\n", + "Processing store TX_2, department HOUSEHOLD_2...\n", + "Processing store TX_2, department FOODS_1...\n", + "Processing store TX_2, department FOODS_2...\n", + "Processing store TX_2, department FOODS_3...\n", + "Processing store TX_3, department HOBBIES_1...\n", + "Processing store TX_3, department HOBBIES_2...\n", + "Processing store TX_3, department HOUSEHOLD_1...\n", + "Processing store TX_3, department HOUSEHOLD_2...\n", + "Processing store TX_3, department FOODS_1...\n", + "Processing store TX_3, department FOODS_2...\n", + "Processing store TX_3, department FOODS_3...\n", + "Processing store WI_1, department HOBBIES_1...\n", + "Processing store WI_1, department HOBBIES_2...\n", + "Processing store WI_1, department HOUSEHOLD_1...\n", + "Processing store WI_1, department HOUSEHOLD_2...\n", + "Processing store WI_1, department FOODS_1...\n", + "Processing store WI_1, department FOODS_2...\n", + "Processing store WI_1, department FOODS_3...\n", + "Processing store WI_2, department HOBBIES_1...\n", + "Processing store WI_2, department HOBBIES_2...\n", + "Processing store WI_2, department HOUSEHOLD_1...\n", + "Processing store WI_2, department HOUSEHOLD_2...\n", + "Processing store WI_2, department FOODS_1...\n", + "Processing store WI_2, department FOODS_2...\n", + "Processing store WI_2, department FOODS_3...\n", + "Processing store WI_3, department HOBBIES_1...\n", + "Processing store WI_3, department HOBBIES_2...\n", + "Processing store WI_3, department HOUSEHOLD_1...\n", + "Processing store WI_3, department HOUSEHOLD_2...\n", + "Processing store WI_3, department FOODS_1...\n", + "Processing store WI_3, department FOODS_2...\n", + "Processing store WI_3, department FOODS_3...\n" + ] + } + ], + "source": [ + "model_alt = final_train_alt(best_params_alt)" + ] + }, + { + "cell_type": "code", + "execution_count": 93, + "id": "9b72487c-e311-40e8-ae52-b6b23a9d391e", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Save the model to the Cloud Storage\n", + "with fs.open(f\"{bucket_name}/final_model_alt.pkl\", \"wb\") as f:\n", + " pickle.dump(model_alt, f)" + ] + }, + { + "cell_type": "markdown", + "id": "b4eb7307-96a3-4f1b-9356-13027b2779a3", + "metadata": {}, + "source": [ + "Now consider an ensemble consisting of the two models `model` and `model_alt`. We evaluate the ensemble by computing the WRMSSE metric for the average of the predictions of the two models." + ] + }, + { + "cell_type": "code", + "execution_count": 94, + "id": "9c746944-5e79-492e-b81f-8f9a556d0df9", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Processing store CA_1...\n", + "Processing store CA_2...\n", + "Processing store CA_3...\n", + "Processing store CA_4...\n", + "Processing store TX_1...\n", + "Processing store TX_2...\n", + "Processing store TX_3...\n", + "Processing store WI_1...\n", + "Processing store WI_2...\n", + "Processing store WI_3...\n", + "WRMSSE metric on the held-out test set: 10.69187847848366\n" + ] + } + ], + "source": [ + "test_wrmsse = 0\n", + "for store in STORES:\n", + " print(f\"Processing store {store}...\")\n", + " # Prediction from Model 1\n", + " with fs.open(f\"{bucket_name}/combined_df_store_{store}.pkl\", \"rb\") as f:\n", + " df = cudf.DataFrame(pd.read_pickle(f))\n", + " df_test = df[(df[\"day_id\"] >= holdout[0]) & (df[\"day_id\"] < holdout[1])]\n", + " X_test = df_test.drop(columns=[\"item_id\", \"dept_id\", \"cat_id\", \"day_id\", \"sales\"])\n", + " df_test[\"pred1\"] = model[store].predict(X_test)\n", + "\n", + " # Prediction from Model 2\n", + " df_test[\"pred2\"] = [np.nan] * len(df_test)\n", + " df_test[\"pred2\"] = df_test[\"pred2\"].astype(\"float32\")\n", + " for dept in DEPTS:\n", + " with fs.open(\n", + " f\"{bucket_name}/combined_df_store_{store}_dept_{dept}.pkl\", \"rb\"\n", + " ) as f:\n", + " df2 = cudf.DataFrame(pd.read_pickle(f))\n", + " df2_test = df2[(df2[\"day_id\"] >= holdout[0]) & (df2[\"day_id\"] < holdout[1])]\n", + " X_test = df2_test.drop(columns=[\"item_id\", \"cat_id\", \"day_id\", \"sales\"])\n", + " assert np.sum(df_test[\"dept_id\"] == dept) == len(X_test)\n", + " df_test[\"pred2\"][df_test[\"dept_id\"] == dept] = model_alt[(store, dept)].predict(\n", + " X_test\n", + " )\n", + "\n", + " # Average prediction\n", + " df_test[\"avg_pred\"] = (df_test[\"pred1\"] + df_test[\"pred2\"]) / 2.0\n", + "\n", + " test_wrmsse += wrmsse(\n", + " product_weights,\n", + " df,\n", + " df_test[\"avg_pred\"],\n", + " train_mask=[0, 1914],\n", + " valid_mask=holdout,\n", + " )\n", + "print(f\"WRMSSE metric on the held-out test set: {test_wrmsse}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 95, + "id": "af6cc91f-036b-4503-a784-0e413c681727", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Close the Dask cluster to clean up\n", + "cluster.close()" + ] + }, + { + "cell_type": "markdown", + "id": "727454aa-06fd-4724-9ee0-7874f44bb88b", + "metadata": {}, + "source": [ + "## Conclusion" + ] + }, + { + "cell_type": "markdown", + "id": "19524812-c005-4d39-bc62-64a681fb3164", + "metadata": {}, + "source": [ + "We demonstrated an end-to-end workflow where we take a real-world time-series data and train a forecasting model using Google Kubernetes Engine (GKE). We were able to speed up the hyperparameter optimization (HPO) process by dispatching parallel training jobs to NVIDIA GPUs." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7b83283f-0e56-4b97-b7aa-024a386fc9b2", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}