From bcc97768a30ff8ad04b4185c358acb5e97f14d8f Mon Sep 17 00:00:00 2001 From: Hyunsu Philip Cho Date: Thu, 27 Jul 2023 15:53:22 -0700 Subject: [PATCH 01/11] Add workflow example for time series forecasting --- extensions/rapids_notebook_files.py | 2 +- source/examples/index.md | 1 + .../preprocessing_part1.ipynb | 3201 +++++++++++++++++ .../preprocessing_part2.ipynb | 1259 +++++++ .../preprocessing_part3.ipynb | 839 +++++ .../preprocessing_part4.ipynb | 623 ++++ .../preprocessing_part5.ipynb | 595 +++ .../preprocessing_part6.ipynb | 427 +++ .../start_here.ipynb | 305 ++ .../training_and_evaluation.ipynb | 1682 +++++++++ 10 files changed, 8933 insertions(+), 1 deletion(-) create mode 100644 source/examples/time-series-forecasting-with-hpo/preprocessing_part1.ipynb create mode 100644 source/examples/time-series-forecasting-with-hpo/preprocessing_part2.ipynb create mode 100644 source/examples/time-series-forecasting-with-hpo/preprocessing_part3.ipynb create mode 100644 source/examples/time-series-forecasting-with-hpo/preprocessing_part4.ipynb create mode 100644 source/examples/time-series-forecasting-with-hpo/preprocessing_part5.ipynb create mode 100644 source/examples/time-series-forecasting-with-hpo/preprocessing_part6.ipynb create mode 100644 source/examples/time-series-forecasting-with-hpo/start_here.ipynb create mode 100644 source/examples/time-series-forecasting-with-hpo/training_and_evaluation.ipynb diff --git a/extensions/rapids_notebook_files.py b/extensions/rapids_notebook_files.py index eef81d94..66d68ef8 100644 --- a/extensions/rapids_notebook_files.py +++ b/extensions/rapids_notebook_files.py @@ -12,7 +12,7 @@ def template_func(app, match): def walk_files(app, dir, outdir): - outdir.mkdir(parents=True, exist_ok=False) + outdir.mkdir(parents=True, exist_ok=True) related_notebook_files = {} for page in dir.glob("*"): if page.is_dir(): diff --git a/source/examples/index.md b/source/examples/index.md index 852bdc2d..263c86b8 100644 --- a/source/examples/index.md +++ b/source/examples/index.md @@ -14,4 +14,5 @@ rapids-ec2-mnmg/notebook rapids-autoscaling-multi-tenant-kubernetes/notebook xgboost-randomforest-gpu-hpo-dask/notebook rapids-azureml-hpo/notebook +time-series-forecasting-with-hpo/start_here ``` diff --git a/source/examples/time-series-forecasting-with-hpo/preprocessing_part1.ipynb b/source/examples/time-series-forecasting-with-hpo/preprocessing_part1.ipynb new file mode 100644 index 00000000..985f1c98 --- /dev/null +++ b/source/examples/time-series-forecasting-with-hpo/preprocessing_part1.ipynb @@ -0,0 +1,3201 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "70b2ef46-9ac6-47db-a66f-353bb4a27722", + "metadata": {}, + "source": [ + "# Data preprocesing, Part 1" + ] + }, + { + "cell_type": "markdown", + "id": "62f9ff3a-b875-4798-8988-7238c1c651a6", + "metadata": {}, + "source": [ + "## Import modules and define utility functions" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "fce9c954-bb2c-4c2d-bb6e-11b100cac88f", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import cudf\n", + "import numpy as np\n", + "import pandas as pd\n", + "import gc\n", + "import pathlib\n", + "import gcsfs" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "45c07679-7305-4a8a-a826-efa42ec65ba2", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def sizeof_fmt(num, suffix=\"B\"):\n", + " for unit in [\"\", \"Ki\", \"Mi\", \"Gi\", \"Ti\", \"Pi\", \"Ei\", \"Zi\"]:\n", + " if abs(num) < 1024.0:\n", + " return f\"{num:3.1f}{unit}{suffix}\"\n", + " num /= 1024.0\n", + " return \"%.1f%s%s\" % (num, \"Yi\", suffix)\n", + "\n", + "def report_dataframe_size(df, name):\n", + " print(\"{} takes up {} memory on GPU\".format(name, sizeof_fmt(grid_df.memory_usage(index=True).sum())))" + ] + }, + { + "cell_type": "markdown", + "id": "b981cb8f-f6f8-4a32-af08-d883ade14c0f", + "metadata": {}, + "source": [ + "## Load Data" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "1fda1c8f-2697-4e04-9484-2d7072c7e904", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "TARGET = \"sales\" # Our main target\n", + "END_TRAIN = 1941 # Last day in train set" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "d3ea8261-4b3d-470d-be49-b161c8d72b04", + "metadata": {}, + "outputs": [], + "source": [ + "raw_data_dir = \"./data/\"\n", + "processed_data_dir = \"./processed_data/\"\n", + "\n", + "pathlib.Path(processed_data_dir).mkdir(exist_ok=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "1a852fe7-e6c6-45f4-9767-af7b78bc510b", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "train_df = cudf.read_csv(raw_data_dir + \"sales_train_evaluation.csv\")\n", + "prices_df = cudf.read_csv(raw_data_dir + \"sell_prices.csv\")\n", + "calendar_df = cudf.read_csv(raw_data_dir + \"calendar.csv\").rename(columns={\"d\": \"day_id\"})" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "e5d049a2-2797-45b3-9b32-367b06505477", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
iditem_iddept_idcat_idstore_idstate_idd_1d_2d_3d_4...d_1932d_1933d_1934d_1935d_1936d_1937d_1938d_1939d_1940d_1941
0HOBBIES_1_001_CA_1_evaluationHOBBIES_1_001HOBBIES_1HOBBIESCA_1CA0000...2400003301
1HOBBIES_1_002_CA_1_evaluationHOBBIES_1_002HOBBIES_1HOBBIESCA_1CA0000...0121100000
2HOBBIES_1_003_CA_1_evaluationHOBBIES_1_003HOBBIES_1HOBBIESCA_1CA0000...1020002301
3HOBBIES_1_004_CA_1_evaluationHOBBIES_1_004HOBBIES_1HOBBIESCA_1CA0000...1104013026
4HOBBIES_1_005_CA_1_evaluationHOBBIES_1_005HOBBIES_1HOBBIESCA_1CA0000...0002100210
..................................................................
30485FOODS_3_823_WI_3_evaluationFOODS_3_823FOODS_3FOODSWI_3WI0022...1030110011
30486FOODS_3_824_WI_3_evaluationFOODS_3_824FOODS_3FOODSWI_3WI0000...0000001010
30487FOODS_3_825_WI_3_evaluationFOODS_3_825FOODS_3FOODSWI_3WI0602...0012010102
30488FOODS_3_826_WI_3_evaluationFOODS_3_826FOODS_3FOODSWI_3WI0000...1114601110
30489FOODS_3_827_WI_3_evaluationFOODS_3_827FOODS_3FOODSWI_3WI0000...1205402251
\n", + "

30490 rows × 1947 columns

\n", + "
" + ], + "text/plain": [ + " id item_id dept_id cat_id \\\n", + "0 HOBBIES_1_001_CA_1_evaluation HOBBIES_1_001 HOBBIES_1 HOBBIES \n", + "1 HOBBIES_1_002_CA_1_evaluation HOBBIES_1_002 HOBBIES_1 HOBBIES \n", + "2 HOBBIES_1_003_CA_1_evaluation HOBBIES_1_003 HOBBIES_1 HOBBIES \n", + "3 HOBBIES_1_004_CA_1_evaluation HOBBIES_1_004 HOBBIES_1 HOBBIES \n", + "4 HOBBIES_1_005_CA_1_evaluation HOBBIES_1_005 HOBBIES_1 HOBBIES \n", + "... ... ... ... ... \n", + "30485 FOODS_3_823_WI_3_evaluation FOODS_3_823 FOODS_3 FOODS \n", + "30486 FOODS_3_824_WI_3_evaluation FOODS_3_824 FOODS_3 FOODS \n", + "30487 FOODS_3_825_WI_3_evaluation FOODS_3_825 FOODS_3 FOODS \n", + "30488 FOODS_3_826_WI_3_evaluation FOODS_3_826 FOODS_3 FOODS \n", + "30489 FOODS_3_827_WI_3_evaluation FOODS_3_827 FOODS_3 FOODS \n", + "\n", + " store_id state_id d_1 d_2 d_3 d_4 ... d_1932 d_1933 d_1934 \\\n", + "0 CA_1 CA 0 0 0 0 ... 2 4 0 \n", + "1 CA_1 CA 0 0 0 0 ... 0 1 2 \n", + "2 CA_1 CA 0 0 0 0 ... 1 0 2 \n", + "3 CA_1 CA 0 0 0 0 ... 1 1 0 \n", + "4 CA_1 CA 0 0 0 0 ... 0 0 0 \n", + "... ... ... ... ... ... ... ... ... ... ... \n", + "30485 WI_3 WI 0 0 2 2 ... 1 0 3 \n", + "30486 WI_3 WI 0 0 0 0 ... 0 0 0 \n", + "30487 WI_3 WI 0 6 0 2 ... 0 0 1 \n", + "30488 WI_3 WI 0 0 0 0 ... 1 1 1 \n", + "30489 WI_3 WI 0 0 0 0 ... 1 2 0 \n", + "\n", + " d_1935 d_1936 d_1937 d_1938 d_1939 d_1940 d_1941 \n", + "0 0 0 0 3 3 0 1 \n", + "1 1 1 0 0 0 0 0 \n", + "2 0 0 0 2 3 0 1 \n", + "3 4 0 1 3 0 2 6 \n", + "4 2 1 0 0 2 1 0 \n", + "... ... ... ... ... ... ... ... \n", + "30485 0 1 1 0 0 1 1 \n", + "30486 0 0 0 1 0 1 0 \n", + "30487 2 0 1 0 1 0 2 \n", + "30488 4 6 0 1 1 1 0 \n", + "30489 5 4 0 2 2 5 1 \n", + "\n", + "[30490 rows x 1947 columns]" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_df" + ] + }, + { + "cell_type": "markdown", + "id": "d9053e74-cddf-4ca2-b8cb-f6da58266e2a", + "metadata": {}, + "source": [ + "The columns `d_1`, `d_2`, ..., `d_1941` indicate the sales data at days 1, 2, ..., 1941 from 2011-01-29." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "efeb85c5-c72e-4cd2-be2d-17d4a240aef4", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
store_iditem_idwm_yr_wksell_price
0CA_1HOBBIES_1_001113259.58
1CA_1HOBBIES_1_001113269.58
2CA_1HOBBIES_1_001113278.26
3CA_1HOBBIES_1_001113288.26
4CA_1HOBBIES_1_001113298.26
...............
6841116WI_3FOODS_3_827116171.00
6841117WI_3FOODS_3_827116181.00
6841118WI_3FOODS_3_827116191.00
6841119WI_3FOODS_3_827116201.00
6841120WI_3FOODS_3_827116211.00
\n", + "

6841121 rows × 4 columns

\n", + "
" + ], + "text/plain": [ + " store_id item_id wm_yr_wk sell_price\n", + "0 CA_1 HOBBIES_1_001 11325 9.58\n", + "1 CA_1 HOBBIES_1_001 11326 9.58\n", + "2 CA_1 HOBBIES_1_001 11327 8.26\n", + "3 CA_1 HOBBIES_1_001 11328 8.26\n", + "4 CA_1 HOBBIES_1_001 11329 8.26\n", + "... ... ... ... ...\n", + "6841116 WI_3 FOODS_3_827 11617 1.00\n", + "6841117 WI_3 FOODS_3_827 11618 1.00\n", + "6841118 WI_3 FOODS_3_827 11619 1.00\n", + "6841119 WI_3 FOODS_3_827 11620 1.00\n", + "6841120 WI_3 FOODS_3_827 11621 1.00\n", + "\n", + "[6841121 rows x 4 columns]" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "prices_df" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "77187dd8-92a7-400b-b5d8-4db31817e4df", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
datewm_yr_wkweekdaywdaymonthyearday_idevent_name_1event_type_1event_name_2event_type_2snap_CAsnap_TXsnap_WI
02011-01-2911101Saturday112011d_1<NA><NA><NA><NA>000
12011-01-3011101Sunday212011d_2<NA><NA><NA><NA>000
22011-01-3111101Monday312011d_3<NA><NA><NA><NA>000
32011-02-0111101Tuesday422011d_4<NA><NA><NA><NA>110
42011-02-0211101Wednesday522011d_5<NA><NA><NA><NA>101
.............................................
19642016-06-1511620Wednesday562016d_1965<NA><NA><NA><NA>011
19652016-06-1611620Thursday662016d_1966<NA><NA><NA><NA>000
19662016-06-1711620Friday762016d_1967<NA><NA><NA><NA>000
19672016-06-1811621Saturday162016d_1968<NA><NA><NA><NA>000
19682016-06-1911621Sunday262016d_1969NBAFinalsEndSportingFather's dayCultural000
\n", + "

1969 rows × 14 columns

\n", + "
" + ], + "text/plain": [ + " date wm_yr_wk weekday wday month year day_id \\\n", + "0 2011-01-29 11101 Saturday 1 1 2011 d_1 \n", + "1 2011-01-30 11101 Sunday 2 1 2011 d_2 \n", + "2 2011-01-31 11101 Monday 3 1 2011 d_3 \n", + "3 2011-02-01 11101 Tuesday 4 2 2011 d_4 \n", + "4 2011-02-02 11101 Wednesday 5 2 2011 d_5 \n", + "... ... ... ... ... ... ... ... \n", + "1964 2016-06-15 11620 Wednesday 5 6 2016 d_1965 \n", + "1965 2016-06-16 11620 Thursday 6 6 2016 d_1966 \n", + "1966 2016-06-17 11620 Friday 7 6 2016 d_1967 \n", + "1967 2016-06-18 11621 Saturday 1 6 2016 d_1968 \n", + "1968 2016-06-19 11621 Sunday 2 6 2016 d_1969 \n", + "\n", + " event_name_1 event_type_1 event_name_2 event_type_2 snap_CA snap_TX \\\n", + "0 0 0 \n", + "1 0 0 \n", + "2 0 0 \n", + "3 1 1 \n", + "4 1 0 \n", + "... ... ... ... ... ... ... \n", + "1964 0 1 \n", + "1965 0 0 \n", + "1966 0 0 \n", + "1967 0 0 \n", + "1968 NBAFinalsEnd Sporting Father's day Cultural 0 0 \n", + "\n", + " snap_WI \n", + "0 0 \n", + "1 0 \n", + "2 0 \n", + "3 0 \n", + "4 1 \n", + "... ... \n", + "1964 1 \n", + "1965 0 \n", + "1966 0 \n", + "1967 0 \n", + "1968 0 \n", + "\n", + "[1969 rows x 14 columns]" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "calendar_df" + ] + }, + { + "cell_type": "markdown", + "id": "dd65fd8d-ba0f-4120-8ebd-4bc9d1513b74", + "metadata": {}, + "source": [ + "## Reformat sales times series data" + ] + }, + { + "cell_type": "markdown", + "id": "666b07b0-bc62-4c9b-86bc-5287b1d40de2", + "metadata": {}, + "source": [ + "Pivot the columns `d_1`, `d_2`, ..., `d_1941` into separate rows using `cudf.melt`." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "4908f0c5-b83b-4c8d-81e1-7853b5bbc7af", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
iditem_iddept_idcat_idstore_idstate_idday_idsales
0HOBBIES_1_001_CA_1_evaluationHOBBIES_1_001HOBBIES_1HOBBIESCA_1CAd_10
1HOBBIES_1_002_CA_1_evaluationHOBBIES_1_002HOBBIES_1HOBBIESCA_1CAd_10
2HOBBIES_1_003_CA_1_evaluationHOBBIES_1_003HOBBIES_1HOBBIESCA_1CAd_10
3HOBBIES_1_004_CA_1_evaluationHOBBIES_1_004HOBBIES_1HOBBIESCA_1CAd_10
4HOBBIES_1_005_CA_1_evaluationHOBBIES_1_005HOBBIES_1HOBBIESCA_1CAd_10
...........................
59181085FOODS_3_823_WI_3_evaluationFOODS_3_823FOODS_3FOODSWI_3WId_19411
59181086FOODS_3_824_WI_3_evaluationFOODS_3_824FOODS_3FOODSWI_3WId_19410
59181087FOODS_3_825_WI_3_evaluationFOODS_3_825FOODS_3FOODSWI_3WId_19412
59181088FOODS_3_826_WI_3_evaluationFOODS_3_826FOODS_3FOODSWI_3WId_19410
59181089FOODS_3_827_WI_3_evaluationFOODS_3_827FOODS_3FOODSWI_3WId_19411
\n", + "

59181090 rows × 8 columns

\n", + "
" + ], + "text/plain": [ + " id item_id dept_id cat_id \\\n", + "0 HOBBIES_1_001_CA_1_evaluation HOBBIES_1_001 HOBBIES_1 HOBBIES \n", + "1 HOBBIES_1_002_CA_1_evaluation HOBBIES_1_002 HOBBIES_1 HOBBIES \n", + "2 HOBBIES_1_003_CA_1_evaluation HOBBIES_1_003 HOBBIES_1 HOBBIES \n", + "3 HOBBIES_1_004_CA_1_evaluation HOBBIES_1_004 HOBBIES_1 HOBBIES \n", + "4 HOBBIES_1_005_CA_1_evaluation HOBBIES_1_005 HOBBIES_1 HOBBIES \n", + "... ... ... ... ... \n", + "59181085 FOODS_3_823_WI_3_evaluation FOODS_3_823 FOODS_3 FOODS \n", + "59181086 FOODS_3_824_WI_3_evaluation FOODS_3_824 FOODS_3 FOODS \n", + "59181087 FOODS_3_825_WI_3_evaluation FOODS_3_825 FOODS_3 FOODS \n", + "59181088 FOODS_3_826_WI_3_evaluation FOODS_3_826 FOODS_3 FOODS \n", + "59181089 FOODS_3_827_WI_3_evaluation FOODS_3_827 FOODS_3 FOODS \n", + "\n", + " store_id state_id day_id sales \n", + "0 CA_1 CA d_1 0 \n", + "1 CA_1 CA d_1 0 \n", + "2 CA_1 CA d_1 0 \n", + "3 CA_1 CA d_1 0 \n", + "4 CA_1 CA d_1 0 \n", + "... ... ... ... ... \n", + "59181085 WI_3 WI d_1941 1 \n", + "59181086 WI_3 WI d_1941 0 \n", + "59181087 WI_3 WI d_1941 2 \n", + "59181088 WI_3 WI d_1941 0 \n", + "59181089 WI_3 WI d_1941 1 \n", + "\n", + "[59181090 rows x 8 columns]" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "index_columns = [\"id\", \"item_id\", \"dept_id\", \"cat_id\", \"store_id\", \"state_id\"]\n", + "grid_df = cudf.melt(train_df, id_vars=index_columns, var_name=\"day_id\", value_name=TARGET)\n", + "grid_df" + ] + }, + { + "cell_type": "markdown", + "id": "fe51bf83-6831-4d4a-99bc-f6d5f6ec9e15", + "metadata": {}, + "source": [ + "For each time series, add 28 rows that corresponds to the future forecast horizon:" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "5700f3c3-f294-4963-8cea-9235ab55b0f4", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
iditem_iddept_idcat_idstore_idstate_idday_idsales
0HOBBIES_1_001_CA_1_evaluationHOBBIES_1_001HOBBIES_1HOBBIESCA_1CAd_10.0
1HOBBIES_1_002_CA_1_evaluationHOBBIES_1_002HOBBIES_1HOBBIESCA_1CAd_10.0
2HOBBIES_1_003_CA_1_evaluationHOBBIES_1_003HOBBIES_1HOBBIESCA_1CAd_10.0
3HOBBIES_1_004_CA_1_evaluationHOBBIES_1_004HOBBIES_1HOBBIESCA_1CAd_10.0
4HOBBIES_1_005_CA_1_evaluationHOBBIES_1_005HOBBIES_1HOBBIESCA_1CAd_10.0
...........................
60034805FOODS_3_823_WI_3_evaluationFOODS_3_823FOODS_3FOODSWI_3WId_1969NaN
60034806FOODS_3_824_WI_3_evaluationFOODS_3_824FOODS_3FOODSWI_3WId_1969NaN
60034807FOODS_3_825_WI_3_evaluationFOODS_3_825FOODS_3FOODSWI_3WId_1969NaN
60034808FOODS_3_826_WI_3_evaluationFOODS_3_826FOODS_3FOODSWI_3WId_1969NaN
60034809FOODS_3_827_WI_3_evaluationFOODS_3_827FOODS_3FOODSWI_3WId_1969NaN
\n", + "

60034810 rows × 8 columns

\n", + "
" + ], + "text/plain": [ + " id item_id dept_id cat_id \\\n", + "0 HOBBIES_1_001_CA_1_evaluation HOBBIES_1_001 HOBBIES_1 HOBBIES \n", + "1 HOBBIES_1_002_CA_1_evaluation HOBBIES_1_002 HOBBIES_1 HOBBIES \n", + "2 HOBBIES_1_003_CA_1_evaluation HOBBIES_1_003 HOBBIES_1 HOBBIES \n", + "3 HOBBIES_1_004_CA_1_evaluation HOBBIES_1_004 HOBBIES_1 HOBBIES \n", + "4 HOBBIES_1_005_CA_1_evaluation HOBBIES_1_005 HOBBIES_1 HOBBIES \n", + "... ... ... ... ... \n", + "60034805 FOODS_3_823_WI_3_evaluation FOODS_3_823 FOODS_3 FOODS \n", + "60034806 FOODS_3_824_WI_3_evaluation FOODS_3_824 FOODS_3 FOODS \n", + "60034807 FOODS_3_825_WI_3_evaluation FOODS_3_825 FOODS_3 FOODS \n", + "60034808 FOODS_3_826_WI_3_evaluation FOODS_3_826 FOODS_3 FOODS \n", + "60034809 FOODS_3_827_WI_3_evaluation FOODS_3_827 FOODS_3 FOODS \n", + "\n", + " store_id state_id day_id sales \n", + "0 CA_1 CA d_1 0.0 \n", + "1 CA_1 CA d_1 0.0 \n", + "2 CA_1 CA d_1 0.0 \n", + "3 CA_1 CA d_1 0.0 \n", + "4 CA_1 CA d_1 0.0 \n", + "... ... ... ... ... \n", + "60034805 WI_3 WI d_1969 NaN \n", + "60034806 WI_3 WI d_1969 NaN \n", + "60034807 WI_3 WI d_1969 NaN \n", + "60034808 WI_3 WI d_1969 NaN \n", + "60034809 WI_3 WI d_1969 NaN \n", + "\n", + "[60034810 rows x 8 columns]" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "add_grid = cudf.DataFrame()\n", + "for i in range(1, 29):\n", + " temp_df = train_df[index_columns]\n", + " temp_df = temp_df.drop_duplicates()\n", + " temp_df[\"day_id\"] = \"d_\" + str(END_TRAIN + i)\n", + " temp_df[TARGET] = np.nan # Sales amount at time (n + i) is unknown\n", + " add_grid = cudf.concat([add_grid, temp_df])\n", + "add_grid[\"day_id\"] = add_grid[\"day_id\"].astype(\"category\") # The day_id column is categorical, after cudf.melt\n", + "\n", + "grid_df = cudf.concat([grid_df, add_grid])\n", + "grid_df = grid_df.reset_index(drop=True)\n", + "grid_df[\"sales\"] = grid_df[\"sales\"].astype(np.float32) # Use float32 type for sales column, to conserve memory\n", + "grid_df" + ] + }, + { + "cell_type": "markdown", + "id": "4af8ec5b-5ad3-43c0-9e49-136f2e061019", + "metadata": {}, + "source": [ + "### Free up GPU memory" + ] + }, + { + "cell_type": "markdown", + "id": "8cac1630-1062-4d3e-be9e-77f5315d903b", + "metadata": {}, + "source": [ + "GPU memory is a precious resource, so let's try to free up some memory. First, delete temporary variables we no longer need:" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "65e025ab-7443-4ba2-b47d-1b6494a57afe", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "8184" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Use xdel magic to scrub extra references from Jupyter notebook\n", + "%xdel temp_df\n", + "%xdel add_grid\n", + "%xdel train_df\n", + "\n", + "# Invoke the garbage collector explicitly to free up memory\n", + "gc.collect()" + ] + }, + { + "cell_type": "markdown", + "id": "15bc2d78-417f-4920-83ab-ea3d32ba6744", + "metadata": {}, + "source": [ + "Second, let's reduce the footprint of `grid_df` by converting strings into categoricals:" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "7d2ef775-f68d-47d2-8f6f-be1efffc831d", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "grid_df takes up 5.2GiB memory on GPU\n" + ] + } + ], + "source": [ + "report_dataframe_size(grid_df, \"grid_df\")" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "7c3fa45a-4f89-48c5-9fd7-7ac31e065ba7", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "id object\n", + "item_id object\n", + "dept_id object\n", + "cat_id object\n", + "store_id object\n", + "state_id object\n", + "day_id category\n", + "sales float32\n", + "dtype: object" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "grid_df.dtypes" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "dd0acbcd-c192-4f01-800f-ea6ac441b1ab", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "grid_df takes up 802.6MiB memory on GPU\n" + ] + } + ], + "source": [ + "for col in index_columns:\n", + " grid_df[col] = grid_df[col].astype(\"category\")\n", + " gc.collect()\n", + "report_dataframe_size(grid_df, \"grid_df\")" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "96945261-2069-4923-a443-07766774465a", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "id category\n", + "item_id category\n", + "dept_id category\n", + "cat_id category\n", + "store_id category\n", + "state_id category\n", + "day_id category\n", + "sales float32\n", + "dtype: object" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "grid_df.dtypes" + ] + }, + { + "cell_type": "markdown", + "id": "48321684-4e32-4d6b-a0b1-6495b773b4ff", + "metadata": {}, + "source": [ + "## Identify the release week of each product" + ] + }, + { + "cell_type": "markdown", + "id": "8c3b2631-02b1-4942-b881-16c43321aad6", + "metadata": {}, + "source": [ + "Each row in the `prices_df` table contains the price of a product sold at a store for a given week." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "dd624c0c-0460-4e58-9bd1-0a4148e21a46", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
store_iditem_idwm_yr_wksell_price
0CA_1HOBBIES_1_001113259.58
1CA_1HOBBIES_1_001113269.58
2CA_1HOBBIES_1_001113278.26
3CA_1HOBBIES_1_001113288.26
4CA_1HOBBIES_1_001113298.26
...............
6841116WI_3FOODS_3_827116171.00
6841117WI_3FOODS_3_827116181.00
6841118WI_3FOODS_3_827116191.00
6841119WI_3FOODS_3_827116201.00
6841120WI_3FOODS_3_827116211.00
\n", + "

6841121 rows × 4 columns

\n", + "
" + ], + "text/plain": [ + " store_id item_id wm_yr_wk sell_price\n", + "0 CA_1 HOBBIES_1_001 11325 9.58\n", + "1 CA_1 HOBBIES_1_001 11326 9.58\n", + "2 CA_1 HOBBIES_1_001 11327 8.26\n", + "3 CA_1 HOBBIES_1_001 11328 8.26\n", + "4 CA_1 HOBBIES_1_001 11329 8.26\n", + "... ... ... ... ...\n", + "6841116 WI_3 FOODS_3_827 11617 1.00\n", + "6841117 WI_3 FOODS_3_827 11618 1.00\n", + "6841118 WI_3 FOODS_3_827 11619 1.00\n", + "6841119 WI_3 FOODS_3_827 11620 1.00\n", + "6841120 WI_3 FOODS_3_827 11621 1.00\n", + "\n", + "[6841121 rows x 4 columns]" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "prices_df" + ] + }, + { + "cell_type": "markdown", + "id": "87cf77f2-3eb5-440d-850d-a80eee2fbce9", + "metadata": {}, + "source": [ + "Notice that not all products were sold over every week. Some products were sold only during some weeks. Let's use the groupby operation to identify the first week in which each product went on the shelf." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "b4bd0430-bffa-4d62-b15a-bf71498c19b5", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
store_iditem_idrelease_week
0CA_4FOODS_3_52911421
1TX_1HOUSEHOLD_1_40911230
2WI_2FOODS_3_14511214
3CA_4HOUSEHOLD_1_49411106
4WI_3HOBBIES_1_09311223
............
30485CA_3HOUSEHOLD_1_36911205
30486CA_2FOODS_3_10911101
30487CA_4FOODS_2_11911101
30488CA_4HOUSEHOLD_2_38411110
30489WI_3HOBBIES_1_13511328
\n", + "

30490 rows × 3 columns

\n", + "
" + ], + "text/plain": [ + " store_id item_id release_week\n", + "0 CA_4 FOODS_3_529 11421\n", + "1 TX_1 HOUSEHOLD_1_409 11230\n", + "2 WI_2 FOODS_3_145 11214\n", + "3 CA_4 HOUSEHOLD_1_494 11106\n", + "4 WI_3 HOBBIES_1_093 11223\n", + "... ... ... ...\n", + "30485 CA_3 HOUSEHOLD_1_369 11205\n", + "30486 CA_2 FOODS_3_109 11101\n", + "30487 CA_4 FOODS_2_119 11101\n", + "30488 CA_4 HOUSEHOLD_2_384 11110\n", + "30489 WI_3 HOBBIES_1_135 11328\n", + "\n", + "[30490 rows x 3 columns]" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "release_df = prices_df.groupby([\"store_id\", \"item_id\"])[\"wm_yr_wk\"].agg(\"min\").reset_index()\n", + "release_df.columns = [\"store_id\", \"item_id\", \"release_week\"]\n", + "release_df" + ] + }, + { + "cell_type": "markdown", + "id": "b987d3f4-7885-4f59-86e0-bb78a9eea106", + "metadata": {}, + "source": [ + "Now that we've computed the release week for each product, let's merge it back to `grid_df`:" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "9b1baea3-8507-452d-ae2e-bc0024d3c926", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
iditem_iddept_idcat_idstore_idstate_idday_idsalesrelease_week
0FOODS_1_001_CA_1_evaluationFOODS_1_001FOODS_1FOODSCA_1CAd_13.011101
1FOODS_1_001_CA_1_evaluationFOODS_1_001FOODS_1FOODSCA_1CAd_20.011101
2FOODS_1_001_CA_1_evaluationFOODS_1_001FOODS_1FOODSCA_1CAd_30.011101
3FOODS_1_001_CA_1_evaluationFOODS_1_001FOODS_1FOODSCA_1CAd_41.011101
4FOODS_1_001_CA_1_evaluationFOODS_1_001FOODS_1FOODSCA_1CAd_54.011101
..............................
60034805HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_1965NaN11101
60034806HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_1966NaN11101
60034807HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_1967NaN11101
60034808HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_1968NaN11101
60034809HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_1969NaN11101
\n", + "

60034810 rows × 9 columns

\n", + "
" + ], + "text/plain": [ + " id item_id dept_id \\\n", + "0 FOODS_1_001_CA_1_evaluation FOODS_1_001 FOODS_1 \n", + "1 FOODS_1_001_CA_1_evaluation FOODS_1_001 FOODS_1 \n", + "2 FOODS_1_001_CA_1_evaluation FOODS_1_001 FOODS_1 \n", + "3 FOODS_1_001_CA_1_evaluation FOODS_1_001 FOODS_1 \n", + "4 FOODS_1_001_CA_1_evaluation FOODS_1_001 FOODS_1 \n", + "... ... ... ... \n", + "60034805 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "60034806 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "60034807 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "60034808 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "60034809 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "\n", + " cat_id store_id state_id day_id sales release_week \n", + "0 FOODS CA_1 CA d_1 3.0 11101 \n", + "1 FOODS CA_1 CA d_2 0.0 11101 \n", + "2 FOODS CA_1 CA d_3 0.0 11101 \n", + "3 FOODS CA_1 CA d_4 1.0 11101 \n", + "4 FOODS CA_1 CA d_5 4.0 11101 \n", + "... ... ... ... ... ... ... \n", + "60034805 HOUSEHOLD WI_3 WI d_1965 NaN 11101 \n", + "60034806 HOUSEHOLD WI_3 WI d_1966 NaN 11101 \n", + "60034807 HOUSEHOLD WI_3 WI d_1967 NaN 11101 \n", + "60034808 HOUSEHOLD WI_3 WI d_1968 NaN 11101 \n", + "60034809 HOUSEHOLD WI_3 WI d_1969 NaN 11101 \n", + "\n", + "[60034810 rows x 9 columns]" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "grid_df = grid_df.merge(release_df, on=[\"store_id\", \"item_id\"], how=\"left\")\n", + "grid_df = grid_df.sort_values(index_columns + [\"day_id\"]).reset_index(drop=True)\n", + "grid_df" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "9e2c3c9b-c753-4b49-876d-447e73a55c81", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "138" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "del release_df # No longer needed\n", + "gc.collect()" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "b973c4df-67a6-414b-8c5d-883ab276bbd9", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "grid_df takes up 1.2GiB memory on GPU\n" + ] + } + ], + "source": [ + "report_dataframe_size(grid_df, \"grid_df\")" + ] + }, + { + "cell_type": "markdown", + "id": "01e39a12-a2b8-49ce-9a08-42fa08f2d1a8", + "metadata": {}, + "source": [ + "## Filter out entries with zero sales" + ] + }, + { + "cell_type": "markdown", + "id": "8733da5f-44bf-4b23-9645-0711bb68f6ef", + "metadata": {}, + "source": [ + "We can further save space by dropping rows from `grid_df` that correspond to zero sales. Since each product doesn't go on the shelf until its release week, its sale must be zero during any week that's prior to the release week.\n", + "\n", + "To make use of this insight, we bring in the `wm_yr_wk` column from `calendar_df`:" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "565fee38-25a2-4ed7-a160-5be11bd68df4", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
iditem_iddept_idcat_idstore_idstate_idday_idsalesrelease_weekwm_yr_wk
0FOODS_1_001_CA_1_evaluationFOODS_1_001FOODS_1FOODSCA_1CAd_15371.01110111511
1FOODS_1_001_CA_1_evaluationFOODS_1_001FOODS_1FOODSCA_1CAd_15380.01110111511
2FOODS_1_001_CA_1_evaluationFOODS_1_001FOODS_1FOODSCA_1CAd_15392.01110111511
3FOODS_1_001_CA_1_evaluationFOODS_1_001FOODS_1FOODSCA_1CAd_15400.01110111511
4FOODS_1_001_CA_1_evaluationFOODS_1_001FOODS_1FOODSCA_1CAd_15410.01110111512
.................................
60034805HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_520.01110111108
60034806HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_530.01110111108
60034807HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_540.01110111108
60034808HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_550.01110111108
60034809HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_490.01110111107
\n", + "

60034810 rows × 10 columns

\n", + "
" + ], + "text/plain": [ + " id item_id dept_id \\\n", + "0 FOODS_1_001_CA_1_evaluation FOODS_1_001 FOODS_1 \n", + "1 FOODS_1_001_CA_1_evaluation FOODS_1_001 FOODS_1 \n", + "2 FOODS_1_001_CA_1_evaluation FOODS_1_001 FOODS_1 \n", + "3 FOODS_1_001_CA_1_evaluation FOODS_1_001 FOODS_1 \n", + "4 FOODS_1_001_CA_1_evaluation FOODS_1_001 FOODS_1 \n", + "... ... ... ... \n", + "60034805 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "60034806 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "60034807 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "60034808 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "60034809 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "\n", + " cat_id store_id state_id day_id sales release_week wm_yr_wk \n", + "0 FOODS CA_1 CA d_1537 1.0 11101 11511 \n", + "1 FOODS CA_1 CA d_1538 0.0 11101 11511 \n", + "2 FOODS CA_1 CA d_1539 2.0 11101 11511 \n", + "3 FOODS CA_1 CA d_1540 0.0 11101 11511 \n", + "4 FOODS CA_1 CA d_1541 0.0 11101 11512 \n", + "... ... ... ... ... ... ... ... \n", + "60034805 HOUSEHOLD WI_3 WI d_52 0.0 11101 11108 \n", + "60034806 HOUSEHOLD WI_3 WI d_53 0.0 11101 11108 \n", + "60034807 HOUSEHOLD WI_3 WI d_54 0.0 11101 11108 \n", + "60034808 HOUSEHOLD WI_3 WI d_55 0.0 11101 11108 \n", + "60034809 HOUSEHOLD WI_3 WI d_49 0.0 11101 11107 \n", + "\n", + "[60034810 rows x 10 columns]" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "grid_df = grid_df.merge(calendar_df[[\"wm_yr_wk\", \"day_id\"]], on=[\"day_id\"], how=\"left\")\n", + "grid_df" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "4cdfb852-2d25-4b60-8470-71edb3c63d14", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "grid_df takes up 1.7GiB memory on GPU\n" + ] + } + ], + "source": [ + "report_dataframe_size(grid_df, \"grid_df\")" + ] + }, + { + "cell_type": "markdown", + "id": "cb5059b7-6a1f-464b-b027-ce3b389eb15e", + "metadata": {}, + "source": [ + "The `wm_yr_wk` column identifies the week that contains the day given by the `day_id` column. Now let's filter all rows in `grid_df` for which `wm_yr_wk` is less than `release_week`:" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "79b8cf35-776e-4295-8b5b-9e293c2ff325", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
iditem_iddept_idcat_idstore_idstate_idday_idsalesrelease_weekwm_yr_wk
9990FOODS_1_001_TX_3_evaluationFOODS_1_001FOODS_1FOODSTX_3TXd_10.01110211101
9991FOODS_1_001_TX_3_evaluationFOODS_1_001FOODS_1FOODSTX_3TXd_20.01110211101
9992FOODS_1_001_TX_3_evaluationFOODS_1_001FOODS_1FOODSTX_3TXd_30.01110211101
9993FOODS_1_001_TX_3_evaluationFOODS_1_001FOODS_1FOODSTX_3TXd_40.01110211101
9994FOODS_1_001_TX_3_evaluationFOODS_1_001FOODS_1FOODSTX_3TXd_50.01110211101
.................................
60032955HOUSEHOLD_2_516_WI_2_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_2WId_200.01110611103
60032956HOUSEHOLD_2_516_WI_2_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_2WId_210.01110611103
60032957HOUSEHOLD_2_516_WI_2_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_2WId_220.01110611104
60032958HOUSEHOLD_2_516_WI_2_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_2WId_230.01110611104
60032959HOUSEHOLD_2_516_WI_2_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_2WId_240.01110611104
\n", + "

12299413 rows × 10 columns

\n", + "
" + ], + "text/plain": [ + " id item_id dept_id \\\n", + "9990 FOODS_1_001_TX_3_evaluation FOODS_1_001 FOODS_1 \n", + "9991 FOODS_1_001_TX_3_evaluation FOODS_1_001 FOODS_1 \n", + "9992 FOODS_1_001_TX_3_evaluation FOODS_1_001 FOODS_1 \n", + "9993 FOODS_1_001_TX_3_evaluation FOODS_1_001 FOODS_1 \n", + "9994 FOODS_1_001_TX_3_evaluation FOODS_1_001 FOODS_1 \n", + "... ... ... ... \n", + "60032955 HOUSEHOLD_2_516_WI_2_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "60032956 HOUSEHOLD_2_516_WI_2_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "60032957 HOUSEHOLD_2_516_WI_2_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "60032958 HOUSEHOLD_2_516_WI_2_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "60032959 HOUSEHOLD_2_516_WI_2_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "\n", + " cat_id store_id state_id day_id sales release_week wm_yr_wk \n", + "9990 FOODS TX_3 TX d_1 0.0 11102 11101 \n", + "9991 FOODS TX_3 TX d_2 0.0 11102 11101 \n", + "9992 FOODS TX_3 TX d_3 0.0 11102 11101 \n", + "9993 FOODS TX_3 TX d_4 0.0 11102 11101 \n", + "9994 FOODS TX_3 TX d_5 0.0 11102 11101 \n", + "... ... ... ... ... ... ... ... \n", + "60032955 HOUSEHOLD WI_2 WI d_20 0.0 11106 11103 \n", + "60032956 HOUSEHOLD WI_2 WI d_21 0.0 11106 11103 \n", + "60032957 HOUSEHOLD WI_2 WI d_22 0.0 11106 11104 \n", + "60032958 HOUSEHOLD WI_2 WI d_23 0.0 11106 11104 \n", + "60032959 HOUSEHOLD WI_2 WI d_24 0.0 11106 11104 \n", + "\n", + "[12299413 rows x 10 columns]" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df = grid_df[grid_df[\"wm_yr_wk\"] < grid_df[\"release_week\"]]\n", + "df" + ] + }, + { + "cell_type": "markdown", + "id": "3056f24d-8bef-457e-aee0-82b43459ba92", + "metadata": {}, + "source": [ + "As we suspected, the sales amount is zero during weeks that come before the release week." + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "e82c98f0-1c6d-49ae-abd4-e37c3a07bc16", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "assert (df[\"sales\"] == 0).all()" + ] + }, + { + "cell_type": "markdown", + "id": "49d6e163-e634-43c6-a021-55d5604944ff", + "metadata": {}, + "source": [ + "For the purpose of our data analysis, we can safely drop the rows with zero sales:" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "dd694a7a-86c0-44c9-b29c-8387f047500c", + "metadata": { + "scrolled": true, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
iditem_iddept_idcat_idstore_idstate_idday_idsalesrelease_weekwm_yr_wk
0FOODS_1_001_CA_1_evaluationFOODS_1_001FOODS_1FOODSCA_1CAd_15371.01110111511
1FOODS_1_001_CA_1_evaluationFOODS_1_001FOODS_1FOODSCA_1CAd_15380.01110111511
2FOODS_1_001_CA_1_evaluationFOODS_1_001FOODS_1FOODSCA_1CAd_15392.01110111511
3FOODS_1_001_CA_1_evaluationFOODS_1_001FOODS_1FOODSCA_1CAd_15400.01110111511
4FOODS_1_001_CA_1_evaluationFOODS_1_001FOODS_1FOODSCA_1CAd_15410.01110111512
.................................
47735392HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_520.01110111108
47735393HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_530.01110111108
47735394HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_540.01110111108
47735395HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_550.01110111108
47735396HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_490.01110111107
\n", + "

47735397 rows × 10 columns

\n", + "
" + ], + "text/plain": [ + " id item_id dept_id \\\n", + "0 FOODS_1_001_CA_1_evaluation FOODS_1_001 FOODS_1 \n", + "1 FOODS_1_001_CA_1_evaluation FOODS_1_001 FOODS_1 \n", + "2 FOODS_1_001_CA_1_evaluation FOODS_1_001 FOODS_1 \n", + "3 FOODS_1_001_CA_1_evaluation FOODS_1_001 FOODS_1 \n", + "4 FOODS_1_001_CA_1_evaluation FOODS_1_001 FOODS_1 \n", + "... ... ... ... \n", + "47735392 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "47735393 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "47735394 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "47735395 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "47735396 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "\n", + " cat_id store_id state_id day_id sales release_week wm_yr_wk \n", + "0 FOODS CA_1 CA d_1537 1.0 11101 11511 \n", + "1 FOODS CA_1 CA d_1538 0.0 11101 11511 \n", + "2 FOODS CA_1 CA d_1539 2.0 11101 11511 \n", + "3 FOODS CA_1 CA d_1540 0.0 11101 11511 \n", + "4 FOODS CA_1 CA d_1541 0.0 11101 11512 \n", + "... ... ... ... ... ... ... ... \n", + "47735392 HOUSEHOLD WI_3 WI d_52 0.0 11101 11108 \n", + "47735393 HOUSEHOLD WI_3 WI d_53 0.0 11101 11108 \n", + "47735394 HOUSEHOLD WI_3 WI d_54 0.0 11101 11108 \n", + "47735395 HOUSEHOLD WI_3 WI d_55 0.0 11101 11108 \n", + "47735396 HOUSEHOLD WI_3 WI d_49 0.0 11101 11107 \n", + "\n", + "[47735397 rows x 10 columns]" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "grid_df = grid_df[grid_df[\"wm_yr_wk\"] >= grid_df[\"release_week\"]].reset_index(drop=True)\n", + "grid_df[\"wm_yr_wk\"] = grid_df[\"wm_yr_wk\"].astype(np.int32) # Convert wm_yr_wk column to int32, to conserve memory\n", + "grid_df" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "b46d629b-9996-443f-95d5-9da888b00fa0", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "grid_df takes up 1.2GiB memory on GPU\n" + ] + } + ], + "source": [ + "report_dataframe_size(grid_df, \"grid_df\")" + ] + }, + { + "cell_type": "markdown", + "id": "d323fd3a-21e4-4df6-8d32-78948acbf6fd", + "metadata": {}, + "source": [ + "## Assign weights for product items" + ] + }, + { + "cell_type": "markdown", + "id": "4b55b355-ffc3-429d-98bd-30c356f70c06", + "metadata": {}, + "source": [ + "When we assess the accuracy of our machine learning model, we should assign a weight for each product item, to indicate the relative importance of the item. For the M5 competition, the weights are computed from the total sales amount (in US dollars) in the lastest 28 days." + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "0b216249-dd2d-40e5-bbb3-6a41773c7b98", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
sales_usd_sum
item_id
FOODS_1_0013516.80
FOODS_1_00212418.80
FOODS_1_0035943.20
FOODS_1_00454184.82
FOODS_1_00517877.00
......
HOUSEHOLD_2_5126034.40
HOUSEHOLD_2_5132668.80
HOUSEHOLD_2_5149574.60
HOUSEHOLD_2_515630.40
HOUSEHOLD_2_5162574.00
\n", + "

3049 rows × 1 columns

\n", + "
" + ], + "text/plain": [ + " sales_usd_sum\n", + "item_id \n", + "FOODS_1_001 3516.80\n", + "FOODS_1_002 12418.80\n", + "FOODS_1_003 5943.20\n", + "FOODS_1_004 54184.82\n", + "FOODS_1_005 17877.00\n", + "... ...\n", + "HOUSEHOLD_2_512 6034.40\n", + "HOUSEHOLD_2_513 2668.80\n", + "HOUSEHOLD_2_514 9574.60\n", + "HOUSEHOLD_2_515 630.40\n", + "HOUSEHOLD_2_516 2574.00\n", + "\n", + "[3049 rows x 1 columns]" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Convert day_id to integers\n", + "grid_df[\"day_id_int\"] = grid_df[\"day_id\"].to_pandas().apply(lambda x: x[2:]).astype(int)\n", + "\n", + "# Compute the total sales over the latest 28 days, per product item\n", + "last28 = grid_df[(grid_df[\"day_id_int\"] >= 1914) & (grid_df[\"day_id_int\"] < 1942)]\n", + "last28 = last28[[\"item_id\", \"wm_yr_wk\", \"sales\"]].merge(prices_df[[\"item_id\", \"wm_yr_wk\", \"sell_price\"]], on=[\"item_id\", \"wm_yr_wk\"])\n", + "last28[\"sales_usd\"] = last28[\"sales\"] * last28[\"sell_price\"]\n", + "total_sales_usd = last28.groupby(\"item_id\")[[\"sales_usd\"]].agg([\"sum\"]).sort_index()\n", + "total_sales_usd.columns = total_sales_usd.columns.map(\"_\".join)\n", + "total_sales_usd" + ] + }, + { + "cell_type": "markdown", + "id": "ae0934ba-37de-4391-bdb4-7c10685fec6c", + "metadata": {}, + "source": [ + "To obtain weights, we normalize the sales amount for one item by the total sales for all items." + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "bfab66af-8b70-41fb-bc99-b43567b3f87a", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
weights
item_id
FOODS_1_0010.000090
FOODS_1_0020.000318
FOODS_1_0030.000152
FOODS_1_0040.001389
FOODS_1_0050.000458
......
HOUSEHOLD_2_5120.000155
HOUSEHOLD_2_5130.000068
HOUSEHOLD_2_5140.000245
HOUSEHOLD_2_5150.000016
HOUSEHOLD_2_5160.000066
\n", + "

3049 rows × 1 columns

\n", + "
" + ], + "text/plain": [ + " weights\n", + "item_id \n", + "FOODS_1_001 0.000090\n", + "FOODS_1_002 0.000318\n", + "FOODS_1_003 0.000152\n", + "FOODS_1_004 0.001389\n", + "FOODS_1_005 0.000458\n", + "... ...\n", + "HOUSEHOLD_2_512 0.000155\n", + "HOUSEHOLD_2_513 0.000068\n", + "HOUSEHOLD_2_514 0.000245\n", + "HOUSEHOLD_2_515 0.000016\n", + "HOUSEHOLD_2_516 0.000066\n", + "\n", + "[3049 rows x 1 columns]" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "weights = total_sales_usd / total_sales_usd.sum()\n", + "weights = weights.rename(columns={\"sales_usd_sum\": \"weights\"})\n", + "weights" + ] + }, + { + "cell_type": "markdown", + "id": "2e717a7d-9022-4c44-9d09-2916fa45afb6", + "metadata": {}, + "source": [ + "## Persist the processed data to disk" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "cb06bbd9-1e00-44cb-b197-2098b7c67069", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# No longer needed\n", + "del grid_df[\"day_id_int\"]\n", + "\n", + "# Persist grid_df to disk\n", + "grid_df.to_pandas().to_pickle(processed_data_dir + \"grid_df_part1.pkl\")\n", + "weights.to_pandas().to_pickle(processed_data_dir + \"product_weights.pkl\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "126198cf-e07f-46e6-bc3b-cc6170a57c08", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/source/examples/time-series-forecasting-with-hpo/preprocessing_part2.ipynb b/source/examples/time-series-forecasting-with-hpo/preprocessing_part2.ipynb new file mode 100644 index 00000000..2142a6bf --- /dev/null +++ b/source/examples/time-series-forecasting-with-hpo/preprocessing_part2.ipynb @@ -0,0 +1,1259 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "a17c37ef-5035-4540-8066-bb9741496639", + "metadata": {}, + "source": [ + "# Data preprocesing, Part 2" + ] + }, + { + "cell_type": "markdown", + "id": "879d5849-97b4-4a05-9f26-7035a2ce220c", + "metadata": {}, + "source": [ + "## Import modules" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "96a23fd4-d1e3-4487-ae73-38d456fa2408", + "metadata": {}, + "outputs": [], + "source": [ + "import cudf\n", + "import numpy as np\n", + "import pandas as pd\n", + "import gc" + ] + }, + { + "cell_type": "markdown", + "id": "ed21ca9c-5a4e-4788-a2a2-2db5b9f5cab8", + "metadata": {}, + "source": [ + "## Load data" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "239019da-02cb-4e53-bfe1-ca0bfef1f201", + "metadata": {}, + "outputs": [], + "source": [ + "raw_data_dir = \"./data/\"\n", + "processed_data_dir = \"./processed_data/\"" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "eb0d40de-c3f3-40fe-9651-63c0cdf5b077", + "metadata": {}, + "outputs": [], + "source": [ + "prices_df = cudf.read_csv(raw_data_dir + \"sell_prices.csv\")\n", + "calendar_df = cudf.read_csv(raw_data_dir + \"calendar.csv\").rename(columns={\"d\": \"day_id\"})" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "5946cf28-0188-4c4d-b98f-6bc855e8a92d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
store_iditem_idwm_yr_wksell_price
0CA_1HOBBIES_1_001113259.58
1CA_1HOBBIES_1_001113269.58
2CA_1HOBBIES_1_001113278.26
3CA_1HOBBIES_1_001113288.26
4CA_1HOBBIES_1_001113298.26
...............
6841116WI_3FOODS_3_827116171.00
6841117WI_3FOODS_3_827116181.00
6841118WI_3FOODS_3_827116191.00
6841119WI_3FOODS_3_827116201.00
6841120WI_3FOODS_3_827116211.00
\n", + "

6841121 rows × 4 columns

\n", + "
" + ], + "text/plain": [ + " store_id item_id wm_yr_wk sell_price\n", + "0 CA_1 HOBBIES_1_001 11325 9.58\n", + "1 CA_1 HOBBIES_1_001 11326 9.58\n", + "2 CA_1 HOBBIES_1_001 11327 8.26\n", + "3 CA_1 HOBBIES_1_001 11328 8.26\n", + "4 CA_1 HOBBIES_1_001 11329 8.26\n", + "... ... ... ... ...\n", + "6841116 WI_3 FOODS_3_827 11617 1.00\n", + "6841117 WI_3 FOODS_3_827 11618 1.00\n", + "6841118 WI_3 FOODS_3_827 11619 1.00\n", + "6841119 WI_3 FOODS_3_827 11620 1.00\n", + "6841120 WI_3 FOODS_3_827 11621 1.00\n", + "\n", + "[6841121 rows x 4 columns]" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "prices_df" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "d6d8e3e5-426e-413c-8461-b01dbead0de0", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
datewm_yr_wkweekdaywdaymonthyearday_idevent_name_1event_type_1event_name_2event_type_2snap_CAsnap_TXsnap_WI
02011-01-2911101Saturday112011d_1<NA><NA><NA><NA>000
12011-01-3011101Sunday212011d_2<NA><NA><NA><NA>000
22011-01-3111101Monday312011d_3<NA><NA><NA><NA>000
32011-02-0111101Tuesday422011d_4<NA><NA><NA><NA>110
42011-02-0211101Wednesday522011d_5<NA><NA><NA><NA>101
.............................................
19642016-06-1511620Wednesday562016d_1965<NA><NA><NA><NA>011
19652016-06-1611620Thursday662016d_1966<NA><NA><NA><NA>000
19662016-06-1711620Friday762016d_1967<NA><NA><NA><NA>000
19672016-06-1811621Saturday162016d_1968<NA><NA><NA><NA>000
19682016-06-1911621Sunday262016d_1969NBAFinalsEndSportingFather's dayCultural000
\n", + "

1969 rows × 14 columns

\n", + "
" + ], + "text/plain": [ + " date wm_yr_wk weekday wday month year day_id \\\n", + "0 2011-01-29 11101 Saturday 1 1 2011 d_1 \n", + "1 2011-01-30 11101 Sunday 2 1 2011 d_2 \n", + "2 2011-01-31 11101 Monday 3 1 2011 d_3 \n", + "3 2011-02-01 11101 Tuesday 4 2 2011 d_4 \n", + "4 2011-02-02 11101 Wednesday 5 2 2011 d_5 \n", + "... ... ... ... ... ... ... ... \n", + "1964 2016-06-15 11620 Wednesday 5 6 2016 d_1965 \n", + "1965 2016-06-16 11620 Thursday 6 6 2016 d_1966 \n", + "1966 2016-06-17 11620 Friday 7 6 2016 d_1967 \n", + "1967 2016-06-18 11621 Saturday 1 6 2016 d_1968 \n", + "1968 2016-06-19 11621 Sunday 2 6 2016 d_1969 \n", + "\n", + " event_name_1 event_type_1 event_name_2 event_type_2 snap_CA snap_TX \\\n", + "0 0 0 \n", + "1 0 0 \n", + "2 0 0 \n", + "3 1 1 \n", + "4 1 0 \n", + "... ... ... ... ... ... ... \n", + "1964 0 1 \n", + "1965 0 0 \n", + "1966 0 0 \n", + "1967 0 0 \n", + "1968 NBAFinalsEnd Sporting Father's day Cultural 0 0 \n", + "\n", + " snap_WI \n", + "0 0 \n", + "1 0 \n", + "2 0 \n", + "3 0 \n", + "4 1 \n", + "... ... \n", + "1964 1 \n", + "1965 0 \n", + "1966 0 \n", + "1967 0 \n", + "1968 0 \n", + "\n", + "[1969 rows x 14 columns]" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "calendar_df" + ] + }, + { + "cell_type": "markdown", + "id": "8fa43d3f-aa95-4d98-9265-0d56f0f39eba", + "metadata": {}, + "source": [ + "## Generate price-related features" + ] + }, + { + "cell_type": "markdown", + "id": "ed99b0f4-ab4a-485f-a75e-211cb6bacd00", + "metadata": {}, + "source": [ + "Let us engineer additional features that are related to the sale price. We consider the distribution of the price of a given product over time and ask how the current price compares to the historical trend." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "380f03c7-c8e7-42a5-ac43-f4aec0883880", + "metadata": {}, + "outputs": [], + "source": [ + "# Highest price over all weeks\n", + "prices_df[\"price_max\"] = prices_df.groupby([\"store_id\", \"item_id\"])[\"sell_price\"].transform(\"max\")\n", + "# Lowest price over all weeks\n", + "prices_df[\"price_min\"] = prices_df.groupby([\"store_id\", \"item_id\"])[\"sell_price\"].transform(\"min\")\n", + "# Standard deviation of the price\n", + "prices_df[\"price_std\"] = prices_df.groupby([\"store_id\", \"item_id\"])[\"sell_price\"].transform(\"std\")\n", + "# Mean (average) price over all weeks\n", + "prices_df[\"price_mean\"] = prices_df.groupby([\"store_id\", \"item_id\"])[\"sell_price\"].transform(\"mean\")" + ] + }, + { + "cell_type": "markdown", + "id": "03eb645e-3a5e-41e5-866c-b9904cdd3953", + "metadata": {}, + "source": [ + "We also consider the ratio of the current price to the max price." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "743d269e-14ea-448d-9dc3-9214c5a42ea7", + "metadata": {}, + "outputs": [], + "source": [ + "prices_df[\"price_norm\"] = prices_df[\"sell_price\"] / prices_df[\"price_max\"]" + ] + }, + { + "cell_type": "markdown", + "id": "4c256268-3c1b-4904-926c-b0b5bf13e98a", + "metadata": {}, + "source": [ + "Some items have a very stable price, whereas other items respond to inflation quickly and rise in price. To capture the price elasticity, we count the number of unique price values for a given product over time." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "6668ab89-155a-4eb0-8f5c-e2f7895afdf9", + "metadata": {}, + "outputs": [], + "source": [ + "prices_df[\"price_nunique\"] = prices_df.groupby([\"store_id\", \"item_id\"])[\"sell_price\"].transform(\"nunique\")" + ] + }, + { + "cell_type": "markdown", + "id": "4da0b2ad-3b36-4bc1-a469-08ae9c8abb1b", + "metadata": {}, + "source": [ + "We also consider, for a given price, how many other items are being sold at the exact same price." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "7b8ef0a3-7736-4f95-8c8a-624f3f7a4faa", + "metadata": {}, + "outputs": [], + "source": [ + "prices_df[\"item_nunique\"] = prices_df.groupby([\"store_id\", \"sell_price\"])[\"item_id\"].transform(\"nunique\")" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "d0bebd07-4238-410d-9f16-f4cbeef90819", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
store_iditem_idwm_yr_wksell_priceprice_maxprice_minprice_stdprice_meanprice_normprice_nuniqueitem_nunique
0CA_1HOBBIES_1_001113259.589.588.260.1521398.2857141.00000033
1CA_1HOBBIES_1_001113269.589.588.260.1521398.2857141.00000033
2CA_1HOBBIES_1_001113278.269.588.260.1521398.2857140.86221335
3CA_1HOBBIES_1_001113288.269.588.260.1521398.2857140.86221335
4CA_1HOBBIES_1_001113298.269.588.260.1521398.2857140.86221335
....................................
6841116WI_3FOODS_3_827116171.001.001.000.0000001.0000001.0000001142
6841117WI_3FOODS_3_827116181.001.001.000.0000001.0000001.0000001142
6841118WI_3FOODS_3_827116191.001.001.000.0000001.0000001.0000001142
6841119WI_3FOODS_3_827116201.001.001.000.0000001.0000001.0000001142
6841120WI_3FOODS_3_827116211.001.001.000.0000001.0000001.0000001142
\n", + "

6841121 rows × 11 columns

\n", + "
" + ], + "text/plain": [ + " store_id item_id wm_yr_wk sell_price price_max price_min \\\n", + "0 CA_1 HOBBIES_1_001 11325 9.58 9.58 8.26 \n", + "1 CA_1 HOBBIES_1_001 11326 9.58 9.58 8.26 \n", + "2 CA_1 HOBBIES_1_001 11327 8.26 9.58 8.26 \n", + "3 CA_1 HOBBIES_1_001 11328 8.26 9.58 8.26 \n", + "4 CA_1 HOBBIES_1_001 11329 8.26 9.58 8.26 \n", + "... ... ... ... ... ... ... \n", + "6841116 WI_3 FOODS_3_827 11617 1.00 1.00 1.00 \n", + "6841117 WI_3 FOODS_3_827 11618 1.00 1.00 1.00 \n", + "6841118 WI_3 FOODS_3_827 11619 1.00 1.00 1.00 \n", + "6841119 WI_3 FOODS_3_827 11620 1.00 1.00 1.00 \n", + "6841120 WI_3 FOODS_3_827 11621 1.00 1.00 1.00 \n", + "\n", + " price_std price_mean price_norm price_nunique item_nunique \n", + "0 0.152139 8.285714 1.000000 3 3 \n", + "1 0.152139 8.285714 1.000000 3 3 \n", + "2 0.152139 8.285714 0.862213 3 5 \n", + "3 0.152139 8.285714 0.862213 3 5 \n", + "4 0.152139 8.285714 0.862213 3 5 \n", + "... ... ... ... ... ... \n", + "6841116 0.000000 1.000000 1.000000 1 142 \n", + "6841117 0.000000 1.000000 1.000000 1 142 \n", + "6841118 0.000000 1.000000 1.000000 1 142 \n", + "6841119 0.000000 1.000000 1.000000 1 142 \n", + "6841120 0.000000 1.000000 1.000000 1 142 \n", + "\n", + "[6841121 rows x 11 columns]" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "prices_df" + ] + }, + { + "cell_type": "markdown", + "id": "98d24d6e-91b2-43a8-8051-2898ce047254", + "metadata": {}, + "source": [ + "Another useful way to put prices in context is to compare the price of a product to its historical price a week ago, a month ago, or an year ago." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "7d95506a-aeaa-445b-8202-d158ce22c182", + "metadata": {}, + "outputs": [], + "source": [ + "# Add \"month\" and \"year\" columns to prices_df\n", + "week_to_month_map = calendar_df[[\"wm_yr_wk\", \"month\", \"year\"]].drop_duplicates(subset=[\"wm_yr_wk\"])\n", + "prices_df = prices_df.merge(week_to_month_map, on=[\"wm_yr_wk\"], how=\"left\")\n", + "\n", + "# Sort by wm_yr_wk. The rows will also be sorted in ascending months and years.\n", + "prices_df = prices_df.sort_values([\"store_id\", \"item_id\", \"wm_yr_wk\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "646edece-e9be-4f35-a252-9710383e5a5b", + "metadata": {}, + "outputs": [], + "source": [ + "# Compare with the average price in the previous week\n", + "prices_df[\"price_momentum\"] = prices_df[\"sell_price\"] / prices_df.groupby([\"store_id\", \"item_id\"])[\"sell_price\"].shift(1)\n", + "# Compare with the average price in the previous month\n", + "prices_df[\"price_momentum_m\"] = prices_df[\"sell_price\"] / prices_df.groupby([\"store_id\", \"item_id\", \"month\"])[\"sell_price\"].transform(\"mean\")\n", + "# Compare with the average price in the previous year\n", + "prices_df[\"price_momentum_y\"] = prices_df[\"sell_price\"] / prices_df.groupby([\"store_id\", \"item_id\", \"year\"])[\"sell_price\"].transform(\"mean\")" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "6205fac5-4bf9-4c9b-9eae-05341cda3834", + "metadata": {}, + "outputs": [], + "source": [ + "# Remove \"month\" and \"year\" columns, as we don't need them any more\n", + "del prices_df[\"month\"], prices_df[\"year\"]\n", + "\n", + "# Convert float64 columns into float32 type to save memory\n", + "columns = [\n", + " \"sell_price\", \"price_max\", \"price_min\", \"price_std\", \"price_mean\",\n", + " \"price_norm\", \"price_momentum\", \"price_momentum_m\", \"price_momentum_y\"\n", + "]\n", + "for col in columns:\n", + " prices_df[col] = prices_df[col].astype(np.float32)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "dc7e75b8-33d0-42bf-b14a-d398012c2ce6", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "store_id object\n", + "item_id object\n", + "wm_yr_wk int64\n", + "sell_price float32\n", + "price_max float32\n", + "price_min float32\n", + "price_std float32\n", + "price_mean float32\n", + "price_norm float32\n", + "price_nunique int32\n", + "item_nunique int32\n", + "price_momentum float32\n", + "price_momentum_m float32\n", + "price_momentum_y float32\n", + "dtype: object" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "prices_df.dtypes" + ] + }, + { + "cell_type": "markdown", + "id": "dab2dc14-6273-4598-98ca-734459b66aa5", + "metadata": {}, + "source": [ + "## Bring in price-related features into `grid_df`" + ] + }, + { + "cell_type": "markdown", + "id": "cf3bb251-871c-452a-b24b-c33d56559160", + "metadata": {}, + "source": [ + "We load `grid_df` from the Part 1 notebook and bring in columns from `price_df`." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "0c5564da-8608-4640-805f-d14960dbe760", + "metadata": {}, + "outputs": [], + "source": [ + "grid_df = cudf.from_pandas(pd.read_pickle(processed_data_dir + \"grid_df_part1.pkl\"))" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "a7901ee0-1a60-42cc-a983-fe6838b7a612", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idday_idsell_priceprice_maxprice_minprice_stdprice_meanprice_normprice_nuniqueitem_nuniqueprice_momentumprice_momentum_mprice_momentum_y
0FOODS_1_001_CA_2_evaluationd_10402.242.242.001.095719e-012.1693621.02611.01.0198681.0
1FOODS_1_001_CA_2_evaluationd_10412.242.242.001.095719e-012.1693621.02611.01.0198681.0
2FOODS_1_001_CA_2_evaluationd_10422.242.242.001.095719e-012.1693621.02611.01.0198681.0
3FOODS_1_001_CA_2_evaluationd_10432.242.242.001.095719e-012.1693621.02611.01.0198681.0
4FOODS_1_001_CA_2_evaluationd_10442.242.242.001.095719e-012.1693621.02611.01.0249581.0
..........................................
47735392HOUSEHOLD_2_516_WI_2_evaluationd_8845.945.945.943.648122e-145.9400001.01471.01.0000001.0
47735393HOUSEHOLD_2_516_WI_2_evaluationd_8855.945.945.943.648122e-145.9400001.01471.01.0000001.0
47735394HOUSEHOLD_2_516_WI_2_evaluationd_8865.945.945.943.648122e-145.9400001.01471.01.0000001.0
47735395HOUSEHOLD_2_516_WI_2_evaluationd_8875.945.945.943.648122e-145.9400001.01471.01.0000001.0
47735396HOUSEHOLD_2_516_WI_2_evaluationd_8885.945.945.943.648122e-145.9400001.01471.01.0000001.0
\n", + "

47735397 rows × 13 columns

\n", + "
" + ], + "text/plain": [ + " id day_id sell_price price_max \\\n", + "0 FOODS_1_001_CA_2_evaluation d_1040 2.24 2.24 \n", + "1 FOODS_1_001_CA_2_evaluation d_1041 2.24 2.24 \n", + "2 FOODS_1_001_CA_2_evaluation d_1042 2.24 2.24 \n", + "3 FOODS_1_001_CA_2_evaluation d_1043 2.24 2.24 \n", + "4 FOODS_1_001_CA_2_evaluation d_1044 2.24 2.24 \n", + "... ... ... ... ... \n", + "47735392 HOUSEHOLD_2_516_WI_2_evaluation d_884 5.94 5.94 \n", + "47735393 HOUSEHOLD_2_516_WI_2_evaluation d_885 5.94 5.94 \n", + "47735394 HOUSEHOLD_2_516_WI_2_evaluation d_886 5.94 5.94 \n", + "47735395 HOUSEHOLD_2_516_WI_2_evaluation d_887 5.94 5.94 \n", + "47735396 HOUSEHOLD_2_516_WI_2_evaluation d_888 5.94 5.94 \n", + "\n", + " price_min price_std price_mean price_norm price_nunique \\\n", + "0 2.00 1.095719e-01 2.169362 1.0 2 \n", + "1 2.00 1.095719e-01 2.169362 1.0 2 \n", + "2 2.00 1.095719e-01 2.169362 1.0 2 \n", + "3 2.00 1.095719e-01 2.169362 1.0 2 \n", + "4 2.00 1.095719e-01 2.169362 1.0 2 \n", + "... ... ... ... ... ... \n", + "47735392 5.94 3.648122e-14 5.940000 1.0 1 \n", + "47735393 5.94 3.648122e-14 5.940000 1.0 1 \n", + "47735394 5.94 3.648122e-14 5.940000 1.0 1 \n", + "47735395 5.94 3.648122e-14 5.940000 1.0 1 \n", + "47735396 5.94 3.648122e-14 5.940000 1.0 1 \n", + "\n", + " item_nunique price_momentum price_momentum_m price_momentum_y \n", + "0 61 1.0 1.019868 1.0 \n", + "1 61 1.0 1.019868 1.0 \n", + "2 61 1.0 1.019868 1.0 \n", + "3 61 1.0 1.019868 1.0 \n", + "4 61 1.0 1.024958 1.0 \n", + "... ... ... ... ... \n", + "47735392 47 1.0 1.000000 1.0 \n", + "47735393 47 1.0 1.000000 1.0 \n", + "47735394 47 1.0 1.000000 1.0 \n", + "47735395 47 1.0 1.000000 1.0 \n", + "47735396 47 1.0 1.000000 1.0 \n", + "\n", + "[47735397 rows x 13 columns]" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# After merging price_df, keep columns id and day_id from grid_df and drop all other columns from grid_df\n", + "original_columns = list(grid_df)\n", + "grid_df = grid_df.merge(prices_df, on=[\"store_id\", \"item_id\", \"wm_yr_wk\"], how=\"left\")\n", + "columns_to_keep = [\"id\", \"day_id\"] + [col for col in list(grid_df) if col not in original_columns]\n", + "grid_df = grid_df[[\"id\", \"day_id\"] + columns_to_keep]\n", + "grid_df" + ] + }, + { + "cell_type": "markdown", + "id": "26dca23e-ff21-4f37-b93a-8c4edfdafe29", + "metadata": {}, + "source": [ + "We persist the combined table to disk." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "ee4b6223-c10e-409d-9d27-987e7c626b4c", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "grid_df.to_pandas().to_pickle(processed_data_dir + \"grid_df_part2.pkl\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/source/examples/time-series-forecasting-with-hpo/preprocessing_part3.ipynb b/source/examples/time-series-forecasting-with-hpo/preprocessing_part3.ipynb new file mode 100644 index 00000000..72d168b4 --- /dev/null +++ b/source/examples/time-series-forecasting-with-hpo/preprocessing_part3.ipynb @@ -0,0 +1,839 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "aaaa81be-da18-4108-aca2-1dc8e28ac34b", + "metadata": {}, + "source": [ + "# Data preprocesing, Part 3" + ] + }, + { + "cell_type": "markdown", + "id": "52de27ff-2ff7-4da8-a836-96cb109940ae", + "metadata": {}, + "source": [ + "## Import modules" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "29dd3d72-b823-4eaa-afb5-1e06c7311278", + "metadata": {}, + "outputs": [], + "source": [ + "import cudf\n", + "import numpy as np\n", + "import pandas as pd\n", + "import cupy as cp\n", + "import gc" + ] + }, + { + "cell_type": "markdown", + "id": "4b598eb9-d289-48f5-8e4d-30b84e7000d9", + "metadata": {}, + "source": [ + "## Load data" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "d9f842e3-51b8-4a06-9ad8-2c5832682c97", + "metadata": {}, + "outputs": [], + "source": [ + "raw_data_dir = \"./data/\"\n", + "processed_data_dir = \"./processed_data/\"" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "97f8792c-d41a-4e1d-ac9f-7a498937ead6", + "metadata": {}, + "outputs": [], + "source": [ + "calendar_df = cudf.read_csv(raw_data_dir + \"calendar.csv\").rename(columns={\"d\": \"day_id\"})" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "8a519b84-528d-45eb-97e4-08cf53775e74", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idday_id
0FOODS_1_001_CA_1_evaluationd_1537
1FOODS_1_001_CA_1_evaluationd_1538
2FOODS_1_001_CA_1_evaluationd_1539
3FOODS_1_001_CA_1_evaluationd_1540
4FOODS_1_001_CA_1_evaluationd_1541
.........
47735392HOUSEHOLD_2_516_WI_3_evaluationd_52
47735393HOUSEHOLD_2_516_WI_3_evaluationd_53
47735394HOUSEHOLD_2_516_WI_3_evaluationd_54
47735395HOUSEHOLD_2_516_WI_3_evaluationd_55
47735396HOUSEHOLD_2_516_WI_3_evaluationd_49
\n", + "

47735397 rows × 2 columns

\n", + "
" + ], + "text/plain": [ + " id day_id\n", + "0 FOODS_1_001_CA_1_evaluation d_1537\n", + "1 FOODS_1_001_CA_1_evaluation d_1538\n", + "2 FOODS_1_001_CA_1_evaluation d_1539\n", + "3 FOODS_1_001_CA_1_evaluation d_1540\n", + "4 FOODS_1_001_CA_1_evaluation d_1541\n", + "... ... ...\n", + "47735392 HOUSEHOLD_2_516_WI_3_evaluation d_52\n", + "47735393 HOUSEHOLD_2_516_WI_3_evaluation d_53\n", + "47735394 HOUSEHOLD_2_516_WI_3_evaluation d_54\n", + "47735395 HOUSEHOLD_2_516_WI_3_evaluation d_55\n", + "47735396 HOUSEHOLD_2_516_WI_3_evaluation d_49\n", + "\n", + "[47735397 rows x 2 columns]" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "grid_df = cudf.from_pandas(pd.read_pickle(processed_data_dir + \"grid_df_part1.pkl\"))\n", + "grid_df = grid_df[[\"id\", \"day_id\"]]\n", + "grid_df" + ] + }, + { + "cell_type": "markdown", + "id": "9fbca2d4-fea9-4bfe-88f9-3c6f0fd8dc50", + "metadata": {}, + "source": [ + "## Generate date-related features" + ] + }, + { + "cell_type": "markdown", + "id": "dff9fb9f-4f72-445a-8e49-0136c98481d1", + "metadata": {}, + "source": [ + "We first identify the date in each row of `grid_df` using information from `calendar_df`." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "ed28f3b3-59ef-4bc0-9e39-6639e6d3c7c7", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idday_iddateevent_name_1event_type_1event_name_2event_type_2snap_CAsnap_TXsnap_WI
0FOODS_1_001_TX_3_evaluationd_15222015-03-30<NA><NA><NA><NA>000
1FOODS_1_001_TX_3_evaluationd_15232015-03-31<NA><NA><NA><NA>000
2FOODS_1_001_TX_3_evaluationd_15242015-04-01<NA><NA><NA><NA>110
3FOODS_1_001_TX_3_evaluationd_15252015-04-02<NA><NA><NA><NA>101
4FOODS_1_001_TX_3_evaluationd_15262015-04-03<NA><NA><NA><NA>111
.................................
47735392HOUSEHOLD_2_516_WI_3_evaluationd_522011-03-21<NA><NA><NA><NA>000
47735393HOUSEHOLD_2_516_WI_3_evaluationd_532011-03-22<NA><NA><NA><NA>000
47735394HOUSEHOLD_2_516_WI_3_evaluationd_542011-03-23<NA><NA><NA><NA>000
47735395HOUSEHOLD_2_516_WI_3_evaluationd_552011-03-24<NA><NA><NA><NA>000
47735396HOUSEHOLD_2_516_WI_3_evaluationd_492011-03-18<NA><NA><NA><NA>000
\n", + "

47735397 rows × 10 columns

\n", + "
" + ], + "text/plain": [ + " id day_id date event_name_1 \\\n", + "0 FOODS_1_001_TX_3_evaluation d_1522 2015-03-30 \n", + "1 FOODS_1_001_TX_3_evaluation d_1523 2015-03-31 \n", + "2 FOODS_1_001_TX_3_evaluation d_1524 2015-04-01 \n", + "3 FOODS_1_001_TX_3_evaluation d_1525 2015-04-02 \n", + "4 FOODS_1_001_TX_3_evaluation d_1526 2015-04-03 \n", + "... ... ... ... ... \n", + "47735392 HOUSEHOLD_2_516_WI_3_evaluation d_52 2011-03-21 \n", + "47735393 HOUSEHOLD_2_516_WI_3_evaluation d_53 2011-03-22 \n", + "47735394 HOUSEHOLD_2_516_WI_3_evaluation d_54 2011-03-23 \n", + "47735395 HOUSEHOLD_2_516_WI_3_evaluation d_55 2011-03-24 \n", + "47735396 HOUSEHOLD_2_516_WI_3_evaluation d_49 2011-03-18 \n", + "\n", + " event_type_1 event_name_2 event_type_2 snap_CA snap_TX snap_WI \n", + "0 0 0 0 \n", + "1 0 0 0 \n", + "2 1 1 0 \n", + "3 1 0 1 \n", + "4 1 1 1 \n", + "... ... ... ... ... ... ... \n", + "47735392 0 0 0 \n", + "47735393 0 0 0 \n", + "47735394 0 0 0 \n", + "47735395 0 0 0 \n", + "47735396 0 0 0 \n", + "\n", + "[47735397 rows x 10 columns]" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Bring in the following columns from calendar_df into grid_df\n", + "icols = [\n", + " \"date\",\n", + " \"day_id\",\n", + " \"event_name_1\",\n", + " \"event_type_1\",\n", + " \"event_name_2\",\n", + " \"event_type_2\",\n", + " \"snap_CA\",\n", + " \"snap_TX\",\n", + " \"snap_WI\",\n", + "]\n", + "grid_df = grid_df.merge(calendar_df[icols], on=[\"day_id\"], how=\"left\")\n", + "grid_df" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "dac7b398-05a9-48cc-bcb1-f53987c4d872", + "metadata": {}, + "outputs": [], + "source": [ + "# Convert columns into categorical type to save memory\n", + "for col in [\"event_name_1\", \"event_type_1\", \"event_name_2\", \"event_type_2\",\n", + " \"snap_CA\", \"snap_TX\", \"snap_WI\"]:\n", + " grid_df[col] = grid_df[col].astype(\"category\")\n", + "# Convert \"date\" column into timestamp type\n", + "grid_df[\"date\"] = cudf.to_datetime(grid_df[\"date\"])" + ] + }, + { + "cell_type": "markdown", + "id": "9963b630-c88a-4cb9-a04e-7b0594fc0cca", + "metadata": {}, + "source": [ + "Using the `date` column, we can generate related features, such as day, week, or month." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "185fa77b-6cd9-47aa-a30c-1fe7e704fbb3", + "metadata": {}, + "outputs": [], + "source": [ + "grid_df[\"tm_d\"] = grid_df[\"date\"].dt.day.astype(np.int8)\n", + "grid_df[\"tm_w\"] = grid_df[\"date\"].dt.isocalendar().week.astype(np.int8)\n", + "grid_df[\"tm_m\"] = grid_df[\"date\"].dt.month.astype(np.int8)\n", + "grid_df[\"tm_y\"] = grid_df[\"date\"].dt.year\n", + "grid_df[\"tm_y\"] = (grid_df[\"tm_y\"] - grid_df[\"tm_y\"].min()).astype(np.int8)\n", + "grid_df[\"tm_wm\"] = cp.ceil(grid_df[\"tm_d\"].to_cupy() / 7).astype(np.int8) # which week in tje month?\n", + "grid_df[\"tm_dw\"] = grid_df[\"date\"].dt.dayofweek.astype(np.int8) # which day in the week?\n", + "grid_df[\"tm_w_end\"] = (grid_df[\"tm_dw\"] >= 5).astype(np.int8) # whether today is in the weekend\n", + "del grid_df[\"date\"] # no longer needed" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "6674ebe5-ef74-4e81-9ec1-41446707b3e3", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idday_idevent_name_1event_type_1event_name_2event_type_2snap_CAsnap_TXsnap_WItm_dtm_wtm_mtm_ytm_wmtm_dwtm_w_end
0FOODS_1_001_TX_3_evaluationd_1522<NA><NA><NA><NA>000301434500
1FOODS_1_001_TX_3_evaluationd_1523<NA><NA><NA><NA>000311434510
2FOODS_1_001_TX_3_evaluationd_1524<NA><NA><NA><NA>11011444120
3FOODS_1_001_TX_3_evaluationd_1525<NA><NA><NA><NA>10121444130
4FOODS_1_001_TX_3_evaluationd_1526<NA><NA><NA><NA>11131444140
...................................................
47735392HOUSEHOLD_2_516_WI_3_evaluationd_52<NA><NA><NA><NA>000211230300
47735393HOUSEHOLD_2_516_WI_3_evaluationd_53<NA><NA><NA><NA>000221230410
47735394HOUSEHOLD_2_516_WI_3_evaluationd_54<NA><NA><NA><NA>000231230420
47735395HOUSEHOLD_2_516_WI_3_evaluationd_55<NA><NA><NA><NA>000241230430
47735396HOUSEHOLD_2_516_WI_3_evaluationd_49<NA><NA><NA><NA>000181130340
\n", + "

47735397 rows × 16 columns

\n", + "
" + ], + "text/plain": [ + " id day_id event_name_1 event_type_1 \\\n", + "0 FOODS_1_001_TX_3_evaluation d_1522 \n", + "1 FOODS_1_001_TX_3_evaluation d_1523 \n", + "2 FOODS_1_001_TX_3_evaluation d_1524 \n", + "3 FOODS_1_001_TX_3_evaluation d_1525 \n", + "4 FOODS_1_001_TX_3_evaluation d_1526 \n", + "... ... ... ... ... \n", + "47735392 HOUSEHOLD_2_516_WI_3_evaluation d_52 \n", + "47735393 HOUSEHOLD_2_516_WI_3_evaluation d_53 \n", + "47735394 HOUSEHOLD_2_516_WI_3_evaluation d_54 \n", + "47735395 HOUSEHOLD_2_516_WI_3_evaluation d_55 \n", + "47735396 HOUSEHOLD_2_516_WI_3_evaluation d_49 \n", + "\n", + " event_name_2 event_type_2 snap_CA snap_TX snap_WI tm_d tm_w tm_m \\\n", + "0 0 0 0 30 14 3 \n", + "1 0 0 0 31 14 3 \n", + "2 1 1 0 1 14 4 \n", + "3 1 0 1 2 14 4 \n", + "4 1 1 1 3 14 4 \n", + "... ... ... ... ... ... ... ... ... \n", + "47735392 0 0 0 21 12 3 \n", + "47735393 0 0 0 22 12 3 \n", + "47735394 0 0 0 23 12 3 \n", + "47735395 0 0 0 24 12 3 \n", + "47735396 0 0 0 18 11 3 \n", + "\n", + " tm_y tm_wm tm_dw tm_w_end \n", + "0 4 5 0 0 \n", + "1 4 5 1 0 \n", + "2 4 1 2 0 \n", + "3 4 1 3 0 \n", + "4 4 1 4 0 \n", + "... ... ... ... ... \n", + "47735392 0 3 0 0 \n", + "47735393 0 4 1 0 \n", + "47735394 0 4 2 0 \n", + "47735395 0 4 3 0 \n", + "47735396 0 3 4 0 \n", + "\n", + "[47735397 rows x 16 columns]" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "grid_df" + ] + }, + { + "cell_type": "markdown", + "id": "4dab0da0-77e1-4462-bae7-d8a8f6f35aa7", + "metadata": {}, + "source": [ + "Now we can persist the table to the disk." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "52065927-8a8b-4b3c-8fde-355b2961645f", + "metadata": {}, + "outputs": [], + "source": [ + "grid_df.to_pandas().to_pickle(processed_data_dir + \"grid_df_part3.pkl\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7b94402c-293f-4e8e-bddc-8416809da94f", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/source/examples/time-series-forecasting-with-hpo/preprocessing_part4.ipynb b/source/examples/time-series-forecasting-with-hpo/preprocessing_part4.ipynb new file mode 100644 index 00000000..c4f92f32 --- /dev/null +++ b/source/examples/time-series-forecasting-with-hpo/preprocessing_part4.ipynb @@ -0,0 +1,623 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "24f14a4c-4adf-4fe8-9f20-5a95f4a40fca", + "metadata": {}, + "source": [ + "# Data preprocesing, Part 4" + ] + }, + { + "cell_type": "markdown", + "id": "7c34d719-2223-4593-8eaf-95bad7b3d1d4", + "metadata": {}, + "source": [ + "## Import modules" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "5b6b0e92-e176-4b22-be2c-374ab2a1c037", + "metadata": {}, + "outputs": [], + "source": [ + "import cudf\n", + "import numpy as np\n", + "import pandas as pd\n", + "import gc" + ] + }, + { + "cell_type": "markdown", + "id": "509f32d0-9adb-4eda-9a16-02b37d1e2b8f", + "metadata": {}, + "source": [ + "## Load data" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "d013b288-07ff-4414-b5c4-985584158ed7", + "metadata": {}, + "outputs": [], + "source": [ + "raw_data_dir = \"./data/\"\n", + "processed_data_dir = \"./processed_data/\"" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "a8ec4415-1bca-4542-a28e-df810b1bc8bf", + "metadata": {}, + "outputs": [], + "source": [ + "grid_df = cudf.DataFrame(pd.read_pickle(processed_data_dir + \"grid_df_part1.pkl\"))\n", + "grid_df = grid_df[[\"id\", \"day_id\", \"sales\"]]\n", + "SHIFT_DAY = 28" + ] + }, + { + "cell_type": "markdown", + "id": "7b17f489-a11c-4152-bedd-da22a5a08782", + "metadata": {}, + "source": [ + "## Generate lag features" + ] + }, + { + "cell_type": "markdown", + "id": "35c3ce2b-164c-464d-a3d1-c84311cac9ae", + "metadata": {}, + "source": [ + "**Lag features** are the value of the target variable at prior timestamps. Lag features are useful because what happens in the past often influences what would happen in the future. In our example, we generate lag features by reading the sales amount at X days prior, where X = 28, 29, ..., 42." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "7cb4b9a2-1332-41ef-8a17-ebe31037d712", + "metadata": {}, + "outputs": [], + "source": [ + "LAG_DAYS = [col for col in range(SHIFT_DAY, SHIFT_DAY + 15)]\n", + "\n", + "# Need to first ensure that rows in each time series are sorted by day_id\n", + "grid_df = grid_df.sort_values([\"id\", \"day_id\"])\n", + "\n", + "grid_df = grid_df.assign(\n", + " **{\n", + " f\"sales_lag_{l}\": grid_df.groupby([\"id\"])[\"sales\"].shift(l)\n", + " for l in LAG_DAYS\n", + " }\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "92377656-ce8e-4101-88ba-260c14b9583b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idday_idsalessales_lag_28sales_lag_29sales_lag_30sales_lag_31sales_lag_32sales_lag_33sales_lag_34sales_lag_35sales_lag_36sales_lag_37sales_lag_38sales_lag_39sales_lag_40sales_lag_41sales_lag_42
13225FOODS_1_001_CA_1_evaluationd_13.0<NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA>
13226FOODS_1_001_CA_1_evaluationd_20.0<NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA>
13227FOODS_1_001_CA_1_evaluationd_30.0<NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA>
13228FOODS_1_001_CA_1_evaluationd_41.0<NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA>
13229FOODS_1_001_CA_1_evaluationd_54.0<NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA>
.........................................................
47734672HOUSEHOLD_2_516_WI_3_evaluationd_1965<NA>0.00.00.00.01.00.00.00.00.00.00.00.00.00.00.0
47734673HOUSEHOLD_2_516_WI_3_evaluationd_1966<NA>0.00.00.00.00.01.00.00.00.00.00.00.00.00.00.0
47734674HOUSEHOLD_2_516_WI_3_evaluationd_1967<NA>0.00.00.00.00.00.01.00.00.00.00.00.00.00.00.0
47734675HOUSEHOLD_2_516_WI_3_evaluationd_1968<NA>0.00.00.00.00.00.00.01.00.00.00.00.00.00.00.0
47734676HOUSEHOLD_2_516_WI_3_evaluationd_1969<NA>0.00.00.00.00.00.00.00.01.00.00.00.00.00.00.0
\n", + "

47735397 rows × 18 columns

\n", + "
" + ], + "text/plain": [ + " id day_id sales sales_lag_28 \\\n", + "13225 FOODS_1_001_CA_1_evaluation d_1 3.0 \n", + "13226 FOODS_1_001_CA_1_evaluation d_2 0.0 \n", + "13227 FOODS_1_001_CA_1_evaluation d_3 0.0 \n", + "13228 FOODS_1_001_CA_1_evaluation d_4 1.0 \n", + "13229 FOODS_1_001_CA_1_evaluation d_5 4.0 \n", + "... ... ... ... ... \n", + "47734672 HOUSEHOLD_2_516_WI_3_evaluation d_1965 0.0 \n", + "47734673 HOUSEHOLD_2_516_WI_3_evaluation d_1966 0.0 \n", + "47734674 HOUSEHOLD_2_516_WI_3_evaluation d_1967 0.0 \n", + "47734675 HOUSEHOLD_2_516_WI_3_evaluation d_1968 0.0 \n", + "47734676 HOUSEHOLD_2_516_WI_3_evaluation d_1969 0.0 \n", + "\n", + " sales_lag_29 sales_lag_30 sales_lag_31 sales_lag_32 sales_lag_33 \\\n", + "13225 \n", + "13226 \n", + "13227 \n", + "13228 \n", + "13229 \n", + "... ... ... ... ... ... \n", + "47734672 0.0 0.0 0.0 1.0 0.0 \n", + "47734673 0.0 0.0 0.0 0.0 1.0 \n", + "47734674 0.0 0.0 0.0 0.0 0.0 \n", + "47734675 0.0 0.0 0.0 0.0 0.0 \n", + "47734676 0.0 0.0 0.0 0.0 0.0 \n", + "\n", + " sales_lag_34 sales_lag_35 sales_lag_36 sales_lag_37 sales_lag_38 \\\n", + "13225 \n", + "13226 \n", + "13227 \n", + "13228 \n", + "13229 \n", + "... ... ... ... ... ... \n", + "47734672 0.0 0.0 0.0 0.0 0.0 \n", + "47734673 0.0 0.0 0.0 0.0 0.0 \n", + "47734674 1.0 0.0 0.0 0.0 0.0 \n", + "47734675 0.0 1.0 0.0 0.0 0.0 \n", + "47734676 0.0 0.0 1.0 0.0 0.0 \n", + "\n", + " sales_lag_39 sales_lag_40 sales_lag_41 sales_lag_42 \n", + "13225 \n", + "13226 \n", + "13227 \n", + "13228 \n", + "13229 \n", + "... ... ... ... ... \n", + "47734672 0.0 0.0 0.0 0.0 \n", + "47734673 0.0 0.0 0.0 0.0 \n", + "47734674 0.0 0.0 0.0 0.0 \n", + "47734675 0.0 0.0 0.0 0.0 \n", + "47734676 0.0 0.0 0.0 0.0 \n", + "\n", + "[47735397 rows x 18 columns]" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "grid_df" + ] + }, + { + "cell_type": "markdown", + "id": "2778fa50-d38b-4eb9-9ce4-a52a6455ac91", + "metadata": {}, + "source": [ + "## Compute rolling window statistics" + ] + }, + { + "cell_type": "markdown", + "id": "aabcb64a-2a43-4f8f-a1f8-5425c9fec2d3", + "metadata": {}, + "source": [ + "In the previous cell, we used the value of sales at a single timestamp to generate lag features. To capture richer information about the past, let us also get the distribution of the sales value over multiple timestamps, by computing **rolling window statistics**. Rolling window statistics are statistics (e.g. mean, standard deviation) over a time duration in the past. Rolling windows statistics complement lag features and provide more information about the past behavior of the target variable.\n", + "\n", + "Read more about lag features and rolling window statistics in [Introduction to feature engineering for time series forecasting](https://medium.com/data-science-at-microsoft/introduction-to-feature-engineering-for-time-series-forecasting-620aa55fcab0)." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "9cd8023a-4727-4365-8bc5-693d6dcd4979", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Shift size: 28\n", + " Window size: 7\n", + " Window size: 14\n", + " Window size: 30\n", + " Window size: 60\n", + " Window size: 180\n" + ] + } + ], + "source": [ + "# Shift by 28 days and apply windows of various sizes\n", + "print(f\"Shift size: {SHIFT_DAY}\")\n", + "for i in [7, 14, 30, 60, 180]:\n", + " print(f\" Window size: {i}\")\n", + " grid_df[f\"rolling_mean_{i}\"] = (\n", + " grid_df.groupby([\"id\"])[\"sales\"].shift(SHIFT_DAY).rolling(i).mean().astype(np.float32)\n", + " )\n", + " grid_df[f\"rolling_std_{i}\"] = (\n", + " grid_df.groupby([\"id\"])[\"sales\"].shift(SHIFT_DAY).rolling(i).std().astype(np.float32)\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "6b4197ca-23a4-4b17-981f-1d35b5f6b671", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Index(['id', 'day_id', 'sales', 'sales_lag_28', 'sales_lag_29', 'sales_lag_30',\n", + " 'sales_lag_31', 'sales_lag_32', 'sales_lag_33', 'sales_lag_34',\n", + " 'sales_lag_35', 'sales_lag_36', 'sales_lag_37', 'sales_lag_38',\n", + " 'sales_lag_39', 'sales_lag_40', 'sales_lag_41', 'sales_lag_42',\n", + " 'rolling_mean_7', 'rolling_std_7', 'rolling_mean_14', 'rolling_std_14',\n", + " 'rolling_mean_30', 'rolling_std_30', 'rolling_mean_60',\n", + " 'rolling_std_60', 'rolling_mean_180', 'rolling_std_180'],\n", + " dtype='object')" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "grid_df.columns" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "f6b439b6-3b44-44f4-84bd-4c99edc59448", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "id category\n", + "day_id category\n", + "sales float32\n", + "sales_lag_28 float32\n", + "sales_lag_29 float32\n", + "sales_lag_30 float32\n", + "sales_lag_31 float32\n", + "sales_lag_32 float32\n", + "sales_lag_33 float32\n", + "sales_lag_34 float32\n", + "sales_lag_35 float32\n", + "sales_lag_36 float32\n", + "sales_lag_37 float32\n", + "sales_lag_38 float32\n", + "sales_lag_39 float32\n", + "sales_lag_40 float32\n", + "sales_lag_41 float32\n", + "sales_lag_42 float32\n", + "rolling_mean_7 float32\n", + "rolling_std_7 float32\n", + "rolling_mean_14 float32\n", + "rolling_std_14 float32\n", + "rolling_mean_30 float32\n", + "rolling_std_30 float32\n", + "rolling_mean_60 float32\n", + "rolling_std_60 float32\n", + "rolling_mean_180 float32\n", + "rolling_std_180 float32\n", + "dtype: object" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "grid_df.dtypes" + ] + }, + { + "cell_type": "markdown", + "id": "79a2ae07-fa57-4579-9443-0cc2bb15a922", + "metadata": {}, + "source": [ + "Once lag features and rolling window statistics are computed, persist them to the disk." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "a6cd2826-0d33-42b8-8616-227a717639d0", + "metadata": {}, + "outputs": [], + "source": [ + "grid_df.to_pandas().to_pickle(processed_data_dir + \"lags_df_28.pkl\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f4979567-1c59-47cf-83c2-21d61d06b701", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/source/examples/time-series-forecasting-with-hpo/preprocessing_part5.ipynb b/source/examples/time-series-forecasting-with-hpo/preprocessing_part5.ipynb new file mode 100644 index 00000000..a17b3d95 --- /dev/null +++ b/source/examples/time-series-forecasting-with-hpo/preprocessing_part5.ipynb @@ -0,0 +1,595 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0929302e-1b8d-49ed-b4ba-be26f7579ec0", + "metadata": {}, + "source": [ + "# Data preprocesing, Part 5" + ] + }, + { + "cell_type": "markdown", + "id": "5ccecbb0-e81e-41e8-bb71-539d3e02bb73", + "metadata": {}, + "source": [ + "## Import modules" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "c6032bfd-3dd7-4d5c-a043-51945b068826", + "metadata": {}, + "outputs": [], + "source": [ + "import cudf\n", + "import numpy as np\n", + "import pandas as pd\n", + "import gc" + ] + }, + { + "cell_type": "markdown", + "id": "a14ddfe9-82b5-41ab-8612-a3e626e1792c", + "metadata": {}, + "source": [ + "## Load data" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "a4c537de-a406-42d4-8ad8-b2f9aad6bb80", + "metadata": {}, + "outputs": [], + "source": [ + "raw_data_dir = \"./data/\"\n", + "processed_data_dir = \"./processed_data/\"" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "882bfc6a-24f5-4690-a5ec-543f578a0843", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
iditem_iddept_idcat_idstore_idstate_idday_idsalesrelease_weekwm_yr_wk
0FOODS_1_001_CA_1_evaluationFOODS_1_001FOODS_1FOODSCA_1CAd_15371.01110111511
1FOODS_1_001_CA_1_evaluationFOODS_1_001FOODS_1FOODSCA_1CAd_15380.01110111511
2FOODS_1_001_CA_1_evaluationFOODS_1_001FOODS_1FOODSCA_1CAd_15392.01110111511
3FOODS_1_001_CA_1_evaluationFOODS_1_001FOODS_1FOODSCA_1CAd_15400.01110111511
4FOODS_1_001_CA_1_evaluationFOODS_1_001FOODS_1FOODSCA_1CAd_15410.01110111512
.................................
47735392HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_520.01110111108
47735393HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_530.01110111108
47735394HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_540.01110111108
47735395HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_550.01110111108
47735396HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_490.01110111107
\n", + "

47735397 rows × 10 columns

\n", + "
" + ], + "text/plain": [ + " id item_id dept_id \\\n", + "0 FOODS_1_001_CA_1_evaluation FOODS_1_001 FOODS_1 \n", + "1 FOODS_1_001_CA_1_evaluation FOODS_1_001 FOODS_1 \n", + "2 FOODS_1_001_CA_1_evaluation FOODS_1_001 FOODS_1 \n", + "3 FOODS_1_001_CA_1_evaluation FOODS_1_001 FOODS_1 \n", + "4 FOODS_1_001_CA_1_evaluation FOODS_1_001 FOODS_1 \n", + "... ... ... ... \n", + "47735392 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "47735393 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "47735394 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "47735395 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "47735396 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "\n", + " cat_id store_id state_id day_id sales release_week wm_yr_wk \n", + "0 FOODS CA_1 CA d_1537 1.0 11101 11511 \n", + "1 FOODS CA_1 CA d_1538 0.0 11101 11511 \n", + "2 FOODS CA_1 CA d_1539 2.0 11101 11511 \n", + "3 FOODS CA_1 CA d_1540 0.0 11101 11511 \n", + "4 FOODS CA_1 CA d_1541 0.0 11101 11512 \n", + "... ... ... ... ... ... ... ... \n", + "47735392 HOUSEHOLD WI_3 WI d_52 0.0 11101 11108 \n", + "47735393 HOUSEHOLD WI_3 WI d_53 0.0 11101 11108 \n", + "47735394 HOUSEHOLD WI_3 WI d_54 0.0 11101 11108 \n", + "47735395 HOUSEHOLD WI_3 WI d_55 0.0 11101 11108 \n", + "47735396 HOUSEHOLD WI_3 WI d_49 0.0 11101 11107 \n", + "\n", + "[47735397 rows x 10 columns]" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "grid_df = cudf.DataFrame(pd.read_pickle(processed_data_dir + \"grid_df_part1.pkl\"))\n", + "grid_df" + ] + }, + { + "cell_type": "markdown", + "id": "f338a3cd-2504-435a-8ac7-7fb5f71cea76", + "metadata": {}, + "source": [ + "## Target encoding" + ] + }, + { + "cell_type": "markdown", + "id": "0db75d1c-dcf6-4404-a4fe-b954d3388add", + "metadata": {}, + "source": [ + "Categorical variables present challenges to many machine learning algorithms such as XGBoost. One way to overcome the challenge is to use **target encoding**, where we encode categorical variables by replacing them with a statistic for the target variable. In this example, we will use the mean and the standard deviation.\n", + "\n", + "Read more about target encoding in [Target-encoding Categorical Variables](https://towardsdatascience.com/dealing-with-categorical-variables-by-using-target-encoder-a0f1733a4c69)." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "4dba104c-8d86-47c5-a044-eab5796a6f2e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Encoding columns ['store_id', 'dept_id']\n", + "Encoding columns ['item_id', 'state_id']\n" + ] + } + ], + "source": [ + "icols = [\n", + " [\"store_id\", \"dept_id\"],\n", + " [\"item_id\", \"state_id\"]\n", + "]\n", + "new_columns = []\n", + "\n", + "for col in icols:\n", + " print(f\"Encoding columns {col}\")\n", + " col_name = \"_\" + \"_\".join(col) + \"_\"\n", + " grid_df[\"enc\" + col_name + \"mean\"] = grid_df.groupby(col)[\"sales\"].transform(\"mean\").astype(np.float32)\n", + " grid_df[\"enc\" + col_name + \"std\"] = grid_df.groupby(col)[\"sales\"].transform(\"std\").astype(np.float32)\n", + " new_columns.extend([\"enc\" + col_name + \"mean\", \"enc\" + col_name + \"std\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "0cd78500-2c68-449b-8c8c-93eb57a860f5", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idday_idenc_store_id_dept_id_meanenc_store_id_dept_id_stdenc_item_id_state_id_meanenc_item_id_state_id_std
0FOODS_1_001_CA_1_evaluationd_15371.6131123.2166720.8733901.666305
1FOODS_1_001_CA_1_evaluationd_15381.6131123.2166720.8733901.666305
2FOODS_1_001_CA_1_evaluationd_15391.6131123.2166720.8733901.666305
3FOODS_1_001_CA_1_evaluationd_15401.6131123.2166720.8733901.666305
4FOODS_1_001_CA_1_evaluationd_15411.6131123.2166720.8733901.666305
.....................
47735392HOUSEHOLD_2_516_WI_3_evaluationd_520.2614860.6663800.0832760.301445
47735393HOUSEHOLD_2_516_WI_3_evaluationd_530.2614860.6663800.0832760.301445
47735394HOUSEHOLD_2_516_WI_3_evaluationd_540.2614860.6663800.0832760.301445
47735395HOUSEHOLD_2_516_WI_3_evaluationd_550.2614860.6663800.0832760.301445
47735396HOUSEHOLD_2_516_WI_3_evaluationd_490.2614860.6663800.0832760.301445
\n", + "

47735397 rows × 6 columns

\n", + "
" + ], + "text/plain": [ + " id day_id enc_store_id_dept_id_mean \\\n", + "0 FOODS_1_001_CA_1_evaluation d_1537 1.613112 \n", + "1 FOODS_1_001_CA_1_evaluation d_1538 1.613112 \n", + "2 FOODS_1_001_CA_1_evaluation d_1539 1.613112 \n", + "3 FOODS_1_001_CA_1_evaluation d_1540 1.613112 \n", + "4 FOODS_1_001_CA_1_evaluation d_1541 1.613112 \n", + "... ... ... ... \n", + "47735392 HOUSEHOLD_2_516_WI_3_evaluation d_52 0.261486 \n", + "47735393 HOUSEHOLD_2_516_WI_3_evaluation d_53 0.261486 \n", + "47735394 HOUSEHOLD_2_516_WI_3_evaluation d_54 0.261486 \n", + "47735395 HOUSEHOLD_2_516_WI_3_evaluation d_55 0.261486 \n", + "47735396 HOUSEHOLD_2_516_WI_3_evaluation d_49 0.261486 \n", + "\n", + " enc_store_id_dept_id_std enc_item_id_state_id_mean \\\n", + "0 3.216672 0.873390 \n", + "1 3.216672 0.873390 \n", + "2 3.216672 0.873390 \n", + "3 3.216672 0.873390 \n", + "4 3.216672 0.873390 \n", + "... ... ... \n", + "47735392 0.666380 0.083276 \n", + "47735393 0.666380 0.083276 \n", + "47735394 0.666380 0.083276 \n", + "47735395 0.666380 0.083276 \n", + "47735396 0.666380 0.083276 \n", + "\n", + " enc_item_id_state_id_std \n", + "0 1.666305 \n", + "1 1.666305 \n", + "2 1.666305 \n", + "3 1.666305 \n", + "4 1.666305 \n", + "... ... \n", + "47735392 0.301445 \n", + "47735393 0.301445 \n", + "47735394 0.301445 \n", + "47735395 0.301445 \n", + "47735396 0.301445 \n", + "\n", + "[47735397 rows x 6 columns]" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "grid_df = grid_df[[\"id\", \"day_id\"] + new_columns]\n", + "grid_df" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "44fc899c-9567-4d70-a5d8-cd37648900b3", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "id category\n", + "day_id category\n", + "enc_store_id_dept_id_mean float32\n", + "enc_store_id_dept_id_std float32\n", + "enc_item_id_state_id_mean float32\n", + "enc_item_id_state_id_std float32\n", + "dtype: object" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "grid_df.dtypes" + ] + }, + { + "cell_type": "markdown", + "id": "8ebcfb76-d807-4fb1-93ab-e30b2e38c173", + "metadata": {}, + "source": [ + "Once we computed the target encoding, we persist the table to the disk." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "c5cfa397-a31f-4f28-aecd-255bf94ea90b", + "metadata": {}, + "outputs": [], + "source": [ + "grid_df.to_pandas().to_pickle(processed_data_dir + \"target_encoding_df.pkl\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "438484e3-deb4-437f-89cd-7603c1f8b9f7", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/source/examples/time-series-forecasting-with-hpo/preprocessing_part6.ipynb b/source/examples/time-series-forecasting-with-hpo/preprocessing_part6.ipynb new file mode 100644 index 00000000..028d28e3 --- /dev/null +++ b/source/examples/time-series-forecasting-with-hpo/preprocessing_part6.ipynb @@ -0,0 +1,427 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "a7f7926c-e8f7-41f0-98a4-1211da546adc", + "metadata": {}, + "source": [ + "# Data preprocesing, Part 6" + ] + }, + { + "cell_type": "markdown", + "id": "d21c08f5-07b8-4f01-b584-d0c6d2bdc86e", + "metadata": {}, + "source": [ + "## Import modules" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "594c2c7f-3fff-4cba-86e7-8f3031ea21d9", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import cudf\n", + "import numpy as np\n", + "import pandas as pd\n", + "import gc\n", + "import glob\n", + "import pathlib\n", + "import gcsfs" + ] + }, + { + "cell_type": "markdown", + "id": "a33cb0b4-9e25-4965-8313-42707160e4fd", + "metadata": {}, + "source": [ + "Enter the name of the Cloud Storage bucket you used in `start_here.ipynb`." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "9fc13697-f9f2-4931-8fee-85148c600d89", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "bucket_name = \"\"" + ] + }, + { + "cell_type": "markdown", + "id": "23de1e65-8e82-4339-baa6-696abd247f22", + "metadata": {}, + "source": [ + "## Filter by store and product department and create data segments" + ] + }, + { + "cell_type": "markdown", + "id": "bfe723f6-7c62-4de5-99c6-c3a31253be61", + "metadata": {}, + "source": [ + "After combining all columns produced in the previous notebooks, we filter the rows in the data set by `store_id` and `dept_id` and create a segment. Each segment is saved as a pickle file and then upload to Cloud Storage." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "dc332147-56fb-4f11-95f1-12102aa6f1cf", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "processed_data_dir = \"./processed_data/\"\n", + "segmented_data_dir = \"./segmented_data/\"\n", + "pathlib.Path(segmented_data_dir).mkdir(exist_ok=True)\n", + "\n", + "STORES = [\"CA_1\", \"CA_2\", \"CA_3\", \"CA_4\", \"TX_1\", \"TX_2\", \"TX_3\", \"WI_1\", \"WI_2\", \"WI_3\"]\n", + "DEPTS = [\"HOBBIES_1\", \"HOBBIES_2\", \"HOUSEHOLD_1\", \"HOUSEHOLD_2\", \"FOODS_1\", \"FOODS_2\", \"FOODS_3\"]\n", + "\n", + "grid2_colnm = [\"sell_price\", \"price_max\", \"price_min\", \"price_std\",\n", + " \"price_mean\", \"price_norm\", \"price_nunique\", \"item_nunique\",\n", + " \"price_momentum\", \"price_momentum_m\", \"price_momentum_y\"]\n", + "\n", + "grid3_colnm = [\"event_name_1\", \"event_type_1\", \"event_name_2\",\n", + " \"event_type_2\", \"snap_CA\", \"snap_TX\", \"snap_WI\", \"tm_d\", \"tm_w\", \"tm_m\",\n", + " \"tm_y\", \"tm_wm\", \"tm_dw\", \"tm_w_end\"]\n", + "\n", + "lag_colnm = [\"sales_lag_28\", \"sales_lag_29\", \"sales_lag_30\",\n", + " \"sales_lag_31\", \"sales_lag_32\", \"sales_lag_33\", \"sales_lag_34\",\n", + " \"sales_lag_35\", \"sales_lag_36\", \"sales_lag_37\", \"sales_lag_38\",\n", + " \"sales_lag_39\", \"sales_lag_40\", \"sales_lag_41\", \"sales_lag_42\",\n", + " \"rolling_mean_7\", \"rolling_std_7\", \"rolling_mean_14\", \"rolling_std_14\",\n", + " \"rolling_mean_30\", \"rolling_std_30\", \"rolling_mean_60\",\n", + " \"rolling_std_60\", \"rolling_mean_180\", \"rolling_std_180\"]\n", + "\n", + "target_enc_colnm = [\n", + " \"enc_store_id_dept_id_mean\", \"enc_store_id_dept_id_std\",\n", + " \"enc_item_id_state_id_mean\", \"enc_item_id_state_id_std\",\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "36fbfd16-83e6-42ab-a337-5dd9a009cd7a", + "metadata": {}, + "outputs": [], + "source": [ + "def prepare_data(store, dept=None):\n", + " \"\"\"\n", + " Filter and clean data according to stores and product departments\n", + "\n", + " Parameters\n", + " ----------\n", + " store: Filter data by retaining rows whose store_id matches this parameter.\n", + " dept: Filter data by retaining rows whose dept_id matches this parameter.\n", + " This parameter can be set to None to indicate that we shouldn't filter by dept_id.\n", + " \"\"\"\n", + " if store is None:\n", + " raise ValueError(f\"store parameter must not be None\")\n", + "\n", + " grid1 = cudf.DataFrame(pd.read_pickle(processed_data_dir + \"grid_df_part1.pkl\"))\n", + "\n", + " if dept is None:\n", + " grid1 = grid1[grid1[\"store_id\"] == store]\n", + " else:\n", + " grid1 = grid1[(grid1[\"store_id\"] == store) & (grid1[\"dept_id\"] == dept)].drop(columns=[\"dept_id\"])\n", + " grid1 = grid1.drop(columns=[\"release_week\", \"wm_yr_wk\", \"store_id\", \"state_id\"])\n", + "\n", + " grid2 = cudf.DataFrame(pd.read_pickle(processed_data_dir + \"grid_df_part2.pkl\"))[[\"id\", \"day_id\"] + grid2_colnm]\n", + " grid_df = grid1.merge(grid2, on=[\"id\", \"day_id\"], how=\"left\")\n", + " del grid1, grid2\n", + "\n", + " grid3 = cudf.DataFrame(pd.read_pickle(processed_data_dir + \"grid_df_part3.pkl\"))[[\"id\", \"day_id\"] + grid3_colnm]\n", + " grid_df = grid_df.merge(grid3, on=[\"id\", \"day_id\"], how=\"left\")\n", + " del grid3\n", + "\n", + " lag_df = cudf.DataFrame(pd.read_pickle(processed_data_dir + \"lags_df_28.pkl\"))[[\"id\", \"day_id\"] + lag_colnm]\n", + "\n", + " grid_df = grid_df.merge(lag_df, on=[\"id\", \"day_id\"], how=\"left\")\n", + " del lag_df\n", + "\n", + " target_enc_df = cudf.DataFrame(pd.read_pickle(processed_data_dir + \"target_encoding_df.pkl\"))[[\"id\", \"day_id\"] + target_enc_colnm]\n", + "\n", + " grid_df = grid_df.merge(target_enc_df, on=[\"id\", \"day_id\"], how=\"left\")\n", + " del target_enc_df\n", + " gc.collect()\n", + "\n", + " grid_df = grid_df.drop(columns=[\"id\"])\n", + " grid_df[\"day_id\"] = grid_df[\"day_id\"].to_pandas().astype(\"str\").apply(lambda x: x[2:]).astype(np.int16)\n", + "\n", + " return grid_df" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "9e61b762-2722-4b83-9220-326006880acd", + "metadata": { + "scrolled": true, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Processing store CA_1...\n", + "Processing store CA_2...\n", + "Processing store CA_3...\n", + "Processing store CA_4...\n", + "Processing store TX_1...\n", + "Processing store TX_2...\n", + "Processing store TX_3...\n", + "Processing store WI_1...\n", + "Processing store WI_2...\n", + "Processing store WI_3...\n", + "Processing (store CA_1, department HOBBIES_1)...\n", + "Processing (store CA_1, department HOBBIES_2)...\n", + "Processing (store CA_1, department HOUSEHOLD_1)...\n", + "Processing (store CA_1, department HOUSEHOLD_2)...\n", + "Processing (store CA_1, department FOODS_1)...\n", + "Processing (store CA_1, department FOODS_2)...\n", + "Processing (store CA_1, department FOODS_3)...\n", + "Processing (store CA_2, department HOBBIES_1)...\n", + "Processing (store CA_2, department HOBBIES_2)...\n", + "Processing (store CA_2, department HOUSEHOLD_1)...\n", + "Processing (store CA_2, department HOUSEHOLD_2)...\n", + "Processing (store CA_2, department FOODS_1)...\n", + "Processing (store CA_2, department FOODS_2)...\n", + "Processing (store CA_2, department FOODS_3)...\n", + "Processing (store CA_3, department HOBBIES_1)...\n", + "Processing (store CA_3, department HOBBIES_2)...\n", + "Processing (store CA_3, department HOUSEHOLD_1)...\n", + "Processing (store CA_3, department HOUSEHOLD_2)...\n", + "Processing (store CA_3, department FOODS_1)...\n", + "Processing (store CA_3, department FOODS_2)...\n", + "Processing (store CA_3, department FOODS_3)...\n", + "Processing (store CA_4, department HOBBIES_1)...\n", + "Processing (store CA_4, department HOBBIES_2)...\n", + "Processing (store CA_4, department HOUSEHOLD_1)...\n", + "Processing (store CA_4, department HOUSEHOLD_2)...\n", + "Processing (store CA_4, department FOODS_1)...\n", + "Processing (store CA_4, department FOODS_2)...\n", + "Processing (store CA_4, department FOODS_3)...\n", + "Processing (store TX_1, department HOBBIES_1)...\n", + "Processing (store TX_1, department HOBBIES_2)...\n", + "Processing (store TX_1, department HOUSEHOLD_1)...\n", + "Processing (store TX_1, department HOUSEHOLD_2)...\n", + "Processing (store TX_1, department FOODS_1)...\n", + "Processing (store TX_1, department FOODS_2)...\n", + "Processing (store TX_1, department FOODS_3)...\n", + "Processing (store TX_2, department HOBBIES_1)...\n", + "Processing (store TX_2, department HOBBIES_2)...\n", + "Processing (store TX_2, department HOUSEHOLD_1)...\n", + "Processing (store TX_2, department HOUSEHOLD_2)...\n", + "Processing (store TX_2, department FOODS_1)...\n", + "Processing (store TX_2, department FOODS_2)...\n", + "Processing (store TX_2, department FOODS_3)...\n", + "Processing (store TX_3, department HOBBIES_1)...\n", + "Processing (store TX_3, department HOBBIES_2)...\n", + "Processing (store TX_3, department HOUSEHOLD_1)...\n", + "Processing (store TX_3, department HOUSEHOLD_2)...\n", + "Processing (store TX_3, department FOODS_1)...\n", + "Processing (store TX_3, department FOODS_2)...\n", + "Processing (store TX_3, department FOODS_3)...\n", + "Processing (store WI_1, department HOBBIES_1)...\n", + "Processing (store WI_1, department HOBBIES_2)...\n", + "Processing (store WI_1, department HOUSEHOLD_1)...\n", + "Processing (store WI_1, department HOUSEHOLD_2)...\n", + "Processing (store WI_1, department FOODS_1)...\n", + "Processing (store WI_1, department FOODS_2)...\n", + "Processing (store WI_1, department FOODS_3)...\n", + "Processing (store WI_2, department HOBBIES_1)...\n", + "Processing (store WI_2, department HOBBIES_2)...\n", + "Processing (store WI_2, department HOUSEHOLD_1)...\n", + "Processing (store WI_2, department HOUSEHOLD_2)...\n", + "Processing (store WI_2, department FOODS_1)...\n", + "Processing (store WI_2, department FOODS_2)...\n", + "Processing (store WI_2, department FOODS_3)...\n", + "Processing (store WI_3, department HOBBIES_1)...\n", + "Processing (store WI_3, department HOBBIES_2)...\n", + "Processing (store WI_3, department HOUSEHOLD_1)...\n", + "Processing (store WI_3, department HOUSEHOLD_2)...\n", + "Processing (store WI_3, department FOODS_1)...\n", + "Processing (store WI_3, department FOODS_2)...\n", + "Processing (store WI_3, department FOODS_3)...\n" + ] + } + ], + "source": [ + "# First save the segment to the disk\n", + "for store in STORES:\n", + " print(f\"Processing store {store}...\")\n", + " grid_df = prepare_data(store=store)\n", + " grid_df.to_pandas().to_pickle(segmented_data_dir + f\"combined_df_store_{store}.pkl\")\n", + " del grid_df\n", + " gc.collect()\n", + "\n", + "for store in STORES:\n", + " for dept in DEPTS:\n", + " print(f\"Processing (store {store}, department {dept})...\")\n", + " grid_df = prepare_data(store=store, dept=dept)\n", + " grid_df.to_pandas().to_pickle(segmented_data_dir + f\"combined_df_store_{store}_dept_{dept}.pkl\")\n", + " del grid_df\n", + " gc.collect()" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "6e12341b-c879-4012-8c73-8316278b9a6f", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Uploading ./segmented_data/combined_df_store_WI_1_dept_FOODS_3.pkl...\n", + "Uploading ./segmented_data/combined_df_store_TX_2_dept_FOODS_3.pkl...\n", + "Uploading ./segmented_data/combined_df_store_CA_2_dept_HOBBIES_2.pkl...\n", + "Uploading ./segmented_data/combined_df_store_WI_2_dept_HOBBIES_2.pkl...\n", + "Uploading ./segmented_data/combined_df_store_CA_1_dept_FOODS_3.pkl...\n", + "Uploading ./segmented_data/combined_df_store_TX_1_dept_FOODS_1.pkl...\n", + "Uploading ./segmented_data/combined_df_store_CA_3_dept_HOBBIES_2.pkl...\n", + "Uploading ./segmented_data/combined_df_store_CA_3_dept_HOUSEHOLD_2.pkl...\n", + "Uploading ./segmented_data/combined_df_store_WI_2_dept_HOBBIES_1.pkl...\n", + "Uploading ./segmented_data/combined_df_store_TX_1.pkl...\n", + "Uploading ./segmented_data/combined_df_store_CA_4_dept_HOBBIES_1.pkl...\n", + "Uploading ./segmented_data/combined_df_store_WI_3_dept_FOODS_1.pkl...\n", + "Uploading ./segmented_data/combined_df_store_WI_1_dept_HOBBIES_2.pkl...\n", + "Uploading ./segmented_data/combined_df_store_WI_1_dept_HOUSEHOLD_2.pkl...\n", + "Uploading ./segmented_data/combined_df_store_CA_4_dept_FOODS_1.pkl...\n", + "Uploading ./segmented_data/combined_df_store_WI_1_dept_HOUSEHOLD_1.pkl...\n", + "Uploading ./segmented_data/combined_df_store_TX_3_dept_HOUSEHOLD_1.pkl...\n", + "Uploading ./segmented_data/combined_df_store_WI_1.pkl...\n", + "Uploading ./segmented_data/combined_df_store_TX_3_dept_FOODS_1.pkl...\n", + "Uploading ./segmented_data/combined_df_store_CA_1_dept_HOBBIES_2.pkl...\n", + "Uploading ./segmented_data/combined_df_store_TX_1_dept_FOODS_3.pkl...\n", + "Uploading ./segmented_data/combined_df_store_CA_4.pkl...\n", + "Uploading ./segmented_data/combined_df_store_TX_1_dept_HOBBIES_1.pkl...\n", + "Uploading ./segmented_data/combined_df_store_WI_2_dept_HOUSEHOLD_1.pkl...\n", + "Uploading ./segmented_data/combined_df_store_WI_3_dept_FOODS_2.pkl...\n", + "Uploading ./segmented_data/combined_df_store_TX_1_dept_FOODS_2.pkl...\n", + "Uploading ./segmented_data/combined_df_store_CA_1_dept_HOBBIES_1.pkl...\n", + "Uploading ./segmented_data/combined_df_store_TX_2_dept_HOBBIES_2.pkl...\n", + "Uploading ./segmented_data/combined_df_store_CA_2_dept_FOODS_2.pkl...\n", + "Uploading ./segmented_data/combined_df_store_WI_3_dept_HOUSEHOLD_1.pkl...\n", + "Uploading ./segmented_data/combined_df_store_CA_3_dept_HOBBIES_1.pkl...\n", + "Uploading ./segmented_data/combined_df_store_CA_2_dept_FOODS_1.pkl...\n", + "Uploading ./segmented_data/combined_df_store_WI_2_dept_HOUSEHOLD_2.pkl...\n", + "Uploading ./segmented_data/combined_df_store_WI_3_dept_HOBBIES_1.pkl...\n", + "Uploading ./segmented_data/combined_df_store_CA_4_dept_FOODS_2.pkl...\n", + "Uploading ./segmented_data/combined_df_store_CA_3.pkl...\n", + "Uploading ./segmented_data/combined_df_store_CA_4_dept_FOODS_3.pkl...\n", + "Uploading ./segmented_data/combined_df_store_TX_3_dept_HOBBIES_1.pkl...\n", + "Uploading ./segmented_data/combined_df_store_TX_1_dept_HOUSEHOLD_2.pkl...\n", + "Uploading ./segmented_data/combined_df_store_TX_2.pkl...\n", + "Uploading ./segmented_data/combined_df_store_CA_2.pkl...\n", + "Uploading ./segmented_data/combined_df_store_CA_3_dept_HOUSEHOLD_1.pkl...\n", + "Uploading ./segmented_data/combined_df_store_CA_4_dept_HOBBIES_2.pkl...\n", + "Uploading ./segmented_data/combined_df_store_CA_1_dept_HOUSEHOLD_1.pkl...\n", + "Uploading ./segmented_data/combined_df_store_TX_3_dept_FOODS_2.pkl...\n", + "Uploading ./segmented_data/combined_df_store_WI_3_dept_HOUSEHOLD_2.pkl...\n", + "Uploading ./segmented_data/combined_df_store_CA_2_dept_HOUSEHOLD_1.pkl...\n", + "Uploading ./segmented_data/combined_df_store_TX_2_dept_HOBBIES_1.pkl...\n", + "Uploading ./segmented_data/combined_df_store_CA_1_dept_HOUSEHOLD_2.pkl...\n", + "Uploading ./segmented_data/combined_df_store_CA_1_dept_FOODS_1.pkl...\n", + "Uploading ./segmented_data/combined_df_store_CA_3_dept_FOODS_1.pkl...\n", + "Uploading ./segmented_data/combined_df_store_TX_1_dept_HOUSEHOLD_1.pkl...\n", + "Uploading ./segmented_data/combined_df_store_TX_2_dept_HOUSEHOLD_2.pkl...\n", + "Uploading ./segmented_data/combined_df_store_CA_4_dept_HOUSEHOLD_1.pkl...\n", + "Uploading ./segmented_data/combined_df_store_WI_1_dept_FOODS_2.pkl...\n", + "Uploading ./segmented_data/combined_df_store_CA_2_dept_HOUSEHOLD_2.pkl...\n", + "Uploading ./segmented_data/combined_df_store_CA_2_dept_HOBBIES_1.pkl...\n", + "Uploading ./segmented_data/combined_df_store_CA_1.pkl...\n", + "Uploading ./segmented_data/combined_df_store_CA_1_dept_FOODS_2.pkl...\n", + "Uploading ./segmented_data/combined_df_store_TX_3.pkl...\n", + "Uploading ./segmented_data/combined_df_store_TX_3_dept_FOODS_3.pkl...\n", + "Uploading ./segmented_data/combined_df_store_CA_2_dept_FOODS_3.pkl...\n", + "Uploading ./segmented_data/combined_df_store_WI_2_dept_FOODS_3.pkl...\n", + "Uploading ./segmented_data/combined_df_store_TX_2_dept_HOUSEHOLD_1.pkl...\n", + "Uploading ./segmented_data/combined_df_store_WI_2_dept_FOODS_1.pkl...\n", + "Uploading ./segmented_data/combined_df_store_TX_2_dept_FOODS_1.pkl...\n", + "Uploading ./segmented_data/combined_df_store_TX_1_dept_HOBBIES_2.pkl...\n", + "Uploading ./segmented_data/combined_df_store_WI_3_dept_HOBBIES_2.pkl...\n", + "Uploading ./segmented_data/combined_df_store_WI_3.pkl...\n", + "Uploading ./segmented_data/combined_df_store_CA_3_dept_FOODS_3.pkl...\n", + "Uploading ./segmented_data/combined_df_store_CA_4_dept_HOUSEHOLD_2.pkl...\n", + "Uploading ./segmented_data/combined_df_store_CA_3_dept_FOODS_2.pkl...\n", + "Uploading ./segmented_data/combined_df_store_TX_3_dept_HOUSEHOLD_2.pkl...\n", + "Uploading ./segmented_data/combined_df_store_WI_2.pkl...\n", + "Uploading ./segmented_data/combined_df_store_WI_2_dept_FOODS_2.pkl...\n", + "Uploading ./segmented_data/combined_df_store_WI_1_dept_FOODS_1.pkl...\n", + "Uploading ./segmented_data/combined_df_store_TX_2_dept_FOODS_2.pkl...\n", + "Uploading ./segmented_data/combined_df_store_TX_3_dept_HOBBIES_2.pkl...\n", + "Uploading ./segmented_data/combined_df_store_WI_1_dept_HOBBIES_1.pkl...\n", + "Uploading ./segmented_data/combined_df_store_WI_3_dept_FOODS_3.pkl...\n" + ] + } + ], + "source": [ + "# Then copy the segment to Cloud Storage\n", + "fs = gcsfs.GCSFileSystem()\n", + "\n", + "for e in glob.glob(segmented_data_dir + \"*\"):\n", + " print(f\"Uploading {e}...\")\n", + " basename = pathlib.Path(e).name\n", + " fs.put_file(e, f\"{bucket_name}/{basename}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "58eb85e7-3d1d-4f72-8c9a-57b672344d73", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Also upload the product weights\n", + "fs = gcsfs.GCSFileSystem()\n", + "fs.put_file(processed_data_dir + \"product_weights.pkl\", f\"{bucket_name}/product_weights.pkl\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b1cf03ee-3e08-429d-bff0-8023ee6de2af", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/source/examples/time-series-forecasting-with-hpo/start_here.ipynb b/source/examples/time-series-forecasting-with-hpo/start_here.ipynb new file mode 100644 index 00000000..e23c4c7e --- /dev/null +++ b/source/examples/time-series-forecasting-with-hpo/start_here.ipynb @@ -0,0 +1,305 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "671dd603-6b51-46b2-98b3-2b05c7c92c38", + "metadata": {}, + "source": [ + "# Perform time series forecasting on Google Kubernetes Engine with NVIDIA GPUs" + ] + }, + { + "cell_type": "markdown", + "id": "e18814f8-5374-468f-9877-28275dbf20d6", + "metadata": {}, + "source": [ + "In this example, we will be looking at a real-world example of **time series forecasting** with data from [the M5 Forecasting Competition](https://www.kaggle.com/competitions/m5-forecasting-accuracy). Walmart provides historical sales data from multiple stores in three states, and our job is to predict the sales in a future 28-day period." + ] + }, + { + "cell_type": "markdown", + "id": "0af8dc16-70a9-41cc-b846-88e83a2aa660", + "metadata": {}, + "source": [ + "## Prerequisites" + ] + }, + { + "cell_type": "markdown", + "id": "9168f224-8549-41a7-a6a7-2023c83bb466", + "metadata": {}, + "source": [ + "### Prepare GKE cluster" + ] + }, + { + "cell_type": "markdown", + "id": "fde94ab8-123c-4a72-95db-86a2098247bb", + "metadata": {}, + "source": [ + "To run the example, you will need a working Google Kubernetes Engine (GKE) cluster with access to NVIDIA GPUs. Use the following resources to set up a cluster:\n", + "\n", + "* [Set up a GKE cluster with access to NVIDIA GPUs](https://docs.rapids.ai/deployment/stable/cloud/gcp/gke/)\n", + "* [Install the Dask-Kubernetes operator](https://kubernetes.dask.org/en/latest/operator_installation.html)\n", + "* [Install Kubeflow](https://www.kubeflow.org/docs/started/installing-kubeflow/)\n", + "\n", + "Kubeflow is not strictly necessary, but we highly recommend it, as Kubeflow gives you a nice notebook environment to run this notebook within the k8s cluster. (You may choose any method; we tested this example after installing Kubeflow from manifests.) When creating the notebook environment, use the following configuration:\n", + "\n", + "* 2 CPUs, 16 GiB of memory\n", + "* 1 NVIDIA GPU\n", + "* 40 GiB disk volume\n", + "\n", + "After uploading all the notebooks in the example, run this notebook (`start_here.ipynb`) in the notebook environment.\n", + "\n", + "Note: We will use the worker pods to speed up the training stage. The preprocessing steps will run solely on the scheduler node." + ] + }, + { + "cell_type": "markdown", + "id": "79b63586-119e-47ee-ba74-063cfca71fe0", + "metadata": {}, + "source": [ + "### Prepare a bucket in Google Cloud Storage" + ] + }, + { + "cell_type": "markdown", + "id": "f1c4ada1-60c5-4456-b82d-953ce499a2fa", + "metadata": {}, + "source": [ + "Create a new bucket in Google Cloud Storage. Make sure that the worker pods in the k8s cluster has read/write access to this bucket. This can be done in one of the following methods:\n", + "\n", + "1. Option 1: Specify an additional scope when provisioning the GKE cluster.\n", + "\n", + " When you are provisioning a new GKE cluster, add the `storage-rw` scope.\n", + " This option is only available if you are creating a new cluster from scratch. If you are using an exising GKE cluster, see Option 2.\n", + "\n", + " Example:\n", + "```\n", + "gcloud container clusters create my_new_cluster --accelerator type=nvidia-tesla-t4 \\\n", + " --machine-type n1-standard-32 --zone us-central1-c --release-channel stable \\\n", + " --num-nodes 5 --scopes=gke-default,storage-rw\n", + "```\n", + "\n", + "2. Option 2: Grant bucket access to the associated service account.\n", + "\n", + " Find out which service account is associated with your GKE cluster. You can grant the bucket access to the service account as follows: Nagivate to the Cloud Storage console, open the Bucket Details page for the bucket, open the Permissions tab, and click on Grant Access.\n", + " \n", + "Enter the name of the bucket that your cluster has read-write access to:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "45bc19b6-f5f9-4f55-827b-db9484c125ce", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "bucket_name = \"\"" + ] + }, + { + "cell_type": "markdown", + "id": "704898b3-e2b0-4b40-bcd3-178a0bdeee6f", + "metadata": {}, + "source": [ + "### Install Python packages in the notebook environment" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "cdf7b111-3aba-4fae-b805-fb3063d5a621", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "!pip install kaggle gcsfs dask-kubernetes optuna" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "13737fe2-3df5-4614-8675-fb20bccf7a19", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "[]" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Test if the bucket is accessible\n", + "import gcsfs\n", + "\n", + "fs = gcsfs.GCSFileSystem()\n", + "fs.ls(f\"{bucket_name}/\")" + ] + }, + { + "cell_type": "markdown", + "id": "42f2274c-90a0-484f-ad71-33ddba170f09", + "metadata": {}, + "source": [ + "## Obtain the time series data set from Kaggle" + ] + }, + { + "cell_type": "markdown", + "id": "21cde5ee-dad0-4ee2-9c26-7a36b6bcee49", + "metadata": {}, + "source": [ + "If you do not yet have an account with Kaggle, create one now. Then follow instructions in [Public API Documentation of Kaggle](https://www.kaggle.com/docs/api) to obtain the API key. This step is needed to obtain the training data from the M5 Forecasting Competition. Once you obtained the API key, fill in the following:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "8c2eb62a-a0f5-4801-9254-367fcef05e05", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "kaggle_username = \"\"\n", + "kaggle_api_key = \"\"" + ] + }, + { + "cell_type": "markdown", + "id": "64fd487f-0de1-4c41-ba02-7189493739f3", + "metadata": {}, + "source": [ + "Now we are ready to download the data set:" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "3c6fd3fe-b2d3-4bff-88f5-b86de68878a6", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "%env KAGGLE_USERNAME=$kaggle_username\n", + "%env KAGGLE_KEY=$kaggle_api_key\n", + "\n", + "!kaggle competitions download -c m5-forecasting-accuracy" + ] + }, + { + "cell_type": "markdown", + "id": "b7141384-e825-4387-a9f5-42b932e27e65", + "metadata": {}, + "source": [ + "Let's unzip the ZIP archive and see what's inside." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "b7bf2248-cd3b-47c5-8923-ec9a6e868a49", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Archive: m5-forecasting-accuracy.zip\n", + " inflating: data/calendar.csv \n", + " inflating: data/sales_train_evaluation.csv \n", + " inflating: data/sales_train_validation.csv \n", + " inflating: data/sample_submission.csv \n", + " inflating: data/sell_prices.csv \n" + ] + } + ], + "source": [ + "!unzip m5-forecasting-accuracy.zip -d data/" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "673468ae-9f6d-499f-86c4-230021bbf1b0", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "-rw-r--r-- 1 root users 102K Jun 1 2020 data/calendar.csv\n", + "-rw-r--r-- 1 root users 117M Jun 1 2020 data/sales_train_evaluation.csv\n", + "-rw-r--r-- 1 root users 115M Jun 1 2020 data/sales_train_validation.csv\n", + "-rw-r--r-- 1 root users 5.0M Jun 1 2020 data/sample_submission.csv\n", + "-rw-r--r-- 1 root users 194M Jun 1 2020 data/sell_prices.csv\n" + ] + } + ], + "source": [ + "!ls -lh data/*.csv" + ] + }, + { + "cell_type": "markdown", + "id": "f304ea68-381f-45b4-9e27-201a35e31239", + "metadata": {}, + "source": [ + "# Next steps" + ] + }, + { + "cell_type": "markdown", + "id": "d9903e47-2a83-40d1-b65b-d5818e9f0647", + "metadata": {}, + "source": [ + "We are now ready to run the preprocessing steps. You should run the six notebooks in order, to process the raw data into a form that can be used for model training:\n", + "\n", + "* `preprocessing_part1.ipynb`\n", + "* `preprocessing_part2.ipynb`\n", + "* `preprocessing_part3.ipynb`\n", + "* `preprocessing_part4.ipynb`\n", + "* `preprocessing_part5.ipynb`\n", + "* `preprocessing_part6.ipynb`\n", + "* `training_and_evaluation.ipynb`\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/source/examples/time-series-forecasting-with-hpo/training_and_evaluation.ipynb b/source/examples/time-series-forecasting-with-hpo/training_and_evaluation.ipynb new file mode 100644 index 00000000..db16e6d6 --- /dev/null +++ b/source/examples/time-series-forecasting-with-hpo/training_and_evaluation.ipynb @@ -0,0 +1,1682 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "328140a4-b668-4a5b-8de1-24c4392d22f1", + "metadata": {}, + "source": [ + "# Train an XGBoost model with retail sales forecasting with hyperparameter search" + ] + }, + { + "cell_type": "markdown", + "id": "770519b0-25d2-4786-933f-f0e16b5c4b18", + "metadata": {}, + "source": [ + "Now that we finished processing the data, we are now ready to train a model to forecast future sales. We will leverage the worker pods to run multiple training jobs in parallel, speeding up the hyperparameter search." + ] + }, + { + "cell_type": "markdown", + "id": "da47c5a5-3777-4bf7-8312-1da91a0d81a3", + "metadata": {}, + "source": [ + "## Import modules and define constants" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "c391837b-6102-4d45-820e-b83091285a71", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import cudf\n", + "import gcsfs\n", + "import xgboost as xgb\n", + "import pandas as pd\n", + "import numpy as np\n", + "import optuna\n", + "import gc\n", + "import time\n", + "import pickle \n", + "import copy\n", + "import json\n", + "\n", + "import matplotlib.pyplot as plt\n", + "from matplotlib.patches import Patch\n", + "import matplotlib\n", + "\n", + "from dask.distributed import wait\n", + "from dask_kubernetes.operator import KubeCluster\n", + "from dask.distributed import Client" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "1954ff43-dcd5-4628-80f2-64f85f253342", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Choose the same RAPIDS image you used for launching the notebook session\n", + "rapids_image = \"rapidsai/rapidsai-core-nightly:23.08-cuda11.8-runtime-ubuntu22.04-py3.10\"\n", + "# Use the number of worker nodes in your Kubernetes cluster.\n", + "n_workers = 3\n", + "# Bucket that contains the processed data pickles, refer to start_here.ipynb\n", + "bucket_name = \"\"\n", + "\n", + "# List of stores and product departments\n", + "STORES = [\"CA_1\", \"CA_2\", \"CA_3\", \"CA_4\", \"TX_1\", \"TX_2\", \"TX_3\", \"WI_1\", \"WI_2\", \"WI_3\"]\n", + "DEPTS = [\"HOBBIES_1\", \"HOBBIES_2\", \"HOUSEHOLD_1\", \"HOUSEHOLD_2\", \"FOODS_1\", \"FOODS_2\", \"FOODS_3\"]" + ] + }, + { + "cell_type": "markdown", + "id": "084ac137-1fb0-43e3-9a61-7aa259d05ca8", + "metadata": {}, + "source": [ + "## Define cross-validation folds" + ] + }, + { + "cell_type": "markdown", + "id": "cee3f346-afbd-41ea-a91e-12294594b6a6", + "metadata": {}, + "source": [ + "[**Cross-validation**](https://en.wikipedia.org/wiki/Cross-validation_(statistics)) is a statistical method for estimating how well a machine learning model generalizes to an independent data set. The method is also useful for evaluating the choice of a given combination of model hyperparameters.\n", + "\n", + "To estimate the capacity to generalize, we define multiple cross-validation **folds** consisting of mulitple pairs of `(training set, validation set)`. For each fold, we fit a model using the training set and evaluate its accuracy on the validation set. The \"goodness\" score for a given hyperparameter combination is the accuracy of the model on each validation set, averaged over all cross-validation folds.\n", + "\n", + "Great care must be taken when defining cross-validation folds for time-series data. We are not allowed to use the future to predict the past, so the training set must precede (in time) the validation set. Consequently, we partition the data set in the time dimension and assign the training and validation sets using time ranges:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "423780b7-8431-4546-85c7-84e4f269d5e1", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Cross-validation folds and held-out test set (in time dimension)\n", + "# The held-out test set is used for final evaluation\n", + "cv_folds = [ # (train_set, validation_set)\n", + " ([0, 1114], [1114, 1314]),\n", + " ([0, 1314], [1314, 1514]),\n", + " ([0, 1514], [1514, 1714]),\n", + " ([0, 1714], [1714, 1914])\n", + "]\n", + "n_folds = len(cv_folds)\n", + "holdout = [1914, 1942]\n", + "time_horizon = 1942" + ] + }, + { + "cell_type": "markdown", + "id": "d5ed120a-ae41-4211-a517-7791e22df037", + "metadata": {}, + "source": [ + "It is helpful to visualize the cross-validation folds using Matplotlib." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "f7945265-f683-4722-8e22-df003b749eae", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAxQAAAEiCAYAAABgP5QIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAAA2+UlEQVR4nO3deXxNd+L/8fdNIhHZLZEgEXtC7VvplFBLtIyOztDWIEU6Ru10MLV201W1X0Nnphp0fMuYajs/2k6lCJrY0kaVSDGxR9MaEoRE5Pz+8M0dV7ab4yY35PV8PO7j4Z7zOefzOZ987nXf95xzPxbDMAwBAAAAgAkuzm4AAAAAgLsXgQIAAACAaQQKAAAAAKYRKAAAAACYRqAAAAAAYBqBAgAAAIBpBAoAAAAAphEoAAAAAJjm5uwG3A3y8/N19uxZ+fj4yGKxOLs5AAAAgMMYhqFLly6pXr16cnEp+/kGAoUdzp49q5CQEGc3AwAAACg3p06dUoMGDcq8HYHCDj4+PpJudrKvr6+TWwMAAAA4TlZWlkJCQqyfecuKQGGHgsucfH19CRQAAAC4J5m9tJ+bsgEAAACYRqAAAAAAYBqBAgAAAIBpBAoAAAAAphEoAAAAAJhGoAAAAABgGoECAAAAgGkECgAAAACmESgAAAAAmEagAAAAAGAagQIAAACAaQQKAAAAAKa5ObsBKF1K4jTl5V5wdjMAAHZYH1dXksXZzXA6X19fzXh2trObAaACcIbiLkCYAIC7CWFCkrKyspzdBAAVhEABAAAAwDQCBQAAAADTCBQAAAAATCNQAAAAADCNQAEAAADANAIFAAAAANMIFAAAAABMq9SBIjIyUlOmTCmxTFhYmJYsWVIh7QEAAABgq1wDRXR0tCwWS6HH0aNHy7PaQj766CO1bNlSHh4eatmypT7++OMKrR8AAAC4V5X7GYqoqCilp6fbPBo1alTe1VolJiZq2LBhGjFihPbv368RI0Zo6NCh2r17d4W1AQAAALhXlXug8PDwUFBQkM3D1dVVkhQfH68uXbrIw8NDwcHBmjVrlvLy8ordV0ZGhgYNGiRPT081atRIa9asKbX+JUuWqG/fvpo9e7bCw8M1e/ZsPfTQQ1wmBQAAADiA0+6hOHPmjB5++GF17txZ+/fv1/Lly7VixQq9+OKLxW4THR2t48ePa8uWLfrHP/6hZcuWKSMjo8R6EhMT1a9fP5tl/fv3V0JCQrHb5OTkKCsry+YBAAAAoDC38q5g48aN8vb2tj4fMGCA1q9fr2XLlikkJERLly6VxWJReHi4zp49q5kzZ2revHlycbHNOj/88IM+//xz7dq1S127dpUkrVixQhERESXWf+7cOdWtW9dmWd26dXXu3Llit1m0aJEWLlxY1kMFAAAAqpxyDxS9evXS8uXLrc+9vLwkSSkpKerWrZssFot13QMPPKDLly/r9OnTCg0NtdlPSkqK3Nzc1KlTJ+uy8PBw+fv7l9qGW+uQJMMwCi271ezZszVt2jTr86ysLIWEhJRaDwAAAFDVlHug8PLyUtOmTQstL+pDvWEYkgoHgNLWlSQoKKjQ2YiMjIxCZy1u5eHhIQ8PjzLVAwAAAFRFTruHomXLlkpISLAGBUlKSEiQj4+P6tevX6h8RESE8vLytG/fPuuy1NRUXbx4scR6unXrps2bN9ss+/LLL9W9e/c7OwAAAAAAzgsU48eP16lTpzRx4kQdPnxYn376qebPn69p06YVun9Cklq0aKGoqCjFxMRo9+7dSkpK0tixY+Xp6VliPZMnT9aXX36pV199VYcPH9arr76quLi4UifMAwAAAFA6pwWK+vXr67PPPtOePXvUtm1bjRs3TmPGjNGcOXOK3SY2NlYhISHq2bOnhgwZoqefflqBgYEl1tO9e3etXbtWsbGxatOmjVauXKl169ZZb+wGAAAAYJ7FuPWaIxQpKytLfn5+yszMlK+vb4XXfyD+qQqvEwBgzvq4IGc3odJ4/oVFzm4CADvc6Wddp52hAAAAAHD3I1AAAAAAMI1AAQAAAMA0AgUAAAAA0wgUAAAAAEwjUAAAAAAwjUBxF3BzD3B2EwAAduPX2CU55WfWATiHm7MbgNJFdFvs7CYAAOzUuqezWwAAFYszFAAAAABMI1AAAAAAMI1AAQAAAMA0AgUAAAAA0wgUAAAAAEwjUAAAAAAwjUABAAAAwDQCBQAAAADTCBQAAAAATCNQAAAAADCNQAEAAADANAIFAAAAANMIFAAAAABMc3N2A1C6lMRpysu94OxmAABgl4076uhqjoski7Ob4lQWi0ULn3/Z2c0Ayh2B4i5AmAAA3E2u5rg6uwmVgmEYzm4CUCG45AkAAACAaQQKAAAAAKYRKAAAAACYRqAAAAAAYBqBAgAAAIBpBAoAAAAAphEoAAAAAJhWqQNFZGSkpkyZUmKZsLAwLVmypELaAwAAAMBWuQaK6OhoWSyWQo+jR4+WZ7U2Dh48qMcee0xhYWGyWCyEDwAAAMCByv0MRVRUlNLT020ejRo1Ku9qrbKzs9W4cWO98sorCgoKqrB6AQAAgKqg3AOFh4eHgoKCbB6urq6SpPj4eHXp0kUeHh4KDg7WrFmzlJeXV+y+MjIyNGjQIHl6eqpRo0Zas2ZNqfV37txZr7/+uh5//HF5eHg47LgAAAAASG7OqvjMmTN6+OGHFR0drdWrV+vw4cOKiYlR9erVtWDBgiK3iY6O1qlTp7Rlyxa5u7tr0qRJysjIcHjbcnJylJOTY32elZXl8DoAAACAe0G5B4qNGzfK29vb+nzAgAFav369li1bppCQEC1dulQWi0Xh4eE6e/asZs6cqXnz5snFxfbkyQ8//KDPP/9cu3btUteuXSVJK1asUEREhMPbvGjRIi1cuNDh+wUAAADuNeUeKHr16qXly5dbn3t5eUmSUlJS1K1bN1ksFuu6Bx54QJcvX9bp06cVGhpqs5+UlBS5ubmpU6dO1mXh4eHy9/d3eJtnz56tadOmWZ9nZWUpJCTE4fUAAAAAd7tyDxReXl5q2rRpoeWGYdiEiYJlkgotL22do3l4eHC/BQAAAGAHp81D0bJlSyUkJFiDgiQlJCTIx8dH9evXL1Q+IiJCeXl52rdvn3VZamqqLl68WBHNBQAAAFAEpwWK8ePH69SpU5o4caIOHz6sTz/9VPPnz9e0adMK3T8hSS1atFBUVJRiYmK0e/duJSUlaezYsfL09CyxntzcXCUnJys5OVm5ubk6c+aMkpOTK3QuDAAAAOBe5bRAUb9+fX322Wfas2eP2rZtq3HjxmnMmDGaM2dOsdvExsYqJCREPXv21JAhQ/T0008rMDCwxHrOnj2r9u3bq3379kpPT9cbb7yh9u3ba+zYsY4+JAAAAKDKsRi3XnOEImVlZcnPz0+ZmZny9fWt8PoPxD9V4XUCAGDW+jgmki3w/AuLnN0EoFR3+lnXaWcoAAAAANz9CBQAAAAATCNQAAAAADCNQAEAAADANAIFAAAAANMIFAAAAABMI1DcBdzcA5zdBAAA7ObpcUMSv0pvsVic3QSgQrg5uwEoXUS3xc5uAgAAdmvd09ktAFCROEMBAAAAwDQCBQAAAADTCBQAAAAATCNQAAAAADCNQAEAAADANAIFAAAAANMIFAAAAABMI1AAAAAAMI1AAQAAAMA0AgUAAAAA0wgUAAAAAEwjUAAAAAAwjUABAAAAwDQ3ZzcApUtJnKa83AvObgYAACiD9XF1JVmc3Qyn8/X11YxnZzu7GShHnKG4CxAmAAC4GxEmJCkrK8vZTUA5I1AAAAAAMI1AAQAAAMA0AgUAAAAA0wgUAAAAAEwjUAAAAAAwjUABAAAAwDQCBQAAAADTKnWgiIyM1JQpU0osExYWpiVLllRIewAAAADYKtdAER0dLYvFUuhx9OjR8qzWxl//+lc9+OCDCggIUEBAgPr06aM9e/ZUWP0AAADAvazcz1BERUUpPT3d5tGoUaPyrtZq27ZteuKJJ7R161YlJiYqNDRU/fr105kzZyqsDQAAAMC9qtwDhYeHh4KCgmwerq6ukqT4+Hh16dJFHh4eCg4O1qxZs5SXl1fsvjIyMjRo0CB5enqqUaNGWrNmTan1r1mzRuPHj1e7du0UHh6uv/71r8rPz9dXX33lsGMEAAAAqio3Z1V85swZPfzww4qOjtbq1at1+PBhxcTEqHr16lqwYEGR20RHR+vUqVPasmWL3N3dNWnSJGVkZJSp3uzsbF2/fl01a9YstkxOTo5ycnKsz7OysspUBwAAAFBVlHug2Lhxo7y9va3PBwwYoPXr12vZsmUKCQnR0qVLZbFYFB4errNnz2rmzJmaN2+eXFxsT5788MMP+vzzz7Vr1y517dpVkrRixQpFRESUqT2zZs1S/fr11adPn2LLLFq0SAsXLizTfgEAAICqqNwDRa9evbR8+XLrcy8vL0lSSkqKunXrJovFYl33wAMP6PLlyzp9+rRCQ0Nt9pOSkiI3Nzd16tTJuiw8PFz+/v52t+W1117Thx9+qG3btql69erFlps9e7amTZtmfZ6VlaWQkBC76wEAAACqinIPFF5eXmratGmh5YZh2ISJgmWSCi0vbZ093njjDb388suKi4tTmzZtSizr4eEhDw8PU/UAAAAAVYnT5qFo2bKlEhISrEFBkhISEuTj46P69esXKh8REaG8vDzt27fPuiw1NVUXL14sta7XX39dL7zwgr744gubMxwAAAAA7ozTAsX48eN16tQpTZw4UYcPH9ann36q+fPna9q0aYXun5CkFi1aKCoqSjExMdq9e7eSkpI0duxYeXp6lljPa6+9pjlz5uj9999XWFiYzp07p3Pnzuny5cvldWgAAABAleG0QFG/fn199tln2rNnj9q2batx48ZpzJgxmjNnTrHbxMbGKiQkRD179tSQIUP09NNPKzAwsMR6li1bptzcXP36179WcHCw9fHGG284+pAAAACAKsdi3HrNEYqUlZUlPz8/ZWZmytfXt8LrPxD/VIXXCQAA7sz6uCBnN6HSeP6FRc5uAkpwp591nXaGAgAAAMDdj0ABAAAAwDQCBQAAAADTCBQAAAAATCNQAAAAADCNQAEAAADANALFXcDNPcDZTQAAAGXGL/NLcspP7qNiuTm7AShdRLfFzm4CAAAoo9Y9nd0CoGJwhgIAAACAaQQKAAAAAKYRKAAAAACYRqAAAAAAYBqBAgAAAIBpBAoAAAAAphEoAAAAAJhGoAAAAABgGoECAAAAgGkECgAAAACmESgAAAAAmEagAAAAAGAagQIAAACAaW7ObgBKl5I4TXm5F5zdDAAAgDLZuKOOrua4SLI4uylOZbFYtPD5l53djHJDoLgLECYAAMDd6GqOq7ObUCkYhuHsJpQrLnkCAAAAYBqBAgAAAIBpBAoAAAAAphEoAAAAAJhGoAAAAABgGoECAAAAgGkECgAAAACmVepAERkZqSlTppRYJiwsTEuWLKmQ9gAAAACwVa6BIjo6WhaLpdDj6NGj5VmtjQ0bNqhTp07y9/eXl5eX2rVrpw8++KDC6gcAAADuZeU+U3ZUVJRiY2NtltWpU6e8q7WqWbOmnnvuOYWHh8vd3V0bN27UU089pcDAQPXv37/C2gEAAADci8r9kicPDw8FBQXZPFxdb07DHh8fry5dusjDw0PBwcGaNWuW8vLyit1XRkaGBg0aJE9PTzVq1Ehr1qwptf7IyEj96le/UkREhJo0aaLJkyerTZs22rlzp8OOEQAAAKiqyv0MRXHOnDmjhx9+WNHR0Vq9erUOHz6smJgYVa9eXQsWLChym+joaJ06dUpbtmyRu7u7Jk2apIyMDLvrNAxDW7ZsUWpqql599dViy+Xk5CgnJ8f6PCsry+46AAAAgKqk3APFxo0b5e3tbX0+YMAArV+/XsuWLVNISIiWLl0qi8Wi8PBwnT17VjNnztS8efPk4mJ78uSHH37Q559/rl27dqlr166SpBUrVigiIqLUNmRmZqp+/frKycmRq6urli1bpr59+xZbftGiRVq4cKHJIwYAAACqjnIPFL169dLy5cutz728vCRJKSkp6tatmywWi3XdAw88oMuXL+v06dMKDQ212U9KSorc3NzUqVMn67Lw8HD5+/uX2gYfHx8lJyfr8uXL+uqrrzRt2jQ1btxYkZGRRZafPXu2pk2bZn2elZWlkJAQew4XAAAAqFLKPVB4eXmpadOmhZYbhmETJgqWSSq0vLR1pXFxcbG2oV27dkpJSdGiRYuKDRQeHh7y8PAocz0AAABAVeO0eShatmyphIQEa1CQpISEBPn4+Kh+/fqFykdERCgvL0/79u2zLktNTdXFixfLXLdhGDb3SAAAAAAwx2mBYvz48Tp16pQmTpyow4cP69NPP9X8+fM1bdq0QvdPSFKLFi0UFRWlmJgY7d69W0lJSRo7dqw8PT1LrGfRokXavHmz/v3vf+vw4cNavHixVq9erd/+9rfldWgAAABAleG0X3mqX7++PvvsMz377LNq27atatasqTFjxmjOnDnFbhMbG6uxY8eqZ8+eqlu3rl588UXNnTu3xHquXLmi8ePH6/Tp0/L09FR4eLj+9re/adiwYY4+JAAAAKDKsRi3XnOEImVlZcnPz0+ZmZny9fWt8PoPxD9V4XUCAADcqfVxQc5uQqXx/AuLnN2EYt3pZ12nXfIEAAAA4O5HoAAAAABgGoECAAAAgGkECgAAAACmESgAAAAAmEagAAAAAGAageIu4OYe4OwmAAAAlJmnxw1JzFBgsVic3YRy5bSJ7WC/iG6Lnd0EAACAMmvd09ktQEXgDAUAAAAA0wgUAAAAAEwjUAAAAAAwjUABAAAAwDQCBQAAAADTCBQAAAAATCNQAAAAADCNQAEAAADANAIFAAAAANMIFAAAAABMI1AAAAAAMI1AAQAAAMA0AgUAAAAA09yc3QAAAADgXta692Tl5xullju47Z0KaI3jcYYCAAAAKEf2hIm7GYECAAAAgGkECgAAAACmESgAAAAAmEagAAAAAGAagQIAAACAaQQKAAAAAKYRKAAAAACY5pRAERYWpiVLlpRYxmKx6JNPPqmQ9gAAAAAwp0yBIjo6Wo8++mih5du2bZPFYtHFixcd1KzyYU+QAQAAAGA/N2c34F5hGIby8vJ048YNZzcFAIB7RrVq1eTq6ursZgAoQbkEio8++kjz5s3T0aNHFRwcrIkTJ2r69OnFlj9y5IjGjBmjPXv2qHHjxnr77bcLlTlw4IAmT56sxMRE1ahRQ4899pgWL14sb29vSVJkZKTatWtncwbi0Ucflb+/v1auXKnIyEidOHFCU6dO1dSpUyXdDAGOkJubq/T0dGVnZztkfwAA4CaLxaIGDRpY/78HUPk4PFAkJSVp6NChWrBggYYNG6aEhASNHz9etWrVUnR0dKHy+fn5GjJkiGrXrq1du3YpKytLU6ZMsSmTnZ2tqKgo3X///dq7d68yMjI0duxYTZgwQStXrrSrXRs2bFDbtm319NNPKyYmpsSyOTk5ysnJsT7Pysoqtmx+fr7S0tLk6uqqevXqyd3dXRaLxa42AQCA4hmGoZ9++kmnT59Ws2bNOFMBVFJlDhQbN24s9C3BrZf5LF68WA899JDmzp0rSWrevLkOHTqk119/vchAERcXp5SUFB0/flwNGjSQJL388ssaMGCAtcyaNWt09epVrV69Wl5eXpKkpUuXatCgQXr11VdVt27dUttds2ZNubq6ysfHR0FBQSWWXbRokRYuXFjqPqWbZyfy8/MVEhKiGjVq2LUNAACwT506dXT8+HFdv36dQAFUUmX+ladevXopOTnZ5vHee+9Z16ekpOiBBx6w2eaBBx7QkSNHiry/ICUlRaGhodYwIUndunUrVKZt27bWMFGwz/z8fKWmppb1EEo1e/ZsZWZmWh+nTp0qdRsXF36BFwAAR+OsP1D5lfkMhZeXl5o2bWqz7PTp09Z/G4ZR6MVf0r0KRa0ravvi3lAKlru4uBTa1/Xr14uttyQeHh7y8PAwtS0AAABQlTj8a/WWLVtq586dNssSEhLUvHnzIk9VtmzZUidPntTZs2etyxITEwuVSU5O1pUrV6zLvv76a7m4uKh58+aSbp4STU9Pt66/ceOGvv/+e5v9uLu78ytM5SgyMrLQ/S8lOX78uCwWi5KTk8utTbh73D5+KnK+Gua9wa2YKwkAysbhN2VPnz5dnTt31gsvvKBhw4YpMTFRS5cu1bJly4os36dPH7Vo0UIjR47Um2++qaysLD333HM2ZYYPH6758+dr1KhRWrBggX766SdNnDhRI0aMsN4/0bt3b02bNk2bNm1SkyZN9NZbbxWaFyMsLEzbt2/X448/Lg8PD9WuXdvRh28j99p53bh+qVzrKOBazUfu1WvZVba008ejRo2y+2b3W23YsEHVqlWzu3xISIjS09PL/e9wp6Kjo3Xx4kWnfni4ePGisrOvlF7QQWrU8JK/v79dZQcNGqSrV68qLi6u0LrExER1795dSUlJ6tChQ5nasHfvXpvLHB1hwYIF+uSTTwqF2PT0dAUEBDi0LkcLCwvTlClTyhTaHensj//RxcyKG4P+fl6qV7em3eWLe51u27ZNvXr10oULF+we05VJUb9gWJTyGB/21n0n7va/D4CbHB4oOnTooL///e+aN2+eXnjhBQUHB+v5558v8oZs6ealSh9//LHGjBmjLl26KCwsTO+8846ioqKsZWrUqKF//etfmjx5sjp37mzzs7EFRo8erf3792vkyJFyc3PT1KlT1atXL5u6nn/+ef3ud79TkyZNlJOT47CfjS1K7rXz+mHvbBn55i67KiuLSzU177zIrlBx65mcdevWad68eTb3onh6etqUv379ul1BoWZN+//zlyRXV9dSb5DHzTDxzttvKi8vr8LqdHNz06TJ0+36D37MmDEaMmSITpw4oYYNG9qse//999WuXbsyhwnp5lnHisI4LNnZH/+jR0a8qNzcihuD7u5u2vTBnDKFCgCAc5TpkqeVK1cW+S1tZGSkDMOwfvh47LHHdPDgQeXm5urEiROaMWOGTfnjx4/bfIvSvHlz7dixQzk5OUpNTVX//v1lGIbNrNytW7fWli1bdPXqVZ0/f15/+ctfbH5tqlq1alq2bJnOnz+vH3/8UbNmzdInn3xi8037/fffr/379+vatWvlGiYk6cb1SxUWJiTJyL9u99mQoKAg68PPz08Wi8X6/Nq1a/L399ff//53RUZGqnr16vrb3/6m8+fP64knnlCDBg1Uo0YNtW7dWh9++KHNfou6ZOXll1/W6NGj5ePjo9DQUP3lL3+xrr/9kqeCGde/+uorderUSTVq1FD37t0L3Xj/4osvKjAwUD4+Pho7dqxmzZqldu3aFXu8Fy5c0PDhw1WnTh15enqqWbNmio2Nta4/c+aMhg0bpoCAANWqVUuDBw/W8ePHJd38RnvVqlX69NNPZbFYZLFYtG3bNrv62VGys69UaJiQpLy8PLvPiAwcOFCBgYGFzmplZ2dr3bp1GjNmjF3j53a3X3Zy5MgR9ejRQ9WrV1fLli21efPmQtvMnDlTzZs3V40aNdS4cWPNnTvXei/VypUrtXDhQu3fv9/6tyxo8+2Xrxw4cEC9e/eWp6enatWqpaefflqXL1+2ro+Ojtajjz6qN954Q8HBwapVq5aeeeaZEu/b2r9/v3r16iUfHx/5+vqqY8eO2rdvn3V9QkKCevToIU9PT4WEhGjSpEnWyzxvnUenoO0V6WLmlQoNE5KUm5tXbmdESurrotgz9opy8uRJDR48WN7e3vL19dXQoUP1448/WtcXjKNbTZkyRZGRkdb18fHxevvtt61/94L3pluVND5KO9Zly5apWbNmql69uurWratf//rXZaq7pH1IN++BfO2119S4cWN5enqqbdu2+sc//iHp5v8BBV/8BQQEyGKxFPvlI4DKjZ8mQpFmzpypSZMmKSUlRf3799e1a9fUsWNHbdy4Ud9//72efvppjRgxQrt37y5xP2+++aY6deqkb7/9VuPHj9fvf/97HT58uMRtnnvuOb355pvat2+f3NzcNHr0aOu6NWvW6KWXXtKrr76qpKQkhYaGavny5SXub+7cuTp06JA+//xzpaSkaPny5dbLrLKzs9WrVy95e3tr+/bt2rlzp7y9vRUVFaXc3FzNmDFDQ4cOVVRUlNLT05Wenq7u3bvb2YtVg5ubm0aOHKmVK1faBPX169crNzdXw4cPNz1+ChTMV+Pq6qpdu3bp3Xff1cyZMwuV8/Hx0cqVK3Xo0CG9/fbb+utf/6q33npLkjRs2DBNnz5drVq1sv4thw0bVmgfBfPeBAQEaO/evVq/fr3i4uI0YcIEm3Jbt27VsWPHtHXrVq1atUorV64s8VLB4cOHq0GDBtq7d6+SkpI0a9Ys65m/AwcOqH///hoyZIi+++47rVu3Tjt37rTWuWHDBjVo0EDPP/+8te0wp7S+vp29Y+92BV+K/ec//1F8fLw2b96sY8eOFTnmivP222+rW7duiomJsf7dQ0JCCpUrbnyUdqz79u3TpEmT9Pzzzys1NVVffPGFevToUaa6S9qHJM2ZM0exsbFavny5Dh48qKlTp+q3v/2t4uPjFRISoo8++kiSlJqaqvT09CIntgVQ+ZXLTNm4+02ZMkVDhgyxWXbrmaaJEyfqiy++0Pr169W1a9di9/Pwww9r/Pjxkm6GlLfeekvbtm1TeHh4sdu89NJL6tmzpyRp1qxZeuSRR3Tt2jVVr15d//M//6MxY8boqaeekiTNmzdPX375pc23x7c7efKk2rdvr06dOkm6+c13gbVr18rFxUXvvfee9Vu92NhY+fv7a9u2berXr588PT2Vk5PDZTElGD16tF5//XXr9dDSzcudhgwZooCAAAUEBJgaPwXsma9GuvnhpUBYWJimT5+udevW6Q9/+IM8PT3l7e0tNze3Ev+W9s57ExAQoKVLl8rV1VXh4eF65JFH9NVXXxU7cebJkyf17LPPWsd+s2bNrOtef/11Pfnkk9YzfM2aNdM777yjnj17avny5WWaR6cqK22eJKn0vq5evbpNeXvH3u3i4uL03XffKS0tzfpB/IMPPlCrVq20d+9ede7cudTj8fPzk7u7u2rUqFHi37248VHasZ48eVJeXl4aOHCgfHx81LBhQ7Vv375MdZe0jytXrmjx4sXasmWL9efgGzdurJ07d+rPf/6zevbsab1UNjAwkHsogLsYZyhQpIIP3wVu3Lihl156SW3atFGtWrXk7e2tL7/8UidPnixxP23atLH+u+DSqoyMDLu3CQ4OliTrNqmpqerSpYtN+duf3+73v/+91q5dq3bt2ukPf/iDEhISrOuSkpJ09OhR+fj4yNvbW97e3qpZs6auXbumY8eOlbhf/Fd4eLi6d++u999/X5J07Ngx7dixw3p2yez4KWDPfDWS9I9//EO/+MUvFBQUJG9vb82dO9fuOm6ty555b1q1amXzy3XBwcElju1p06Zp7Nix6tOnj1555RWb8ZWUlKSVK1dax6C3t7f69++v/Px8paWllan9VVlp8yRJZe9re8begAEDrPtq1aqVdbuQkBCbb/Vbtmwpf39/paSkOPKwi1Xasfbt21cNGzZU48aNNWLECK1Zs0bZ2dllqqOkfRw6dEjXrl1T3759bdqwevVq3l+BewxnKFCk239d580339Rbb72lJUuWqHXr1vLy8tKUKVOUm5tb4n5uv5nbYrEoPz/f7m0Kzhrcuk1Z5jmRbv5nf+LECW3atElxcXF66KGH9Mwzz+iNN95Qfn6+OnbsqDVr1hTariJvCr4XjBkzRhMmTNCf/vQnxcbGqmHDhnrooYckmR8/BeyZr2bXrl16/PHHtXDhQvXv319+fn5au3at3nzzzTIdhz3z3khlH9sLFizQk08+qU2bNunzzz/X/PnztXbtWv3qV79Sfn6+fve732nSpEmFtgsNDS1T+6uy0uZJklTmvrZn7L333nu6evWqpP+Oi+LG0a3LHTl/UlFKO1Z3d3d988032rZtm7788kvNmzdPCxYs0N69e+0+W+Dj41PsPgpeD5s2bVL9+vVttmOuJ+DeQqCAXXbs2KHBgwfrt7/9raSb/1EdOXJEERERFdqOFi1aaM+ePRoxYoR12a03thanTp06io6OVnR0tB588EE9++yzeuONN9ShQwetW7dOgYGB8vX1LXJb5i+xz9ChQzV58mT97//+r1atWqWYmBjrB6c7HT+3zldTr149SYXnq/n666/VsGFDm5+dPnHihE0Ze/6WLVu21KpVq3TlyhVrsL593huzmjdvrubNm2vq1Kl64oknFBsbq1/96lfq0KGDDh48WOjDcFnbjtLZ09e3smfs3f5h+dbtTp06ZT1LcejQIWVmZlrHfZ06dQrNl5ScnGwTVu39uxdVzp5jdXNzU58+fdSnTx/Nnz9f/v7+2rJli4YMGWJ33cXto2/fvvLw8NDJkyetl7EW1W6p8KVpAO4uXPIEuzRt2lSbN29WQkKCUlJS9Lvf/U7nzp2r8HZMnDhRK1as0KpVq3TkyBG9+OKL+u6770r81Zt58+bp008/1dGjR3Xw4EFt3LjR+h/68OHDVbt2bQ0ePFg7duxQWlqa4uPjNXnyZOs3m2FhYfruu++Umpqqn3/+2aHfIN5LvL29NWzYMP3xj3/U2bNnbX6t5U7Hz63z1ezfv187duwoNF9N06ZNdfLkSa1du1bHjh3TO++8o48//timTFhYmNLS0pScnKyff/5ZOTk5heoaPny4qlevrlGjRun777/X1q1bC817U1ZXr17VhAkTtG3bNp04cUJff/219u7dax2HM2fOVGJiop555hklJyfryJEj+uc//6mJEyfatH379u06c+aMfv75Z1PtgH19fSt7xl5x27Vp00bDhw/XN998oz179mjkyJHq2bOn9ZLS3r17a9++fVq9erWOHDmi+fPnFwoYYWFh2r17t44fP66ff/652LNgRY2P0o5148aNeuedd5ScnKwTJ05o9erVys/PV4sWLeyuu6R9+Pj4aMaMGZo6dapWrVqlY8eO6dtvv9Wf/vQnrVq1SpLUsGFDWSwWbdy4UT/99FOJ98MBqLwIFLDL3Llz1aFDB/Xv31+RkZEKCgoq9HOHFWH48OGaPXu2ZsyYoQ4dOigtLU3R0dGFbqS8lbu7u2bPnq02bdqoR48ecnV11dq1ayXdnONk+/btCg0N1ZAhQxQREaHRo0fr6tWr1jMWMTExatGihTp16qQ6dero66+/rpBjvRuNGTNGFy5cUJ8+fWwuH7nT8VMwX01OTo66dOmisWPH6qWXXrIpM3jwYE2dOlUTJkxQu3btlJCQoLlz59qUeeyxxxQVFaVevXqpTp06Rf50bcG8N//5z3/UuXNn/frXv9ZDDz2kpUuXlq0zbuHq6qrz589r5MiRat68uYYOHaoBAwZo4cKFkm7eNxQfH68jR47owQcfVPv27TV37lzrPUTSzXl0jh8/riZNmnA53h2wp69vZc/YK0rBTxEHBASoR48e6tOnjxo3bqx169ZZy/Tv319z587VH/7wB3Xu3FmXLl3SyJEjbfYzY8YMubq6qmXLlqpTp06x9wQVNT5KO1Z/f39t2LBBvXv3VkREhN599119+OGH1vtA7Km7tH288MILmjdvnhYtWqSIiAj1799f/+///T81atRI0s2zOwsXLtSsWbNUt27dYn9tC7jbubhU7M99VzSLUd4TMtwDsrKy5Ofnp8zMzEKXxVy7dk1paWlq1KiRzYfayjyx3b2mb9++CgoK0gcffODsppSLyj6xHe59TGwHZyru/1kAjlPSZ117cA9FOXGvXkvNOy+ye7K5O+VazadKhIns7Gy9++676t+/v1xdXfXhhx8qLi7O7omm7kb+/v6aNHm63RPNOUKNGl6ECVjVq1tTmz6YU24TzRXF38+LMAEAdwkCRTlyr15LqgIf8iuSxWLRZ599phdffFE5OTlq0aKFPvroI/Xp08fZTStX/v7+fMCHU9WrW5MP+ACAIhEocFfx9PRUXFycs5sBAACA/8NN2QAAAABMI1AAAAAAMI1A4SD8WBYAAI7H/69A5UeguEMFM5pmZ2c7uSUAANx7cnNzJd2czwVA5cRN2XfI1dVV/v7+ysjIkHRzUqySZm0GAAD2yc/P108//aQaNWrIzY2PLEBlxavTAYKCgiTJGioAAIBjuLi4KDQ0lC/rgEqMQOEAFotFwcHBCgwM1PXrFTMzNgAAVYG7u7tcXLhCG6jMCBQO5OrqyjWeAAAAqFKI/AAAAABMI1AAAAAAMI1AAQAAAMA07qGwQ8GkOllZWU5uCQAAAOBYBZ9xzU4kSaCww6VLlyRJISEhTm4JAAAAUD4uXbokPz+/Mm9nMZjTvlT5+fk6e/asfHx8nPI72FlZWQoJCdGpU6fk6+tb4fXfK+hHx6Af7xx96Bj0o2PQj45BPzoG/egYZe1HwzB06dIl1atXz9TPNHOGwg4uLi5q0KCBs5shX19fXlwOQD86Bv145+hDx6AfHYN+dAz60THoR8coSz+aOTNRgJuyAQAAAJhGoAAAAABgGoHiLuDh4aH58+fLw8PD2U25q9GPjkE/3jn60DHoR8egHx2DfnQM+tExKrofuSkbAAAAgGmcoQAAAABgGoECAAAAgGkECgAAAACmESgquWXLlqlRo0aqXr26OnbsqB07dji7SZXGokWL1LlzZ/n4+CgwMFCPPvqoUlNTbcpER0fLYrHYPO6//36bMjk5OZo4caJq164tLy8v/fKXv9Tp06cr8lCcasGCBYX6KCgoyLreMAwtWLBA9erVk6enpyIjI3Xw4EGbfVT1PpSksLCwQv1osVj0zDPPSGIsFmf79u0aNGiQ6tWrJ4vFok8++cRmvaPG34ULFzRixAj5+fnJz89PI0aM0MWLF8v56CpOSf14/fp1zZw5U61bt5aXl5fq1aunkSNH6uzZszb7iIyMLDRGH3/8cZsyVbkfJce9jqt6Pxb1XmmxWPT6669by1T18WjPZ5zK9P5IoKjE1q1bpylTpui5557Tt99+qwcffFADBgzQyZMnnd20SiE+Pl7PPPOMdu3apc2bNysvL0/9+vXTlStXbMpFRUUpPT3d+vjss89s1k+ZMkUff/yx1q5dq507d+ry5csaOHCgbty4UZGH41StWrWy6aMDBw5Y17322mtavHixli5dqr179yooKEh9+/bVpUuXrGXoQ2nv3r02fbh582ZJ0m9+8xtrGcZiYVeuXFHbtm21dOnSItc7avw9+eSTSk5O1hdffKEvvvhCycnJGjFiRLkfX0UpqR+zs7P1zTffaO7cufrmm2+0YcMG/fDDD/rlL39ZqGxMTIzNGP3zn/9ss74q92MBR7yOq3o/3tp/6enpev/992WxWPTYY4/ZlKvK49GezziV6v3RQKXVpUsXY9y4cTbLwsPDjVmzZjmpRZVbRkaGIcmIj4+3Lhs1apQxePDgYre5ePGiUa1aNWPt2rXWZWfOnDFcXFyML774ojybW2nMnz/faNu2bZHr8vPzjaCgIOOVV16xLrt27Zrh5+dnvPvuu4Zh0IfFmTx5stGkSRMjPz/fMAzGoj0kGR9//LH1uaPG36FDhwxJxq5du6xlEhMTDUnG4cOHy/moKt7t/ViUPXv2GJKMEydOWJf17NnTmDx5crHb0I+OeR3Tj4UNHjzY6N27t80yxqOt2z/jVLb3R85QVFK5ublKSkpSv379bJb369dPCQkJTmpV5ZaZmSlJqlmzps3ybdu2KTAwUM2bN1dMTIwyMjKs65KSknT9+nWbfq5Xr57uu+++KtXPR44cUb169dSoUSM9/vjj+ve//y1JSktL07lz52z6x8PDQz179rT2D31YWG5urv72t79p9OjRslgs1uWMxbJx1PhLTEyUn5+funbtai1z//33y8/Pr8r2bWZmpiwWi/z9/W2Wr1mzRrVr11arVq00Y8YMm2866ceb7vR1TD/a+vHHH7Vp0yaNGTOm0DrG43/d/hmnsr0/upk/NJSnn3/+WTdu3FDdunVtltetW1fnzp1zUqsqL8MwNG3aNP3iF7/QfffdZ10+YMAA/eY3v1HDhg2VlpamuXPnqnfv3kpKSpKHh4fOnTsnd3d3BQQE2OyvKvVz165dtXr1ajVv3lw//vijXnzxRXXv3l0HDx609kFR4/DEiROSRB8W4ZNPPtHFixcVHR1tXcZYLDtHjb9z584pMDCw0P4DAwOrZN9eu3ZNs2bN0pNPPilfX1/r8uHDh6tRo0YKCgrS999/r9mzZ2v//v3Wy/foR8e8julHW6tWrZKPj4+GDBlis5zx+F9FfcapbO+PBIpK7tZvN6Wbg+r2ZZAmTJig7777Tjt37rRZPmzYMOu/77vvPnXq1EkNGzbUpk2bCr153aoq9fOAAQOs/27durW6deumJk2aaNWqVdabDc2Mw6rUh7dbsWKFBgwYoHr16lmXMRbNc8T4K6p8Vezb69ev6/HHH1d+fr6WLVtmsy4mJsb67/vuu0/NmjVTp06d9M0336hDhw6S6EdHvY6rej/e6v3339fw4cNVvXp1m+WMx/8q7jOOVHneH7nkqZKqXbu2XF1dC6XDjIyMQmm0qps4caL++c9/auvWrWrQoEGJZYODg9WwYUMdOXJEkhQUFKTc3FxduHDBplxV7mcvLy+1bt1aR44csf7aU0njkD60deLECcXFxWns2LEllmMsls5R4y8oKEg//vhjof3/9NNPVapvr1+/rqFDhyotLU2bN2+2OTtRlA4dOqhatWo2Y5R+tGXmdUw//teOHTuUmppa6vulVHXHY3GfcSrb+yOBopJyd3dXx44draf2CmzevFndu3d3UqsqF8MwNGHCBG3YsEFbtmxRo0aNSt3m/PnzOnXqlIKDgyVJHTt2VLVq1Wz6OT09Xd9//32V7eecnBylpKQoODjYerr51v7Jzc1VfHy8tX/oQ1uxsbEKDAzUI488UmI5xmLpHDX+unXrpszMTO3Zs8daZvfu3crMzKwyfVsQJo4cOaK4uDjVqlWr1G0OHjyo69evW8co/ViYmdcx/fhfK1asUMeOHdW2bdtSy1a18VjaZ5xK9/5o//3lqGhr1641qlWrZqxYscI4dOiQMWXKFMPLy8s4fvy4s5tWKfz+9783/Pz8jG3bthnp6enWR3Z2tmEYhnHp0iVj+vTpRkJCgpGWlmZs3brV6Natm1G/fn0jKyvLup9x48YZDRo0MOLi4oxvvvnG6N27t9G2bVsjLy/PWYdWoaZPn25s27bN+Pe//23s2rXLGDhwoOHj42MdZ6+88orh5+dnbNiwwThw4IDxxBNPGMHBwfRhEW7cuGGEhoYaM2fOtFnOWCzepUuXjG+//db49ttvDUnG4sWLjW+//db660OOGn9RUVFGmzZtjMTERCMxMdFo3bq1MXDgwAo/3vJSUj9ev37d+OUvf2k0aNDASE5Otnm/zMnJMQzDMI4ePWosXLjQ2Lt3r5GWlmZs2rTJCA8PN9q3b08//l8/OvJ1XJX7sUBmZqZRo0YNY/ny5YW2ZzyW/hnHMCrX+yOBopL705/+ZDRs2NBwd3c3OnToYPOTqFWdpCIfsbGxhmEYRnZ2ttGvXz+jTp06RrVq1YzQ0FBj1KhRxsmTJ232c/XqVWPChAlGzZo1DU9PT2PgwIGFytzLhg0bZgQHBxvVqlUz6tWrZwwZMsQ4ePCgdX1+fr4xf/58IygoyPDw8DB69OhhHDhwwGYfVb0PC/zrX/8yJBmpqak2yxmLxdu6dWuRr+NRo0YZhuG48Xf+/Hlj+PDhho+Pj+Hj42MMHz7cuHDhQgUdZfkrqR/T0tKKfb/cunWrYRiGcfLkSaNHjx5GzZo1DXd3d6NJkybGpEmTjPPnz9vUU5X70ZGv46rcjwX+/Oc/G56ensbFixcLbc94LP0zjmFUrvdHy/81GgAAAADKjHsoAAAAAJhGoAAAAABgGoECAAAAgGkECgAAAACmESgAAAAAmEagAAAAAGAagQIAAACAaQQKAAAAAKYRKAAATrFgwQK1a9fO2c0AANwhZsoGADicxWIpcf2oUaO0dOlS5eTkqFatWhXUKgBAeSBQAAAc7ty5c9Z/r1u3TvPmzVNqaqp1maenp/z8/JzRNACAg3HJEwDA4YKCgqwPPz8/WSyWQstuv+QpOjpajz76qF5++WXVrVtX/v7+WrhwofLy8vTss8+qZs2aatCggd5//32bus6cOaNhw4YpICBAtWrV0uDBg3X8+PGKPWAAqMIIFACASmPLli06e/astm/frsWLF2vBggUaOHCgAgICtHv3bo0bN07jxo3TqVOnJEnZ2dnq1auXvL29tX37du3cuVPe3t6KiopSbm6uk48GAKoGAgUAoNKoWbOm3nnnHbVo0UKjR49WixYtlJ2drT/+8Y9q1qyZZs+eLXd3d3399deSpLVr18rFxUXvvfeeWrdurYiICMXGxurkyZPatm2bcw8GAKoIN2c3AACAAq1atZKLy3+/66pbt67uu+8+63NXV1fVqlVLGRkZkqSkpCQdPXpUPj4+Nvu5du2ajh07VjGNBoAqjkABAKg0qlWrZvPcYrEUuSw/P1+SlJ+fr44dO2rNmjWF9lWnTp3yaygAwIpAAQC4a3Xo0EHr1q1TYGCgfH19nd0cAKiSuIcCAHDXGj58uGrXrq3Bgwdrx44dSktLU3x8vCZPnqzTp087u3kAUCUQKAAAd60aNWpo+/btCg0N1ZAhQxQREaHRo0fr6tWrnLEAgArCxHYAAAAATOMMBQAAAADTCBQAAAAATCNQAAAAADCNQAEAAADANAIFAAAAANMIFAAAAABMI1AAAAAAMI1AAQAAAMA0AgUAAAAA0wgUAAAAAEwjUAAAAAAwjUABAAAAwLT/D9blFiP5mbuxAAAAAElFTkSuQmCC", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "cv_cmap = matplotlib.colormaps[\"cividis\"]\n", + "plt.figure(figsize=(8, 3))\n", + "\n", + "for i, (train_mask, valid_mask) in enumerate(cv_folds):\n", + " idx = np.array([np.nan] * time_horizon)\n", + " idx[np.arange(*train_mask)] = 1\n", + " idx[np.arange(*valid_mask)] = 0\n", + " plt.scatter(\n", + " range(time_horizon),\n", + " [i + 0.5] * time_horizon,\n", + " c=idx,\n", + " marker=\"_\",\n", + " capstyle=\"butt\",\n", + " s=1,\n", + " lw=20,\n", + " cmap=cv_cmap,\n", + " vmin=-1.5,\n", + " vmax=1.5\n", + " )\n", + "\n", + "idx = np.array([np.nan] * time_horizon)\n", + "idx[np.arange(*holdout)] = -1\n", + "plt.scatter(\n", + " range(time_horizon),\n", + " [n_folds + 0.5] * time_horizon,\n", + " c=idx,\n", + " marker=\"_\",\n", + " capstyle=\"butt\",\n", + " s=1,\n", + " lw=20,\n", + " cmap=cv_cmap,\n", + " vmin=-1.5,\n", + " vmax=1.5\n", + ")\n", + "\n", + "plt.xlabel(\"Time\")\n", + "plt.yticks(\n", + " ticks=np.arange(n_folds + 1) + 0.5,\n", + " labels=[f\"Fold {i}\" for i in range(n_folds)] + [\"Holdout\"]\n", + ")\n", + "plt.ylim([len(cv_folds) + 1.2, -0.2])\n", + "\n", + "norm = matplotlib.colors.Normalize(vmin=-1.5, vmax=1.5)\n", + "plt.legend(\n", + " [Patch(color=cv_cmap(norm(1))),\n", + " Patch(color=cv_cmap(norm(0))),\n", + " Patch(color=cv_cmap(norm(-1)))],\n", + " [\"Training set\", \"Validation set\", \"Held-out test set\"],\n", + " ncol=3,\n", + " loc=\"best\"\n", + ")\n", + "plt.tight_layout()" + ] + }, + { + "cell_type": "markdown", + "id": "17b4c300-971d-419e-b167-19ff533860dd", + "metadata": {}, + "source": [ + "## Launch a Dask client on Kubernetes" + ] + }, + { + "cell_type": "markdown", + "id": "bd94ac3f-2843-4596-9f27-14864ef5ec8e", + "metadata": {}, + "source": [ + "Let us set up a Dask cluster using the `KubeCluster` class." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "deac0215-66fb-4c2b-a606-bbb6b6ec2260", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "8c26e7b9e0a54dc8863e6733a70ebdbf", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "cluster = KubeCluster(name=\"rapids-dask\",\n",
+    "                      image=rapids_image,\n",
+    "                      worker_command=\"dask-cuda-worker\",\n",
+    "                      n_workers=n_workers,\n",
+    "                      resources={\"limits\": {\"nvidia.com/gpu\": \"1\"}},\n",
+    "                      env={\"DISABLE_JUPYTER\": \"true\", \"EXTRA_PIP_PACKAGES\": \"optuna gcsfs\"})"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "id": "b2d5c010-4cd3-4d58-9c88-84d4aa404234",
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "76e37407208149dcb0b97123d8a2a135",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/html": [
+       "
\n", + "
\n", + "
\n", + "
\n", + "

KubeCluster

\n", + "

rapids-dask

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + " Dashboard: http://rapids-dask-scheduler.kubeflow-user-example-com:8787/status\n", + " \n", + " Workers: 1\n", + "
\n", + " Total threads: 1\n", + " \n", + " Total memory: 117.93 GiB\n", + "
\n", + "\n", + "
\n", + " \n", + "

Scheduler Info

\n", + "
\n", + "\n", + "
\n", + "
\n", + "
\n", + "
\n", + "

Scheduler

\n", + "

Scheduler-8df8661b-8c41-4c39-ba31-07f76461af5e

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + " Comm: tcp://10.36.3.19:8786\n", + " \n", + " Workers: 1\n", + "
\n", + " Dashboard: http://10.36.3.19:8787/status\n", + " \n", + " Total threads: 1\n", + "
\n", + " Started: Just now\n", + " \n", + " Total memory: 117.93 GiB\n", + "
\n", + "
\n", + "
\n", + "\n", + "
\n", + " \n", + "

Workers

\n", + "
\n", + "\n", + " \n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "

Worker: rapids-dask-default-worker-9fc2234a8d

\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + "\n", + "
\n", + " Comm: tcp://10.36.4.20:35995\n", + " \n", + " Total threads: 1\n", + "
\n", + " Dashboard: http://10.36.4.20:8788/status\n", + " \n", + " Memory: 117.93 GiB\n", + "
\n", + " Nanny: tcp://10.36.4.20:40817\n", + "
\n", + " Local directory: /tmp/dask-scratch-space/worker-mwjhcmjv\n", + "
\n", + " GPU: Tesla T4\n", + " \n", + " GPU memory: 15.00 GiB\n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "\n", + "
\n", + "
\n", + "\n", + "
\n", + "
\n", + "
" + ], + "text/plain": [ + "KubeCluster(rapids-dask, 'tcp://rapids-dask-scheduler.kubeflow-user-example-com:8786', workers=1, threads=1, memory=117.93 GiB)" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "cluster" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "727c1182-4c74-4d2d-a059-c359193d76ff", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "client = Client(cluster)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "633a2f56-9ad2-4e38-8e6e-be3bf52ba795", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "
\n", + "
\n", + "

Client

\n", + "

Client-54cb8819-2c1e-11ee-8c1e-9a09a5b5e674

\n", + " \n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + "
Connection method: Cluster objectCluster type: dask_kubernetes.KubeCluster
\n", + " Dashboard: http://rapids-dask-scheduler.kubeflow-user-example-com:8787/status\n", + "
\n", + "\n", + " \n", + " \n", + " \n", + "\n", + " \n", + "
\n", + "

Cluster Info

\n", + "
\n", + "
\n", + "
\n", + "
\n", + "

KubeCluster

\n", + "

rapids-dask

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + " Dashboard: http://rapids-dask-scheduler.kubeflow-user-example-com:8787/status\n", + " \n", + " Workers: 3\n", + "
\n", + " Total threads: 3\n", + " \n", + " Total memory: 353.79 GiB\n", + "
\n", + "\n", + "
\n", + " \n", + "

Scheduler Info

\n", + "
\n", + "\n", + "
\n", + "
\n", + "
\n", + "
\n", + "

Scheduler

\n", + "

Scheduler-8df8661b-8c41-4c39-ba31-07f76461af5e

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + " Comm: tcp://10.36.3.19:8786\n", + " \n", + " Workers: 3\n", + "
\n", + " Dashboard: http://10.36.3.19:8787/status\n", + " \n", + " Total threads: 3\n", + "
\n", + " Started: Just now\n", + " \n", + " Total memory: 353.79 GiB\n", + "
\n", + "
\n", + "
\n", + "\n", + "
\n", + " \n", + "

Workers

\n", + "
\n", + "\n", + " \n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "

Worker: rapids-dask-default-worker-3588dacd46

\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + "\n", + "
\n", + " Comm: tcp://10.36.1.16:41549\n", + " \n", + " Total threads: 1\n", + "
\n", + " Dashboard: http://10.36.1.16:8788/status\n", + " \n", + " Memory: 117.93 GiB\n", + "
\n", + " Nanny: tcp://10.36.1.16:43755\n", + "
\n", + " Local directory: /tmp/dask-scratch-space/worker-j2wa7czf\n", + "
\n", + " GPU: Tesla T4\n", + " \n", + " GPU memory: 15.00 GiB\n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "

Worker: rapids-dask-default-worker-5d46c38fcf

\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + "\n", + "
\n", + " Comm: tcp://10.36.2.23:39471\n", + " \n", + " Total threads: 1\n", + "
\n", + " Dashboard: http://10.36.2.23:8788/status\n", + " \n", + " Memory: 117.93 GiB\n", + "
\n", + " Nanny: tcp://10.36.2.23:44423\n", + "
\n", + " Local directory: /tmp/dask-scratch-space/worker-dsfuecsa\n", + "
\n", + " GPU: Tesla T4\n", + " \n", + " GPU memory: 15.00 GiB\n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "

Worker: rapids-dask-default-worker-9fc2234a8d

\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + "\n", + "
\n", + " Comm: tcp://10.36.4.20:35995\n", + " \n", + " Total threads: 1\n", + "
\n", + " Dashboard: http://10.36.4.20:8788/status\n", + " \n", + " Memory: 117.93 GiB\n", + "
\n", + " Nanny: tcp://10.36.4.20:40817\n", + "
\n", + " Local directory: /tmp/dask-scratch-space/worker-mwjhcmjv\n", + "
\n", + " GPU: Tesla T4\n", + " \n", + " GPU memory: 15.00 GiB\n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "\n", + "
\n", + "
\n", + "\n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "\n", + "
\n", + "
" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "client" + ] + }, + { + "cell_type": "markdown", + "id": "5e0aa64d-1c5a-4a9b-886b-07e846111b55", + "metadata": {}, + "source": [ + "## Define the custom evaluation metric" + ] + }, + { + "cell_type": "markdown", + "id": "32d8c67d-3187-4483-bdc9-27b5b7b0cda0", + "metadata": {}, + "source": [ + "The M5 forecasting competition defines a custom metric called WRMSSE as follows:\n", + "$$\n", + "WRMSSE = \\sum w_i \\cdot RMSSE_i\n", + "$$\n", + "i.e. WRMSEE is a weighted sum of RMSSE for all product items $i$. RMSSE is in turn defined to be\n", + "$$\n", + "RMSSE = \\sqrt{\\frac{1/h \\cdot \\sum_t{\\left(Y_t - \\hat{Y}_t\\right)}^2}{1/(n-1)\\sum_t{(Y_t - Y_{t-1})}^2}}\n", + "$$\n", + "where the squared error of the prediction (forecast) is normalized by the speed at which the sales amount changes per unit in the training data.\n", + "\n", + "Here is the implementation of the WRMSSE using cuDF. We use the product weights $w_i$ as computed in the first preprocessing notebook." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "43f0e74d-734b-4135-8313-b9ed277a3341", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def wrmsse(product_weights, df, pred_sales, train_mask, valid_mask):\n", + " \"\"\"Compute WRMSSE metric\"\"\"\n", + " df_train = df[(df[\"day_id\"] >= train_mask[0]) & (df[\"day_id\"] < train_mask[1])]\n", + " df_valid = df[(df[\"day_id\"] >= valid_mask[0]) & (df[\"day_id\"] < valid_mask[1])]\n", + "\n", + " # Compute denominator: 1/(n-1) * sum( (y(t) - y(t-1))**2 )\n", + " diff = df_train.sort_values([\"item_id\", \"day_id\"]).groupby([\"item_id\"])[[\"sales\"]].diff(1)\n", + " x = df_train[[\"item_id\", \"day_id\"]].join(diff, how=\"left\").rename(columns={\"sales\": \"diff\"}).sort_values([\"item_id\", \"day_id\"])\n", + " x[\"diff\"] = x[\"diff\"] ** 2\n", + " xx = x.groupby([\"item_id\"])[[\"diff\"]].agg([\"sum\", \"count\"]).sort_index()\n", + " xx.columns = xx.columns.map(\"_\".join)\n", + " xx[\"denominator\"] = xx[\"diff_sum\"] / xx[\"diff_count\"]\n", + " t = xx.reset_index()\n", + " \n", + " # Compute numerator: 1/h * sum( (y(t) - y_pred(t))**2 )\n", + " X_valid = df_valid.drop(columns=[\"item_id\", \"cat_id\", \"day_id\", \"sales\"])\n", + " if \"dept_id\" in X_valid.columns:\n", + " X_valid = X_valid.drop(columns=[\"dept_id\"])\n", + " df_pred = cudf.DataFrame({\n", + " \"item_id\": df_valid[\"item_id\"].copy(),\n", + " \"pred_sales\": pred_sales,\n", + " \"sales\": df_valid[\"sales\"].copy()\n", + " })\n", + " df_pred[\"diff\"] = (df_pred[\"sales\"] - df_pred[\"pred_sales\"])**2\n", + " yy = df_pred.groupby([\"item_id\"])[[\"diff\"]].agg([\"sum\", \"count\"]).sort_index()\n", + " yy.columns = yy.columns.map(\"_\".join)\n", + " yy[\"numerator\"] = yy[\"diff_sum\"] / yy[\"diff_count\"]\n", + " \n", + " zz = yy[[\"numerator\"]].join(xx[[\"denominator\"]], how=\"left\")\n", + " zz = zz.join(product_weights, how=\"left\").sort_index()\n", + " # Filter out zero denominator.\n", + " # This can occur if the product was never on sale during the period in the training set\n", + " zz = zz[zz[\"denominator\"] != 0]\n", + " zz[\"rmsse\"] = np.sqrt(zz[\"numerator\"] / zz[\"denominator\"])\n", + " t = zz[\"rmsse\"].multiply(zz[\"weights\"])\n", + " return zz[\"rmsse\"].multiply(zz[\"weights\"]).sum()" + ] + }, + { + "cell_type": "markdown", + "id": "b5e83747-b181-49f3-b302-bdffd1e2d847", + "metadata": {}, + "source": [ + "## Define the training and hyperparameter search pipeline using Optuna" + ] + }, + { + "cell_type": "markdown", + "id": "6145f299-9c1d-4391-af2a-63bf9afc8fb8", + "metadata": {}, + "source": [ + "Optuna lets us define the training procedure iteratively, i.e. as if we were to write an ordinary function to train a single model. Instead of a fixed hyperparameter combination, the function now takes in a `trial` object which yields different hyperparameter combinations.\n", + "\n", + "In this example, we partition the training data according to the store and then fit a separate XGBoost model per data segment." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "38c7c1ae-1888-4ef9-8a57-0eabd845fcb4", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def objective(trial):\n", + " fs = gcsfs.GCSFileSystem()\n", + " with fs.open(f\"{bucket_name}/product_weights.pkl\", \"rb\") as f:\n", + " product_weights = cudf.DataFrame(pd.read_pickle(f))\n", + " params = {\n", + " \"n_estimators\": 100,\n", + " \"verbosity\": 0,\n", + " \"learning_rate\": 0.01,\n", + " \"objective\": \"reg:tweedie\",\n", + " \"tree_method\": \"gpu_hist\",\n", + " \"grow_policy\": \"depthwise\",\n", + " \"predictor\": \"gpu_predictor\",\n", + " \"enable_categorical\": True,\n", + " \"lambda\": trial.suggest_float(\"lambda\", 1e-8, 100.0, log=True),\n", + " \"alpha\": trial.suggest_float(\"alpha\", 1e-8, 100.0, log=True),\n", + " \"colsample_bytree\": trial.suggest_float(\"colsample_bytree\", 0.2, 1.0),\n", + " \"max_depth\": trial.suggest_int(\"max_depth\", 2, 6, step=1),\n", + " \"min_child_weight\": trial.suggest_float(\"min_child_weight\", 1e-8, 100, log=True),\n", + " \"gamma\": trial.suggest_float(\"gamma\", 1e-8, 1.0, log=True),\n", + " \"tweedie_variance_power\": trial.suggest_float(\"tweedie_variance_power\", 1, 2),\n", + " }\n", + " scores = [[] for store in STORES]\n", + " \n", + " for store_id, store in enumerate(STORES):\n", + " print(f\"Processing store {store}...\")\n", + " with fs.open(f\"{bucket_name}/combined_df_store_{store}.pkl\", \"rb\") as f:\n", + " df = cudf.DataFrame(pd.read_pickle(f))\n", + " for train_mask, valid_mask in cv_folds:\n", + " df_train = df[(df[\"day_id\"] >= train_mask[0]) & (df[\"day_id\"] < train_mask[1])]\n", + " df_valid = df[(df[\"day_id\"] >= valid_mask[0]) & (df[\"day_id\"] < valid_mask[1])]\n", + "\n", + " X_train, y_train = df_train.drop(columns=[\"item_id\", \"dept_id\", \"cat_id\", \"day_id\", \"sales\"]), df_train[\"sales\"]\n", + " X_valid = df_valid.drop(columns=[\"item_id\", \"dept_id\", \"cat_id\", \"day_id\", \"sales\"])\n", + "\n", + " clf = xgb.XGBRegressor(**params)\n", + " clf.fit(X_train, y_train)\n", + " pred_sales = clf.predict(X_valid)\n", + " scores[store_id].append(wrmsse(product_weights, df, pred_sales, train_mask, valid_mask))\n", + " del df_train, df_valid, X_train, y_train, clf\n", + " gc.collect()\n", + " del df\n", + " gc.collect()\n", + " \n", + " # We can sum WRMSSE scores over data segments because data segments contain disjoint sets of time series\n", + " return np.array(scores).sum(axis=0).mean()" + ] + }, + { + "cell_type": "markdown", + "id": "78b96b48-f94b-47c1-ac27-cfab20b8561c", + "metadata": {}, + "source": [ + "Using the Dask cluster client, we execute multiple training jobs in parallel. Optuna keeps track of the progress in the hyperparameter search using in-memory Dask storage." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "09af0be0-4a72-4dc2-bd60-ff1d8152385b", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_3102/456600745.py:7: ExperimentalWarning: DaskStorage is experimental (supported from v3.1.0). The interface can change in the future.\n", + " dask_storage = optuna.integration.DaskStorage(storage=backend_storage, client=client)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Testing hyperparameter combinations 0..3\n", + "Best cross-validation metric: 9.689218658589553, Time elapsed = 491.4758385729983\n", + "Testing hyperparameter combinations 3..6\n", + "Best cross-validation metric: 9.689218658589553, Time elapsed = 1047.8801612580028\n", + "Testing hyperparameter combinations 6..9\n", + "Best cross-validation metric: 9.689218658589553, Time elapsed = 1650.7563961980013\n", + "Total time elapsed = 1650.7610972189977\n" + ] + } + ], + "source": [ + "##### Number of hyperparameter combinations to try in parallel\n", + "n_trials = 9 # Using a small n_trials so that the demo can finish quickly\n", + "#n_trials = 100\n", + "\n", + "# Optimize in parallel on your Dask cluster\n", + "backend_storage = optuna.storages.InMemoryStorage()\n", + "dask_storage = optuna.integration.DaskStorage(storage=backend_storage, client=client)\n", + "study = optuna.create_study(direction=\"minimize\",\n", + " sampler=optuna.samplers.RandomSampler(seed=0),\n", + " storage=dask_storage)\n", + "futures = []\n", + "for i in range(0, n_trials, n_workers):\n", + " iter_range = (i, min([i + n_workers, n_trials]))\n", + " futures.append(\n", + " {\n", + " \"range\": iter_range,\n", + " \"futures\": [\n", + " client.submit(study.optimize, objective, n_trials=1, pure=False)\n", + " for _ in range(*iter_range)\n", + " ]\n", + " }\n", + " )\n", + "\n", + "tstart = time.perf_counter()\n", + "for partition in futures:\n", + " iter_range = partition[\"range\"]\n", + " print(f\"Testing hyperparameter combinations {iter_range[0]}..{iter_range[1]}\")\n", + " _ = wait(partition[\"futures\"])\n", + " for fut in partition[\"futures\"]:\n", + " _ = fut.result() # Ensure that the training job was successful\n", + " tnow = time.perf_counter()\n", + " print(f\"Best cross-validation metric: {study.best_value}, Time elapsed = {tnow - tstart}\")\n", + "tend = time.perf_counter()\n", + "print(f\"Total time elapsed = {tend - tstart}\")" + ] + }, + { + "cell_type": "markdown", + "id": "01e6da28-42e6-4e7e-9fa9-340539254e5e", + "metadata": {}, + "source": [ + "Once the hyperparameter search is complete, we fetch the optimal hyperparameter combination using the attributes of the `study` object." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "3d11f3ac-76f5-48d2-bbb9-109d38b1f261", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'lambda': 0.003077053443211648,\n", + " 'alpha': 0.14187101103672142,\n", + " 'colsample_bytree': 0.682210700857315,\n", + " 'max_depth': 4,\n", + " 'min_child_weight': 0.00017240426024865184,\n", + " 'gamma': 0.0014694435419424668,\n", + " 'tweedie_variance_power': 1.4375872112626924}" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "study.best_params" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "ac74f348-c1cd-4dbf-918b-52292d874a16", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "FrozenTrial(number=0, state=TrialState.COMPLETE, values=[9.689218658589553], datetime_start=datetime.datetime(2023, 7, 27, 1, 39, 4, 604443), datetime_complete=datetime.datetime(2023, 7, 27, 1, 47, 9, 804887), params={'lambda': 0.003077053443211648, 'alpha': 0.14187101103672142, 'colsample_bytree': 0.682210700857315, 'max_depth': 4, 'min_child_weight': 0.00017240426024865184, 'gamma': 0.0014694435419424668, 'tweedie_variance_power': 1.4375872112626924}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'lambda': FloatDistribution(high=100.0, log=True, low=1e-08, step=None), 'alpha': FloatDistribution(high=100.0, log=True, low=1e-08, step=None), 'colsample_bytree': FloatDistribution(high=1.0, log=False, low=0.2, step=None), 'max_depth': IntDistribution(high=6, log=False, low=2, step=1), 'min_child_weight': FloatDistribution(high=100.0, log=True, low=1e-08, step=None), 'gamma': FloatDistribution(high=1.0, log=True, low=1e-08, step=None), 'tweedie_variance_power': FloatDistribution(high=2.0, log=False, low=1.0, step=None)}, trial_id=0, value=None)" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "study.best_trial" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "dbc8d725-5aff-4339-bc45-d43f80aab454", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'lambda': 0.003077053443211648,\n", + " 'alpha': 0.14187101103672142,\n", + " 'colsample_bytree': 0.682210700857315,\n", + " 'max_depth': 4,\n", + " 'min_child_weight': 0.00017240426024865184,\n", + " 'gamma': 0.0014694435419424668,\n", + " 'tweedie_variance_power': 1.4375872112626924}" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Make a deep copy to preserve the dictionary after deleting the Dask cluster\n", + "best_params = copy.deepcopy(study.best_params)\n", + "best_params" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "aa630ad6-e011-4652-b51c-1f2b8ce6564f", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "fs = gcsfs.GCSFileSystem()\n", + "with fs.open(f\"{bucket_name}/params.json\", \"w\") as f:\n", + " json.dump(best_params, f)" + ] + }, + { + "cell_type": "markdown", + "id": "b5d02e80-c4e8-4445-bd0e-d05b9009c07f", + "metadata": {}, + "source": [ + "## Train the final XGBoost model and evaluate" + ] + }, + { + "cell_type": "markdown", + "id": "ca9a2b78-8e08-47d9-9136-76faa8d21761", + "metadata": {}, + "source": [ + "Using the optimal hyperparameters found in the search, fit a new model using the whole training data. As in the previous section, we fit a separate XGBoost model per data segment." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "3b665d27-8148-4959-ba6d-bdd54b18a311", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "fs = gcsfs.GCSFileSystem()\n", + "with fs.open(f\"{bucket_name}/params.json\", \"r\") as f:\n", + " best_params = json.load(f)\n", + "with fs.open(f\"{bucket_name}/product_weights.pkl\", \"rb\") as f:\n", + " product_weights = cudf.DataFrame(pd.read_pickle(f))" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "17c25588-b1e1-4a53-afe1-478bab418793", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def final_train(best_params):\n", + " fs = gcsfs.GCSFileSystem()\n", + " params = {\n", + " \"n_estimators\": 100,\n", + " \"verbosity\": 0,\n", + " \"learning_rate\": 0.01,\n", + " \"objective\": \"reg:tweedie\",\n", + " \"tree_method\": \"gpu_hist\",\n", + " \"grow_policy\": \"depthwise\",\n", + " \"predictor\": \"gpu_predictor\",\n", + " \"enable_categorical\": True,\n", + " }\n", + " params.update(best_params)\n", + " model = {}\n", + " train_mask = [0, 1914]\n", + "\n", + " for store in STORES:\n", + " print(f\"Processing store {store}...\")\n", + " with fs.open(f\"{bucket_name}/combined_df_store_{store}.pkl\", \"rb\") as f:\n", + " df = cudf.DataFrame(pd.read_pickle(f))\n", + "\n", + " df_train = df[(df[\"day_id\"] >= train_mask[0]) & (df[\"day_id\"] < train_mask[1])]\n", + " X_train, y_train = df_train.drop(columns=[\"item_id\", \"dept_id\", \"cat_id\", \"day_id\", \"sales\"]), df_train[\"sales\"]\n", + "\n", + " clf = xgb.XGBRegressor(**params)\n", + " clf.fit(X_train, y_train)\n", + " model[store] = clf\n", + " del df\n", + " gc.collect()\n", + " \n", + " return model" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "3170cd9d-cfb3-4de4-9eb8-80f0e96438c3", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Processing store CA_1...\n", + "Processing store CA_2...\n", + "Processing store CA_3...\n", + "Processing store CA_4...\n", + "Processing store TX_1...\n", + "Processing store TX_2...\n", + "Processing store TX_3...\n", + "Processing store WI_1...\n", + "Processing store WI_2...\n", + "Processing store WI_3...\n" + ] + } + ], + "source": [ + "model = final_train(best_params)" + ] + }, + { + "cell_type": "markdown", + "id": "fa83a711-879a-44be-bfc6-70eaf0691934", + "metadata": {}, + "source": [ + "Let's now evaluate the final model using the held-out test set:" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "7426afcd-5e22-4901-84fe-6155b172f8c4", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "WRMSSE metric on the held-out test set: 10.495262182826213\n" + ] + } + ], + "source": [ + "test_wrmsse = 0\n", + "for store in STORES:\n", + " with fs.open(f\"{bucket_name}/combined_df_store_{store}.pkl\", \"rb\") as f:\n", + " df = cudf.DataFrame(pd.read_pickle(f))\n", + " df_test = df[(df[\"day_id\"] >= holdout[0]) & (df[\"day_id\"] < holdout[1])]\n", + " X_test = df_test.drop(columns=[\"item_id\", \"dept_id\", \"cat_id\", \"day_id\", \"sales\"])\n", + " pred_sales = model[store].predict(X_test)\n", + " test_wrmsse += wrmsse(product_weights, df, pred_sales, train_mask=[0, 1914], valid_mask=holdout)\n", + "print(f\"WRMSSE metric on the held-out test set: {test_wrmsse}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "e6c804ce-3ac2-475e-a86a-3e6acf2f49ce", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Save the model to the Cloud Storage\n", + "with fs.open(f\"{bucket_name}/final_model.pkl\", \"wb\") as f:\n", + " pickle.dump(model, f)" + ] + }, + { + "cell_type": "markdown", + "id": "bb62e220-d69a-42f4-90ee-acee6ef6144a", + "metadata": {}, + "source": [ + "## Create an ensemble model using a different strategy for segmenting sales data" + ] + }, + { + "cell_type": "markdown", + "id": "d5714912-df6b-4686-83a0-2d6424e4d522", + "metadata": {}, + "source": [ + "It is common to create an **ensemble model** where multiple machine learning methods are used to obtain better predictive performance. Prediction is made from an ensemble model by averaging the prediction output of the constituent models.\n", + "\n", + "In this example, we will create a second model by segmenting the sales data in a different way. Instead of splitting by stores, we will split the data by both stores and product categories." + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "8461f1ca-4521-44f9-8a2b-0666b684b72d", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def objective_alt(trial):\n", + " fs = gcsfs.GCSFileSystem()\n", + " with fs.open(f\"{bucket_name}/product_weights.pkl\", \"rb\") as f:\n", + " product_weights = cudf.DataFrame(pd.read_pickle(f))\n", + " params = {\n", + " \"n_estimators\": 100,\n", + " \"verbosity\": 0,\n", + " \"learning_rate\": 0.01,\n", + " \"objective\": \"reg:tweedie\",\n", + " \"tree_method\": \"gpu_hist\",\n", + " \"grow_policy\": \"depthwise\",\n", + " \"predictor\": \"gpu_predictor\",\n", + " \"enable_categorical\": True,\n", + " \"lambda\": trial.suggest_float(\"lambda\", 1e-8, 100.0, log=True),\n", + " \"alpha\": trial.suggest_float(\"alpha\", 1e-8, 100.0, log=True),\n", + " \"colsample_bytree\": trial.suggest_float(\"colsample_bytree\", 0.2, 1.0),\n", + " \"max_depth\": trial.suggest_int(\"max_depth\", 2, 6, step=1),\n", + " \"min_child_weight\": trial.suggest_float(\"min_child_weight\", 1e-8, 100, log=True),\n", + " \"gamma\": trial.suggest_float(\"gamma\", 1e-8, 1.0, log=True),\n", + " \"tweedie_variance_power\": trial.suggest_float(\"tweedie_variance_power\", 1, 2),\n", + " }\n", + " scores = [[] for i in range(len(STORES) * len(DEPTS))]\n", + " \n", + " for store_id, store in enumerate(STORES):\n", + " for dept_id, dept in enumerate(DEPTS):\n", + " print(f\"Processing store {store}, department {dept}...\")\n", + " with fs.open(f\"{bucket_name}/combined_df_store_{store}_dept_{dept}.pkl\", \"rb\") as f:\n", + " df = cudf.DataFrame(pd.read_pickle(f))\n", + " for train_mask, valid_mask in cv_folds:\n", + " df_train = df[(df[\"day_id\"] >= train_mask[0]) & (df[\"day_id\"] < train_mask[1])]\n", + " df_valid = df[(df[\"day_id\"] >= valid_mask[0]) & (df[\"day_id\"] < valid_mask[1])]\n", + "\n", + " X_train, y_train = df_train.drop(columns=[\"item_id\", \"cat_id\", \"day_id\", \"sales\"]), df_train[\"sales\"]\n", + " X_valid = df_valid.drop(columns=[\"item_id\", \"cat_id\", \"day_id\", \"sales\"])\n", + "\n", + " clf = xgb.XGBRegressor(**params)\n", + " clf.fit(X_train, y_train)\n", + " sales_pred = clf.predict(X_valid)\n", + " scores[store_id * len(DEPTS) + dept_id].append(wrmsse(product_weights, df, sales_pred, train_mask, valid_mask))\n", + " del df_train, df_valid, X_train, y_train, clf\n", + " gc.collect()\n", + " del df\n", + " gc.collect()\n", + " \n", + " # We can sum WRMSSE scores over data segments because data segments contain disjoint sets of time series\n", + " return np.array(scores).sum(axis=0).mean()" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "3e5c34ba-9709-4c71-bfcc-dca214891a2f", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_3102/383703293.py:7: ExperimentalWarning: DaskStorage is experimental (supported from v3.1.0). The interface can change in the future.\n", + " dask_storage = optuna.integration.DaskStorage(storage=backend_storage, client=client)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Testing hyperparameter combinations 0..3\n", + "Best cross-validation metric: 9.657402162051978, Time elapsed = 663.513354638002\n", + "Testing hyperparameter combinations 3..6\n", + "Best cross-validation metric: 9.657402162051978, Time elapsed = 1379.8620550880005\n", + "Testing hyperparameter combinations 6..9\n", + "Best cross-validation metric: 9.657402162051978, Time elapsed = 2183.6284268570016\n", + "Total time elapsed = 2183.632464492999\n" + ] + } + ], + "source": [ + "##### Number of hyperparameter combinations to try in parallel\n", + "n_trials = 9 # Using a small n_trials so that the demo can finish quickly\n", + "#n_trials = 100\n", + "\n", + "# Optimize in parallel on your Dask cluster\n", + "backend_storage = optuna.storages.InMemoryStorage()\n", + "dask_storage = optuna.integration.DaskStorage(storage=backend_storage, client=client)\n", + "study = optuna.create_study(direction=\"minimize\",\n", + " sampler=optuna.samplers.RandomSampler(seed=0),\n", + " storage=dask_storage)\n", + "futures = []\n", + "for i in range(0, n_trials, n_workers):\n", + " iter_range = (i, min([i + n_workers, n_trials]))\n", + " futures.append(\n", + " {\n", + " \"range\": iter_range,\n", + " \"futures\": [\n", + " client.submit(study.optimize, objective_alt, n_trials=1, pure=False)\n", + " for _ in range(*iter_range)\n", + " ]\n", + " }\n", + " )\n", + "\n", + "tstart = time.perf_counter()\n", + "for partition in futures:\n", + " iter_range = partition[\"range\"]\n", + " print(f\"Testing hyperparameter combinations {iter_range[0]}..{iter_range[1]}\")\n", + " _ = wait(partition[\"futures\"])\n", + " for fut in partition[\"futures\"]:\n", + " _ = fut.result() # Ensure that the training job was successful\n", + " tnow = time.perf_counter()\n", + " print(f\"Best cross-validation metric: {study.best_value}, Time elapsed = {tnow - tstart}\")\n", + "tend = time.perf_counter()\n", + "print(f\"Total time elapsed = {tend - tstart}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "0ce96895-8ef8-4897-bf52-b6bac69188c2", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'lambda': 0.003077053443211648,\n", + " 'alpha': 0.14187101103672142,\n", + " 'colsample_bytree': 0.682210700857315,\n", + " 'max_depth': 4,\n", + " 'min_child_weight': 0.00017240426024865184,\n", + " 'gamma': 0.0014694435419424668,\n", + " 'tweedie_variance_power': 1.4375872112626924}" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Make a deep copy to preserve the dictionary after deleting the Dask cluster\n", + "best_params_alt = copy.deepcopy(study.best_params)\n", + "best_params_alt" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "9841bf88-24ff-4fda-991f-108291866a43", + "metadata": {}, + "outputs": [], + "source": [ + "fs = gcsfs.GCSFileSystem()\n", + "with fs.open(f\"{bucket_name}/params_alt.json\", \"w\") as f:\n", + " json.dump(best_params_alt, f)" + ] + }, + { + "cell_type": "markdown", + "id": "e04f712f-b7d6-4063-979d-22358da9fbc0", + "metadata": {}, + "source": [ + "Using the optimal hyperparameters found in the search, fit a new model using the whole training data." + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "44bbde76-93e6-4c69-bb12-d71e93e775db", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def final_train_alt(best_params):\n", + " fs = gcsfs.GCSFileSystem()\n", + " with fs.open(f\"{bucket_name}/product_weights.pkl\", \"rb\") as f:\n", + " product_weights = cudf.DataFrame(pd.read_pickle(f))\n", + " params = {\n", + " \"n_estimators\": 100,\n", + " \"verbosity\": 0,\n", + " \"learning_rate\": 0.01,\n", + " \"objective\": \"reg:tweedie\",\n", + " \"tree_method\": \"gpu_hist\",\n", + " \"grow_policy\": \"depthwise\",\n", + " \"predictor\": \"gpu_predictor\",\n", + " \"enable_categorical\": True,\n", + " }\n", + " params.update(best_params)\n", + " model = {}\n", + " train_mask = [0, 1914]\n", + " \n", + " for store_id, store in enumerate(STORES):\n", + " for dept_id, dept in enumerate(DEPTS):\n", + " print(f\"Processing store {store}, department {dept}...\")\n", + " with fs.open(f\"{bucket_name}/combined_df_store_{store}_dept_{dept}.pkl\", \"rb\") as f:\n", + " df = cudf.DataFrame(pd.read_pickle(f))\n", + " for train_mask, valid_mask in cv_folds:\n", + " df_train = df[(df[\"day_id\"] >= train_mask[0]) & (df[\"day_id\"] < train_mask[1])]\n", + " X_train, y_train = df_train.drop(columns=[\"item_id\", \"cat_id\", \"day_id\", \"sales\"]), df_train[\"sales\"]\n", + "\n", + " clf = xgb.XGBRegressor(**params)\n", + " clf.fit(X_train, y_train)\n", + " model[(store, dept)] = clf\n", + " del df\n", + " gc.collect()\n", + " \n", + " return model" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "id": "0af52a5c-4d7c-454c-90a9-c47f570d6457", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "fs = gcsfs.GCSFileSystem()\n", + "with fs.open(f\"{bucket_name}/params_alt.json\", \"r\") as f:\n", + " best_params_alt = json.load(f)\n", + "with fs.open(f\"{bucket_name}/product_weights.pkl\", \"rb\") as f:\n", + " product_weights = cudf.DataFrame(pd.read_pickle(f))" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "fdd3b3a4-9e5a-4ebe-b138-57774c8cacf9", + "metadata": { + "scrolled": true, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Processing store CA_1, department HOBBIES_1...\n", + "Processing store CA_1, department HOBBIES_2...\n", + "Processing store CA_1, department HOUSEHOLD_1...\n", + "Processing store CA_1, department HOUSEHOLD_2...\n", + "Processing store CA_1, department FOODS_1...\n", + "Processing store CA_1, department FOODS_2...\n", + "Processing store CA_1, department FOODS_3...\n", + "Processing store CA_2, department HOBBIES_1...\n", + "Processing store CA_2, department HOBBIES_2...\n", + "Processing store CA_2, department HOUSEHOLD_1...\n", + "Processing store CA_2, department HOUSEHOLD_2...\n", + "Processing store CA_2, department FOODS_1...\n", + "Processing store CA_2, department FOODS_2...\n", + "Processing store CA_2, department FOODS_3...\n", + "Processing store CA_3, department HOBBIES_1...\n", + "Processing store CA_3, department HOBBIES_2...\n", + "Processing store CA_3, department HOUSEHOLD_1...\n", + "Processing store CA_3, department HOUSEHOLD_2...\n", + "Processing store CA_3, department FOODS_1...\n", + "Processing store CA_3, department FOODS_2...\n", + "Processing store CA_3, department FOODS_3...\n", + "Processing store CA_4, department HOBBIES_1...\n", + "Processing store CA_4, department HOBBIES_2...\n", + "Processing store CA_4, department HOUSEHOLD_1...\n", + "Processing store CA_4, department HOUSEHOLD_2...\n", + "Processing store CA_4, department FOODS_1...\n", + "Processing store CA_4, department FOODS_2...\n", + "Processing store CA_4, department FOODS_3...\n", + "Processing store TX_1, department HOBBIES_1...\n", + "Processing store TX_1, department HOBBIES_2...\n", + "Processing store TX_1, department HOUSEHOLD_1...\n", + "Processing store TX_1, department HOUSEHOLD_2...\n", + "Processing store TX_1, department FOODS_1...\n", + "Processing store TX_1, department FOODS_2...\n", + "Processing store TX_1, department FOODS_3...\n", + "Processing store TX_2, department HOBBIES_1...\n", + "Processing store TX_2, department HOBBIES_2...\n", + "Processing store TX_2, department HOUSEHOLD_1...\n", + "Processing store TX_2, department HOUSEHOLD_2...\n", + "Processing store TX_2, department FOODS_1...\n", + "Processing store TX_2, department FOODS_2...\n", + "Processing store TX_2, department FOODS_3...\n", + "Processing store TX_3, department HOBBIES_1...\n", + "Processing store TX_3, department HOBBIES_2...\n", + "Processing store TX_3, department HOUSEHOLD_1...\n", + "Processing store TX_3, department HOUSEHOLD_2...\n", + "Processing store TX_3, department FOODS_1...\n", + "Processing store TX_3, department FOODS_2...\n", + "Processing store TX_3, department FOODS_3...\n", + "Processing store WI_1, department HOBBIES_1...\n", + "Processing store WI_1, department HOBBIES_2...\n", + "Processing store WI_1, department HOUSEHOLD_1...\n", + "Processing store WI_1, department HOUSEHOLD_2...\n", + "Processing store WI_1, department FOODS_1...\n", + "Processing store WI_1, department FOODS_2...\n", + "Processing store WI_1, department FOODS_3...\n", + "Processing store WI_2, department HOBBIES_1...\n", + "Processing store WI_2, department HOBBIES_2...\n", + "Processing store WI_2, department HOUSEHOLD_1...\n", + "Processing store WI_2, department HOUSEHOLD_2...\n", + "Processing store WI_2, department FOODS_1...\n", + "Processing store WI_2, department FOODS_2...\n", + "Processing store WI_2, department FOODS_3...\n", + "Processing store WI_3, department HOBBIES_1...\n", + "Processing store WI_3, department HOBBIES_2...\n", + "Processing store WI_3, department HOUSEHOLD_1...\n", + "Processing store WI_3, department HOUSEHOLD_2...\n", + "Processing store WI_3, department FOODS_1...\n", + "Processing store WI_3, department FOODS_2...\n", + "Processing store WI_3, department FOODS_3...\n" + ] + } + ], + "source": [ + "model_alt = final_train_alt(best_params_alt)" + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "id": "b00cf96a-8589-4ee6-8c29-51a5aa81af2c", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Save the model to the Cloud Storage\n", + "with fs.open(f\"{bucket_name}/final_model_alt.pkl\", \"wb\") as f:\n", + " pickle.dump(model_alt, f)" + ] + }, + { + "cell_type": "markdown", + "id": "49fe134a-ebd8-4322-887f-ffcb97a380a2", + "metadata": {}, + "source": [ + "Now consider an ensemble consisting of the two models `model` and `model_alt`. We evaluate the ensemble by computing the WRMSSE metric for the average of the predictions of the two models." + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "id": "99495acc-0355-40cf-935c-91658a6909ea", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Processing store CA_1...\n", + "Processing store CA_2...\n", + "Processing store CA_3...\n", + "Processing store CA_4...\n", + "Processing store TX_1...\n", + "Processing store TX_2...\n", + "Processing store TX_3...\n", + "Processing store WI_1...\n", + "Processing store WI_2...\n", + "Processing store WI_3...\n", + "WRMSSE metric on the held-out test set: 11.055364531163706\n" + ] + } + ], + "source": [ + "test_wrmsse = 0\n", + "for store in STORES:\n", + " print(f\"Processing store {store}...\")\n", + " # Prediction from Model 1\n", + " with fs.open(f\"{bucket_name}/combined_df_store_{store}.pkl\", \"rb\") as f:\n", + " df = cudf.DataFrame(pd.read_pickle(f))\n", + " df_test = df[(df[\"day_id\"] >= holdout[0]) & (df[\"day_id\"] < holdout[1])]\n", + " X_test = df_test.drop(columns=[\"item_id\", \"dept_id\", \"cat_id\", \"day_id\", \"sales\"])\n", + " df_test[\"pred1\"] = model[store].predict(X_test)\n", + " \n", + " # Prediction from Model 2\n", + " df_test[\"pred2\"] = [np.nan] * len(df_test)\n", + " df_test[\"pred2\"] = df_test[\"pred2\"].astype(\"float32\")\n", + " for dept in DEPTS:\n", + " with fs.open(f\"{bucket_name}/combined_df_store_{store}_dept_{dept}.pkl\", \"rb\") as f:\n", + " df2 = cudf.DataFrame(pd.read_pickle(f))\n", + " df2_test = df2[(df2[\"day_id\"] >= holdout[0]) & (df2[\"day_id\"] < holdout[1])]\n", + " X_test = df2_test.drop(columns=[\"item_id\", \"cat_id\", \"day_id\", \"sales\"])\n", + " assert np.sum(df_test[\"dept_id\"] == dept) == len(X_test)\n", + " df_test[\"pred2\"][df_test[\"dept_id\"] == dept] = model_alt[(store, dept)].predict(X_test)\n", + " \n", + " # Average prediction\n", + " df_test[\"avg_pred\"] = (df_test[\"pred1\"] + df_test[\"pred2\"]) / 2.0\n", + "\n", + " test_wrmsse += wrmsse(product_weights, df, df_test[\"avg_pred\"],\n", + " train_mask=[0, 1914], valid_mask=holdout)\n", + "print(f\"WRMSSE metric on the held-out test set: {test_wrmsse}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fda530fd-9d8e-4200-bc08-20432f4ffa53", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Close the Dask cluster to clean up\n", + "cluster.close()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ae38a58b-5c5e-41b7-adcb-06cbfd81bbfc", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From 0d121b714d5465836ab78d61be65fe29769229f7 Mon Sep 17 00:00:00 2001 From: Hyunsu Philip Cho Date: Wed, 2 Aug 2023 13:40:16 -0700 Subject: [PATCH 02/11] Fix formatting --- .../preprocessing_part1.ipynb | 35 ++- .../preprocessing_part2.ipynb | 59 +++-- .../preprocessing_part3.ipynb | 27 +- .../preprocessing_part4.ipynb | 17 +- .../preprocessing_part5.ipynb | 13 +- .../preprocessing_part6.ipynb | 134 ++++++++-- .../training_and_evaluation.ipynb | 237 ++++++++++++------ 7 files changed, 386 insertions(+), 136 deletions(-) diff --git a/source/examples/time-series-forecasting-with-hpo/preprocessing_part1.ipynb b/source/examples/time-series-forecasting-with-hpo/preprocessing_part1.ipynb index 985f1c98..a4347733 100644 --- a/source/examples/time-series-forecasting-with-hpo/preprocessing_part1.ipynb +++ b/source/examples/time-series-forecasting-with-hpo/preprocessing_part1.ipynb @@ -49,8 +49,13 @@ " num /= 1024.0\n", " return \"%.1f%s%s\" % (num, \"Yi\", suffix)\n", "\n", + "\n", "def report_dataframe_size(df, name):\n", - " print(\"{} takes up {} memory on GPU\".format(name, sizeof_fmt(grid_df.memory_usage(index=True).sum())))" + " print(\n", + " \"{} takes up {} memory on GPU\".format(\n", + " name, sizeof_fmt(grid_df.memory_usage(index=True).sum())\n", + " )\n", + " )" ] }, { @@ -98,7 +103,9 @@ "source": [ "train_df = cudf.read_csv(raw_data_dir + \"sales_train_evaluation.csv\")\n", "prices_df = cudf.read_csv(raw_data_dir + \"sell_prices.csv\")\n", - "calendar_df = cudf.read_csv(raw_data_dir + \"calendar.csv\").rename(columns={\"d\": \"day_id\"})" + "calendar_df = cudf.read_csv(raw_data_dir + \"calendar.csv\").rename(\n", + " columns={\"d\": \"day_id\"}\n", + ")" ] }, { @@ -1136,7 +1143,9 @@ ], "source": [ "index_columns = [\"id\", \"item_id\", \"dept_id\", \"cat_id\", \"store_id\", \"state_id\"]\n", - "grid_df = cudf.melt(train_df, id_vars=index_columns, var_name=\"day_id\", value_name=TARGET)\n", + "grid_df = cudf.melt(\n", + " train_df, id_vars=index_columns, var_name=\"day_id\", value_name=TARGET\n", + ")\n", "grid_df" ] }, @@ -1357,11 +1366,15 @@ " temp_df[\"day_id\"] = \"d_\" + str(END_TRAIN + i)\n", " temp_df[TARGET] = np.nan # Sales amount at time (n + i) is unknown\n", " add_grid = cudf.concat([add_grid, temp_df])\n", - "add_grid[\"day_id\"] = add_grid[\"day_id\"].astype(\"category\") # The day_id column is categorical, after cudf.melt\n", + "add_grid[\"day_id\"] = add_grid[\"day_id\"].astype(\n", + " \"category\"\n", + ") # The day_id column is categorical, after cudf.melt\n", "\n", "grid_df = cudf.concat([grid_df, add_grid])\n", "grid_df = grid_df.reset_index(drop=True)\n", - "grid_df[\"sales\"] = grid_df[\"sales\"].astype(np.float32) # Use float32 type for sales column, to conserve memory\n", + "grid_df[\"sales\"] = grid_df[\"sales\"].astype(\n", + " np.float32\n", + ") # Use float32 type for sales column, to conserve memory\n", "grid_df" ] }, @@ -1820,7 +1833,9 @@ } ], "source": [ - "release_df = prices_df.groupby([\"store_id\", \"item_id\"])[\"wm_yr_wk\"].agg(\"min\").reset_index()\n", + "release_df = (\n", + " prices_df.groupby([\"store_id\", \"item_id\"])[\"wm_yr_wk\"].agg(\"min\").reset_index()\n", + ")\n", "release_df.columns = [\"store_id\", \"item_id\", \"release_week\"]\n", "release_df" ] @@ -2858,7 +2873,9 @@ ], "source": [ "grid_df = grid_df[grid_df[\"wm_yr_wk\"] >= grid_df[\"release_week\"]].reset_index(drop=True)\n", - "grid_df[\"wm_yr_wk\"] = grid_df[\"wm_yr_wk\"].astype(np.int32) # Convert wm_yr_wk column to int32, to conserve memory\n", + "grid_df[\"wm_yr_wk\"] = grid_df[\"wm_yr_wk\"].astype(\n", + " np.int32\n", + ") # Convert wm_yr_wk column to int32, to conserve memory\n", "grid_df" ] }, @@ -3013,7 +3030,9 @@ "\n", "# Compute the total sales over the latest 28 days, per product item\n", "last28 = grid_df[(grid_df[\"day_id_int\"] >= 1914) & (grid_df[\"day_id_int\"] < 1942)]\n", - "last28 = last28[[\"item_id\", \"wm_yr_wk\", \"sales\"]].merge(prices_df[[\"item_id\", \"wm_yr_wk\", \"sell_price\"]], on=[\"item_id\", \"wm_yr_wk\"])\n", + "last28 = last28[[\"item_id\", \"wm_yr_wk\", \"sales\"]].merge(\n", + " prices_df[[\"item_id\", \"wm_yr_wk\", \"sell_price\"]], on=[\"item_id\", \"wm_yr_wk\"]\n", + ")\n", "last28[\"sales_usd\"] = last28[\"sales\"] * last28[\"sell_price\"]\n", "total_sales_usd = last28.groupby(\"item_id\")[[\"sales_usd\"]].agg([\"sum\"]).sort_index()\n", "total_sales_usd.columns = total_sales_usd.columns.map(\"_\".join)\n", diff --git a/source/examples/time-series-forecasting-with-hpo/preprocessing_part2.ipynb b/source/examples/time-series-forecasting-with-hpo/preprocessing_part2.ipynb index 2142a6bf..2c3ba598 100644 --- a/source/examples/time-series-forecasting-with-hpo/preprocessing_part2.ipynb +++ b/source/examples/time-series-forecasting-with-hpo/preprocessing_part2.ipynb @@ -56,7 +56,9 @@ "outputs": [], "source": [ "prices_df = cudf.read_csv(raw_data_dir + \"sell_prices.csv\")\n", - "calendar_df = cudf.read_csv(raw_data_dir + \"calendar.csv\").rename(columns={\"d\": \"day_id\"})" + "calendar_df = cudf.read_csv(raw_data_dir + \"calendar.csv\").rename(\n", + " columns={\"d\": \"day_id\"}\n", + ")" ] }, { @@ -513,13 +515,21 @@ "outputs": [], "source": [ "# Highest price over all weeks\n", - "prices_df[\"price_max\"] = prices_df.groupby([\"store_id\", \"item_id\"])[\"sell_price\"].transform(\"max\")\n", + "prices_df[\"price_max\"] = prices_df.groupby([\"store_id\", \"item_id\"])[\n", + " \"sell_price\"\n", + "].transform(\"max\")\n", "# Lowest price over all weeks\n", - "prices_df[\"price_min\"] = prices_df.groupby([\"store_id\", \"item_id\"])[\"sell_price\"].transform(\"min\")\n", + "prices_df[\"price_min\"] = prices_df.groupby([\"store_id\", \"item_id\"])[\n", + " \"sell_price\"\n", + "].transform(\"min\")\n", "# Standard deviation of the price\n", - "prices_df[\"price_std\"] = prices_df.groupby([\"store_id\", \"item_id\"])[\"sell_price\"].transform(\"std\")\n", + "prices_df[\"price_std\"] = prices_df.groupby([\"store_id\", \"item_id\"])[\n", + " \"sell_price\"\n", + "].transform(\"std\")\n", "# Mean (average) price over all weeks\n", - "prices_df[\"price_mean\"] = prices_df.groupby([\"store_id\", \"item_id\"])[\"sell_price\"].transform(\"mean\")" + "prices_df[\"price_mean\"] = prices_df.groupby([\"store_id\", \"item_id\"])[\n", + " \"sell_price\"\n", + "].transform(\"mean\")" ] }, { @@ -555,7 +565,9 @@ "metadata": {}, "outputs": [], "source": [ - "prices_df[\"price_nunique\"] = prices_df.groupby([\"store_id\", \"item_id\"])[\"sell_price\"].transform(\"nunique\")" + "prices_df[\"price_nunique\"] = prices_df.groupby([\"store_id\", \"item_id\"])[\n", + " \"sell_price\"\n", + "].transform(\"nunique\")" ] }, { @@ -573,7 +585,9 @@ "metadata": {}, "outputs": [], "source": [ - "prices_df[\"item_nunique\"] = prices_df.groupby([\"store_id\", \"sell_price\"])[\"item_id\"].transform(\"nunique\")" + "prices_df[\"item_nunique\"] = prices_df.groupby([\"store_id\", \"sell_price\"])[\n", + " \"item_id\"\n", + "].transform(\"nunique\")" ] }, { @@ -831,7 +845,9 @@ "outputs": [], "source": [ "# Add \"month\" and \"year\" columns to prices_df\n", - "week_to_month_map = calendar_df[[\"wm_yr_wk\", \"month\", \"year\"]].drop_duplicates(subset=[\"wm_yr_wk\"])\n", + "week_to_month_map = calendar_df[[\"wm_yr_wk\", \"month\", \"year\"]].drop_duplicates(\n", + " subset=[\"wm_yr_wk\"]\n", + ")\n", "prices_df = prices_df.merge(week_to_month_map, on=[\"wm_yr_wk\"], how=\"left\")\n", "\n", "# Sort by wm_yr_wk. The rows will also be sorted in ascending months and years.\n", @@ -846,11 +862,17 @@ "outputs": [], "source": [ "# Compare with the average price in the previous week\n", - "prices_df[\"price_momentum\"] = prices_df[\"sell_price\"] / prices_df.groupby([\"store_id\", \"item_id\"])[\"sell_price\"].shift(1)\n", + "prices_df[\"price_momentum\"] = prices_df[\"sell_price\"] / prices_df.groupby(\n", + " [\"store_id\", \"item_id\"]\n", + ")[\"sell_price\"].shift(1)\n", "# Compare with the average price in the previous month\n", - "prices_df[\"price_momentum_m\"] = prices_df[\"sell_price\"] / prices_df.groupby([\"store_id\", \"item_id\", \"month\"])[\"sell_price\"].transform(\"mean\")\n", + "prices_df[\"price_momentum_m\"] = prices_df[\"sell_price\"] / prices_df.groupby(\n", + " [\"store_id\", \"item_id\", \"month\"]\n", + ")[\"sell_price\"].transform(\"mean\")\n", "# Compare with the average price in the previous year\n", - "prices_df[\"price_momentum_y\"] = prices_df[\"sell_price\"] / prices_df.groupby([\"store_id\", \"item_id\", \"year\"])[\"sell_price\"].transform(\"mean\")" + "prices_df[\"price_momentum_y\"] = prices_df[\"sell_price\"] / prices_df.groupby(\n", + " [\"store_id\", \"item_id\", \"year\"]\n", + ")[\"sell_price\"].transform(\"mean\")" ] }, { @@ -865,8 +887,15 @@ "\n", "# Convert float64 columns into float32 type to save memory\n", "columns = [\n", - " \"sell_price\", \"price_max\", \"price_min\", \"price_std\", \"price_mean\",\n", - " \"price_norm\", \"price_momentum\", \"price_momentum_m\", \"price_momentum_y\"\n", + " \"sell_price\",\n", + " \"price_max\",\n", + " \"price_min\",\n", + " \"price_std\",\n", + " \"price_mean\",\n", + " \"price_norm\",\n", + " \"price_momentum\",\n", + " \"price_momentum_m\",\n", + " \"price_momentum_y\",\n", "]\n", "for col in columns:\n", " prices_df[col] = prices_df[col].astype(np.float32)" @@ -1209,7 +1238,9 @@ "# After merging price_df, keep columns id and day_id from grid_df and drop all other columns from grid_df\n", "original_columns = list(grid_df)\n", "grid_df = grid_df.merge(prices_df, on=[\"store_id\", \"item_id\", \"wm_yr_wk\"], how=\"left\")\n", - "columns_to_keep = [\"id\", \"day_id\"] + [col for col in list(grid_df) if col not in original_columns]\n", + "columns_to_keep = [\"id\", \"day_id\"] + [\n", + " col for col in list(grid_df) if col not in original_columns\n", + "]\n", "grid_df = grid_df[[\"id\", \"day_id\"] + columns_to_keep]\n", "grid_df" ] diff --git a/source/examples/time-series-forecasting-with-hpo/preprocessing_part3.ipynb b/source/examples/time-series-forecasting-with-hpo/preprocessing_part3.ipynb index 72d168b4..e9f43755 100644 --- a/source/examples/time-series-forecasting-with-hpo/preprocessing_part3.ipynb +++ b/source/examples/time-series-forecasting-with-hpo/preprocessing_part3.ipynb @@ -56,7 +56,9 @@ "metadata": {}, "outputs": [], "source": [ - "calendar_df = cudf.read_csv(raw_data_dir + \"calendar.csv\").rename(columns={\"d\": \"day_id\"})" + "calendar_df = cudf.read_csv(raw_data_dir + \"calendar.csv\").rename(\n", + " columns={\"d\": \"day_id\"}\n", + ")" ] }, { @@ -443,8 +445,15 @@ "outputs": [], "source": [ "# Convert columns into categorical type to save memory\n", - "for col in [\"event_name_1\", \"event_type_1\", \"event_name_2\", \"event_type_2\",\n", - " \"snap_CA\", \"snap_TX\", \"snap_WI\"]:\n", + "for col in [\n", + " \"event_name_1\",\n", + " \"event_type_1\",\n", + " \"event_name_2\",\n", + " \"event_type_2\",\n", + " \"snap_CA\",\n", + " \"snap_TX\",\n", + " \"snap_WI\",\n", + "]:\n", " grid_df[col] = grid_df[col].astype(\"category\")\n", "# Convert \"date\" column into timestamp type\n", "grid_df[\"date\"] = cudf.to_datetime(grid_df[\"date\"])" @@ -470,9 +479,15 @@ "grid_df[\"tm_m\"] = grid_df[\"date\"].dt.month.astype(np.int8)\n", "grid_df[\"tm_y\"] = grid_df[\"date\"].dt.year\n", "grid_df[\"tm_y\"] = (grid_df[\"tm_y\"] - grid_df[\"tm_y\"].min()).astype(np.int8)\n", - "grid_df[\"tm_wm\"] = cp.ceil(grid_df[\"tm_d\"].to_cupy() / 7).astype(np.int8) # which week in tje month?\n", - "grid_df[\"tm_dw\"] = grid_df[\"date\"].dt.dayofweek.astype(np.int8) # which day in the week?\n", - "grid_df[\"tm_w_end\"] = (grid_df[\"tm_dw\"] >= 5).astype(np.int8) # whether today is in the weekend\n", + "grid_df[\"tm_wm\"] = cp.ceil(grid_df[\"tm_d\"].to_cupy() / 7).astype(\n", + " np.int8\n", + ") # which week in tje month?\n", + "grid_df[\"tm_dw\"] = grid_df[\"date\"].dt.dayofweek.astype(\n", + " np.int8\n", + ") # which day in the week?\n", + "grid_df[\"tm_w_end\"] = (grid_df[\"tm_dw\"] >= 5).astype(\n", + " np.int8\n", + ") # whether today is in the weekend\n", "del grid_df[\"date\"] # no longer needed" ] }, diff --git a/source/examples/time-series-forecasting-with-hpo/preprocessing_part4.ipynb b/source/examples/time-series-forecasting-with-hpo/preprocessing_part4.ipynb index c4f92f32..5f101042 100644 --- a/source/examples/time-series-forecasting-with-hpo/preprocessing_part4.ipynb +++ b/source/examples/time-series-forecasting-with-hpo/preprocessing_part4.ipynb @@ -89,10 +89,7 @@ "grid_df = grid_df.sort_values([\"id\", \"day_id\"])\n", "\n", "grid_df = grid_df.assign(\n", - " **{\n", - " f\"sales_lag_{l}\": grid_df.groupby([\"id\"])[\"sales\"].shift(l)\n", - " for l in LAG_DAYS\n", - " }\n", + " **{f\"sales_lag_{l}\": grid_df.groupby([\"id\"])[\"sales\"].shift(l) for l in LAG_DAYS}\n", ")" ] }, @@ -488,10 +485,18 @@ "for i in [7, 14, 30, 60, 180]:\n", " print(f\" Window size: {i}\")\n", " grid_df[f\"rolling_mean_{i}\"] = (\n", - " grid_df.groupby([\"id\"])[\"sales\"].shift(SHIFT_DAY).rolling(i).mean().astype(np.float32)\n", + " grid_df.groupby([\"id\"])[\"sales\"]\n", + " .shift(SHIFT_DAY)\n", + " .rolling(i)\n", + " .mean()\n", + " .astype(np.float32)\n", " )\n", " grid_df[f\"rolling_std_{i}\"] = (\n", - " grid_df.groupby([\"id\"])[\"sales\"].shift(SHIFT_DAY).rolling(i).std().astype(np.float32)\n", + " grid_df.groupby([\"id\"])[\"sales\"]\n", + " .shift(SHIFT_DAY)\n", + " .rolling(i)\n", + " .std()\n", + " .astype(np.float32)\n", " )" ] }, diff --git a/source/examples/time-series-forecasting-with-hpo/preprocessing_part5.ipynb b/source/examples/time-series-forecasting-with-hpo/preprocessing_part5.ipynb index a17b3d95..4cdb522e 100644 --- a/source/examples/time-series-forecasting-with-hpo/preprocessing_part5.ipynb +++ b/source/examples/time-series-forecasting-with-hpo/preprocessing_part5.ipynb @@ -310,17 +310,18 @@ } ], "source": [ - "icols = [\n", - " [\"store_id\", \"dept_id\"],\n", - " [\"item_id\", \"state_id\"]\n", - "]\n", + "icols = [[\"store_id\", \"dept_id\"], [\"item_id\", \"state_id\"]]\n", "new_columns = []\n", "\n", "for col in icols:\n", " print(f\"Encoding columns {col}\")\n", " col_name = \"_\" + \"_\".join(col) + \"_\"\n", - " grid_df[\"enc\" + col_name + \"mean\"] = grid_df.groupby(col)[\"sales\"].transform(\"mean\").astype(np.float32)\n", - " grid_df[\"enc\" + col_name + \"std\"] = grid_df.groupby(col)[\"sales\"].transform(\"std\").astype(np.float32)\n", + " grid_df[\"enc\" + col_name + \"mean\"] = (\n", + " grid_df.groupby(col)[\"sales\"].transform(\"mean\").astype(np.float32)\n", + " )\n", + " grid_df[\"enc\" + col_name + \"std\"] = (\n", + " grid_df.groupby(col)[\"sales\"].transform(\"std\").astype(np.float32)\n", + " )\n", " new_columns.extend([\"enc\" + col_name + \"mean\", \"enc\" + col_name + \"std\"])" ] }, diff --git a/source/examples/time-series-forecasting-with-hpo/preprocessing_part6.ipynb b/source/examples/time-series-forecasting-with-hpo/preprocessing_part6.ipynb index 028d28e3..149c77ca 100644 --- a/source/examples/time-series-forecasting-with-hpo/preprocessing_part6.ipynb +++ b/source/examples/time-series-forecasting-with-hpo/preprocessing_part6.ipynb @@ -83,28 +83,92 @@ "segmented_data_dir = \"./segmented_data/\"\n", "pathlib.Path(segmented_data_dir).mkdir(exist_ok=True)\n", "\n", - "STORES = [\"CA_1\", \"CA_2\", \"CA_3\", \"CA_4\", \"TX_1\", \"TX_2\", \"TX_3\", \"WI_1\", \"WI_2\", \"WI_3\"]\n", - "DEPTS = [\"HOBBIES_1\", \"HOBBIES_2\", \"HOUSEHOLD_1\", \"HOUSEHOLD_2\", \"FOODS_1\", \"FOODS_2\", \"FOODS_3\"]\n", + "STORES = [\n", + " \"CA_1\",\n", + " \"CA_2\",\n", + " \"CA_3\",\n", + " \"CA_4\",\n", + " \"TX_1\",\n", + " \"TX_2\",\n", + " \"TX_3\",\n", + " \"WI_1\",\n", + " \"WI_2\",\n", + " \"WI_3\",\n", + "]\n", + "DEPTS = [\n", + " \"HOBBIES_1\",\n", + " \"HOBBIES_2\",\n", + " \"HOUSEHOLD_1\",\n", + " \"HOUSEHOLD_2\",\n", + " \"FOODS_1\",\n", + " \"FOODS_2\",\n", + " \"FOODS_3\",\n", + "]\n", "\n", - "grid2_colnm = [\"sell_price\", \"price_max\", \"price_min\", \"price_std\",\n", - " \"price_mean\", \"price_norm\", \"price_nunique\", \"item_nunique\",\n", - " \"price_momentum\", \"price_momentum_m\", \"price_momentum_y\"]\n", + "grid2_colnm = [\n", + " \"sell_price\",\n", + " \"price_max\",\n", + " \"price_min\",\n", + " \"price_std\",\n", + " \"price_mean\",\n", + " \"price_norm\",\n", + " \"price_nunique\",\n", + " \"item_nunique\",\n", + " \"price_momentum\",\n", + " \"price_momentum_m\",\n", + " \"price_momentum_y\",\n", + "]\n", "\n", - "grid3_colnm = [\"event_name_1\", \"event_type_1\", \"event_name_2\",\n", - " \"event_type_2\", \"snap_CA\", \"snap_TX\", \"snap_WI\", \"tm_d\", \"tm_w\", \"tm_m\",\n", - " \"tm_y\", \"tm_wm\", \"tm_dw\", \"tm_w_end\"]\n", + "grid3_colnm = [\n", + " \"event_name_1\",\n", + " \"event_type_1\",\n", + " \"event_name_2\",\n", + " \"event_type_2\",\n", + " \"snap_CA\",\n", + " \"snap_TX\",\n", + " \"snap_WI\",\n", + " \"tm_d\",\n", + " \"tm_w\",\n", + " \"tm_m\",\n", + " \"tm_y\",\n", + " \"tm_wm\",\n", + " \"tm_dw\",\n", + " \"tm_w_end\",\n", + "]\n", "\n", - "lag_colnm = [\"sales_lag_28\", \"sales_lag_29\", \"sales_lag_30\",\n", - " \"sales_lag_31\", \"sales_lag_32\", \"sales_lag_33\", \"sales_lag_34\",\n", - " \"sales_lag_35\", \"sales_lag_36\", \"sales_lag_37\", \"sales_lag_38\",\n", - " \"sales_lag_39\", \"sales_lag_40\", \"sales_lag_41\", \"sales_lag_42\",\n", - " \"rolling_mean_7\", \"rolling_std_7\", \"rolling_mean_14\", \"rolling_std_14\",\n", - " \"rolling_mean_30\", \"rolling_std_30\", \"rolling_mean_60\",\n", - " \"rolling_std_60\", \"rolling_mean_180\", \"rolling_std_180\"]\n", + "lag_colnm = [\n", + " \"sales_lag_28\",\n", + " \"sales_lag_29\",\n", + " \"sales_lag_30\",\n", + " \"sales_lag_31\",\n", + " \"sales_lag_32\",\n", + " \"sales_lag_33\",\n", + " \"sales_lag_34\",\n", + " \"sales_lag_35\",\n", + " \"sales_lag_36\",\n", + " \"sales_lag_37\",\n", + " \"sales_lag_38\",\n", + " \"sales_lag_39\",\n", + " \"sales_lag_40\",\n", + " \"sales_lag_41\",\n", + " \"sales_lag_42\",\n", + " \"rolling_mean_7\",\n", + " \"rolling_std_7\",\n", + " \"rolling_mean_14\",\n", + " \"rolling_std_14\",\n", + " \"rolling_mean_30\",\n", + " \"rolling_std_30\",\n", + " \"rolling_mean_60\",\n", + " \"rolling_std_60\",\n", + " \"rolling_mean_180\",\n", + " \"rolling_std_180\",\n", + "]\n", "\n", "target_enc_colnm = [\n", - " \"enc_store_id_dept_id_mean\", \"enc_store_id_dept_id_std\",\n", - " \"enc_item_id_state_id_mean\", \"enc_item_id_state_id_std\",\n", + " \"enc_store_id_dept_id_mean\",\n", + " \"enc_store_id_dept_id_std\",\n", + " \"enc_item_id_state_id_mean\",\n", + " \"enc_item_id_state_id_std\",\n", "]" ] }, @@ -133,30 +197,46 @@ " if dept is None:\n", " grid1 = grid1[grid1[\"store_id\"] == store]\n", " else:\n", - " grid1 = grid1[(grid1[\"store_id\"] == store) & (grid1[\"dept_id\"] == dept)].drop(columns=[\"dept_id\"])\n", + " grid1 = grid1[(grid1[\"store_id\"] == store) & (grid1[\"dept_id\"] == dept)].drop(\n", + " columns=[\"dept_id\"]\n", + " )\n", " grid1 = grid1.drop(columns=[\"release_week\", \"wm_yr_wk\", \"store_id\", \"state_id\"])\n", "\n", - " grid2 = cudf.DataFrame(pd.read_pickle(processed_data_dir + \"grid_df_part2.pkl\"))[[\"id\", \"day_id\"] + grid2_colnm]\n", + " grid2 = cudf.DataFrame(pd.read_pickle(processed_data_dir + \"grid_df_part2.pkl\"))[\n", + " [\"id\", \"day_id\"] + grid2_colnm\n", + " ]\n", " grid_df = grid1.merge(grid2, on=[\"id\", \"day_id\"], how=\"left\")\n", " del grid1, grid2\n", "\n", - " grid3 = cudf.DataFrame(pd.read_pickle(processed_data_dir + \"grid_df_part3.pkl\"))[[\"id\", \"day_id\"] + grid3_colnm]\n", + " grid3 = cudf.DataFrame(pd.read_pickle(processed_data_dir + \"grid_df_part3.pkl\"))[\n", + " [\"id\", \"day_id\"] + grid3_colnm\n", + " ]\n", " grid_df = grid_df.merge(grid3, on=[\"id\", \"day_id\"], how=\"left\")\n", " del grid3\n", "\n", - " lag_df = cudf.DataFrame(pd.read_pickle(processed_data_dir + \"lags_df_28.pkl\"))[[\"id\", \"day_id\"] + lag_colnm]\n", + " lag_df = cudf.DataFrame(pd.read_pickle(processed_data_dir + \"lags_df_28.pkl\"))[\n", + " [\"id\", \"day_id\"] + lag_colnm\n", + " ]\n", "\n", " grid_df = grid_df.merge(lag_df, on=[\"id\", \"day_id\"], how=\"left\")\n", " del lag_df\n", "\n", - " target_enc_df = cudf.DataFrame(pd.read_pickle(processed_data_dir + \"target_encoding_df.pkl\"))[[\"id\", \"day_id\"] + target_enc_colnm]\n", + " target_enc_df = cudf.DataFrame(\n", + " pd.read_pickle(processed_data_dir + \"target_encoding_df.pkl\")\n", + " )[[\"id\", \"day_id\"] + target_enc_colnm]\n", "\n", " grid_df = grid_df.merge(target_enc_df, on=[\"id\", \"day_id\"], how=\"left\")\n", " del target_enc_df\n", " gc.collect()\n", "\n", " grid_df = grid_df.drop(columns=[\"id\"])\n", - " grid_df[\"day_id\"] = grid_df[\"day_id\"].to_pandas().astype(\"str\").apply(lambda x: x[2:]).astype(np.int16)\n", + " grid_df[\"day_id\"] = (\n", + " grid_df[\"day_id\"]\n", + " .to_pandas()\n", + " .astype(\"str\")\n", + " .apply(lambda x: x[2:])\n", + " .astype(np.int16)\n", + " )\n", "\n", " return grid_df" ] @@ -270,7 +350,9 @@ " for dept in DEPTS:\n", " print(f\"Processing (store {store}, department {dept})...\")\n", " grid_df = prepare_data(store=store, dept=dept)\n", - " grid_df.to_pandas().to_pickle(segmented_data_dir + f\"combined_df_store_{store}_dept_{dept}.pkl\")\n", + " grid_df.to_pandas().to_pickle(\n", + " segmented_data_dir + f\"combined_df_store_{store}_dept_{dept}.pkl\"\n", + " )\n", " del grid_df\n", " gc.collect()" ] @@ -391,7 +473,9 @@ "source": [ "# Also upload the product weights\n", "fs = gcsfs.GCSFileSystem()\n", - "fs.put_file(processed_data_dir + \"product_weights.pkl\", f\"{bucket_name}/product_weights.pkl\")" + "fs.put_file(\n", + " processed_data_dir + \"product_weights.pkl\", f\"{bucket_name}/product_weights.pkl\"\n", + ")" ] }, { diff --git a/source/examples/time-series-forecasting-with-hpo/training_and_evaluation.ipynb b/source/examples/time-series-forecasting-with-hpo/training_and_evaluation.ipynb index db16e6d6..ffdfb494 100644 --- a/source/examples/time-series-forecasting-with-hpo/training_and_evaluation.ipynb +++ b/source/examples/time-series-forecasting-with-hpo/training_and_evaluation.ipynb @@ -41,7 +41,7 @@ "import optuna\n", "import gc\n", "import time\n", - "import pickle \n", + "import pickle\n", "import copy\n", "import json\n", "\n", @@ -64,15 +64,36 @@ "outputs": [], "source": [ "# Choose the same RAPIDS image you used for launching the notebook session\n", - "rapids_image = \"rapidsai/rapidsai-core-nightly:23.08-cuda11.8-runtime-ubuntu22.04-py3.10\"\n", + "rapids_image = (\n", + " \"rapidsai/rapidsai-core-nightly:23.08-cuda11.8-runtime-ubuntu22.04-py3.10\"\n", + ")\n", "# Use the number of worker nodes in your Kubernetes cluster.\n", "n_workers = 3\n", "# Bucket that contains the processed data pickles, refer to start_here.ipynb\n", "bucket_name = \"\"\n", "\n", "# List of stores and product departments\n", - "STORES = [\"CA_1\", \"CA_2\", \"CA_3\", \"CA_4\", \"TX_1\", \"TX_2\", \"TX_3\", \"WI_1\", \"WI_2\", \"WI_3\"]\n", - "DEPTS = [\"HOBBIES_1\", \"HOBBIES_2\", \"HOUSEHOLD_1\", \"HOUSEHOLD_2\", \"FOODS_1\", \"FOODS_2\", \"FOODS_3\"]" + "STORES = [\n", + " \"CA_1\",\n", + " \"CA_2\",\n", + " \"CA_3\",\n", + " \"CA_4\",\n", + " \"TX_1\",\n", + " \"TX_2\",\n", + " \"TX_3\",\n", + " \"WI_1\",\n", + " \"WI_2\",\n", + " \"WI_3\",\n", + "]\n", + "DEPTS = [\n", + " \"HOBBIES_1\",\n", + " \"HOBBIES_2\",\n", + " \"HOUSEHOLD_1\",\n", + " \"HOUSEHOLD_2\",\n", + " \"FOODS_1\",\n", + " \"FOODS_2\",\n", + " \"FOODS_3\",\n", + "]" ] }, { @@ -110,7 +131,7 @@ " ([0, 1114], [1114, 1314]),\n", " ([0, 1314], [1314, 1514]),\n", " ([0, 1514], [1514, 1714]),\n", - " ([0, 1714], [1714, 1914])\n", + " ([0, 1714], [1714, 1914]),\n", "]\n", "n_folds = len(cv_folds)\n", "holdout = [1914, 1942]\n", @@ -162,7 +183,7 @@ " lw=20,\n", " cmap=cv_cmap,\n", " vmin=-1.5,\n", - " vmax=1.5\n", + " vmax=1.5,\n", " )\n", "\n", "idx = np.array([np.nan] * time_horizon)\n", @@ -177,24 +198,26 @@ " lw=20,\n", " cmap=cv_cmap,\n", " vmin=-1.5,\n", - " vmax=1.5\n", + " vmax=1.5,\n", ")\n", "\n", "plt.xlabel(\"Time\")\n", "plt.yticks(\n", " ticks=np.arange(n_folds + 1) + 0.5,\n", - " labels=[f\"Fold {i}\" for i in range(n_folds)] + [\"Holdout\"]\n", + " labels=[f\"Fold {i}\" for i in range(n_folds)] + [\"Holdout\"],\n", ")\n", "plt.ylim([len(cv_folds) + 1.2, -0.2])\n", "\n", "norm = matplotlib.colors.Normalize(vmin=-1.5, vmax=1.5)\n", "plt.legend(\n", - " [Patch(color=cv_cmap(norm(1))),\n", - " Patch(color=cv_cmap(norm(0))),\n", - " Patch(color=cv_cmap(norm(-1)))],\n", + " [\n", + " Patch(color=cv_cmap(norm(1))),\n", + " Patch(color=cv_cmap(norm(0))),\n", + " Patch(color=cv_cmap(norm(-1))),\n", + " ],\n", " [\"Training set\", \"Validation set\", \"Held-out test set\"],\n", " ncol=3,\n", - " loc=\"best\"\n", + " loc=\"best\",\n", ")\n", "plt.tight_layout()" ] @@ -249,12 +272,14 @@ } ], "source": [ - "cluster = KubeCluster(name=\"rapids-dask\",\n", - " image=rapids_image,\n", - " worker_command=\"dask-cuda-worker\",\n", - " n_workers=n_workers,\n", - " resources={\"limits\": {\"nvidia.com/gpu\": \"1\"}},\n", - " env={\"DISABLE_JUPYTER\": \"true\", \"EXTRA_PIP_PACKAGES\": \"optuna gcsfs\"})" + "cluster = KubeCluster(\n", + " name=\"rapids-dask\",\n", + " image=rapids_image,\n", + " worker_command=\"dask-cuda-worker\",\n", + " n_workers=n_workers,\n", + " resources={\"limits\": {\"nvidia.com/gpu\": \"1\"}},\n", + " env={\"DISABLE_JUPYTER\": \"true\", \"EXTRA_PIP_PACKAGES\": \"optuna gcsfs\"},\n", + ")" ] }, { @@ -777,28 +802,39 @@ " df_valid = df[(df[\"day_id\"] >= valid_mask[0]) & (df[\"day_id\"] < valid_mask[1])]\n", "\n", " # Compute denominator: 1/(n-1) * sum( (y(t) - y(t-1))**2 )\n", - " diff = df_train.sort_values([\"item_id\", \"day_id\"]).groupby([\"item_id\"])[[\"sales\"]].diff(1)\n", - " x = df_train[[\"item_id\", \"day_id\"]].join(diff, how=\"left\").rename(columns={\"sales\": \"diff\"}).sort_values([\"item_id\", \"day_id\"])\n", + " diff = (\n", + " df_train.sort_values([\"item_id\", \"day_id\"])\n", + " .groupby([\"item_id\"])[[\"sales\"]]\n", + " .diff(1)\n", + " )\n", + " x = (\n", + " df_train[[\"item_id\", \"day_id\"]]\n", + " .join(diff, how=\"left\")\n", + " .rename(columns={\"sales\": \"diff\"})\n", + " .sort_values([\"item_id\", \"day_id\"])\n", + " )\n", " x[\"diff\"] = x[\"diff\"] ** 2\n", " xx = x.groupby([\"item_id\"])[[\"diff\"]].agg([\"sum\", \"count\"]).sort_index()\n", " xx.columns = xx.columns.map(\"_\".join)\n", " xx[\"denominator\"] = xx[\"diff_sum\"] / xx[\"diff_count\"]\n", " t = xx.reset_index()\n", - " \n", + "\n", " # Compute numerator: 1/h * sum( (y(t) - y_pred(t))**2 )\n", " X_valid = df_valid.drop(columns=[\"item_id\", \"cat_id\", \"day_id\", \"sales\"])\n", " if \"dept_id\" in X_valid.columns:\n", " X_valid = X_valid.drop(columns=[\"dept_id\"])\n", - " df_pred = cudf.DataFrame({\n", - " \"item_id\": df_valid[\"item_id\"].copy(),\n", - " \"pred_sales\": pred_sales,\n", - " \"sales\": df_valid[\"sales\"].copy()\n", - " })\n", - " df_pred[\"diff\"] = (df_pred[\"sales\"] - df_pred[\"pred_sales\"])**2\n", + " df_pred = cudf.DataFrame(\n", + " {\n", + " \"item_id\": df_valid[\"item_id\"].copy(),\n", + " \"pred_sales\": pred_sales,\n", + " \"sales\": df_valid[\"sales\"].copy(),\n", + " }\n", + " )\n", + " df_pred[\"diff\"] = (df_pred[\"sales\"] - df_pred[\"pred_sales\"]) ** 2\n", " yy = df_pred.groupby([\"item_id\"])[[\"diff\"]].agg([\"sum\", \"count\"]).sort_index()\n", " yy.columns = yy.columns.map(\"_\".join)\n", " yy[\"numerator\"] = yy[\"diff_sum\"] / yy[\"diff_count\"]\n", - " \n", + "\n", " zz = yy[[\"numerator\"]].join(xx[[\"denominator\"]], how=\"left\")\n", " zz = zz.join(product_weights, how=\"left\").sort_index()\n", " # Filter out zero denominator.\n", @@ -853,32 +889,47 @@ " \"alpha\": trial.suggest_float(\"alpha\", 1e-8, 100.0, log=True),\n", " \"colsample_bytree\": trial.suggest_float(\"colsample_bytree\", 0.2, 1.0),\n", " \"max_depth\": trial.suggest_int(\"max_depth\", 2, 6, step=1),\n", - " \"min_child_weight\": trial.suggest_float(\"min_child_weight\", 1e-8, 100, log=True),\n", + " \"min_child_weight\": trial.suggest_float(\n", + " \"min_child_weight\", 1e-8, 100, log=True\n", + " ),\n", " \"gamma\": trial.suggest_float(\"gamma\", 1e-8, 1.0, log=True),\n", " \"tweedie_variance_power\": trial.suggest_float(\"tweedie_variance_power\", 1, 2),\n", " }\n", " scores = [[] for store in STORES]\n", - " \n", + "\n", " for store_id, store in enumerate(STORES):\n", " print(f\"Processing store {store}...\")\n", " with fs.open(f\"{bucket_name}/combined_df_store_{store}.pkl\", \"rb\") as f:\n", " df = cudf.DataFrame(pd.read_pickle(f))\n", " for train_mask, valid_mask in cv_folds:\n", - " df_train = df[(df[\"day_id\"] >= train_mask[0]) & (df[\"day_id\"] < train_mask[1])]\n", - " df_valid = df[(df[\"day_id\"] >= valid_mask[0]) & (df[\"day_id\"] < valid_mask[1])]\n", + " df_train = df[\n", + " (df[\"day_id\"] >= train_mask[0]) & (df[\"day_id\"] < train_mask[1])\n", + " ]\n", + " df_valid = df[\n", + " (df[\"day_id\"] >= valid_mask[0]) & (df[\"day_id\"] < valid_mask[1])\n", + " ]\n", "\n", - " X_train, y_train = df_train.drop(columns=[\"item_id\", \"dept_id\", \"cat_id\", \"day_id\", \"sales\"]), df_train[\"sales\"]\n", - " X_valid = df_valid.drop(columns=[\"item_id\", \"dept_id\", \"cat_id\", \"day_id\", \"sales\"])\n", + " X_train, y_train = (\n", + " df_train.drop(\n", + " columns=[\"item_id\", \"dept_id\", \"cat_id\", \"day_id\", \"sales\"]\n", + " ),\n", + " df_train[\"sales\"],\n", + " )\n", + " X_valid = df_valid.drop(\n", + " columns=[\"item_id\", \"dept_id\", \"cat_id\", \"day_id\", \"sales\"]\n", + " )\n", "\n", " clf = xgb.XGBRegressor(**params)\n", " clf.fit(X_train, y_train)\n", " pred_sales = clf.predict(X_valid)\n", - " scores[store_id].append(wrmsse(product_weights, df, pred_sales, train_mask, valid_mask))\n", + " scores[store_id].append(\n", + " wrmsse(product_weights, df, pred_sales, train_mask, valid_mask)\n", + " )\n", " del df_train, df_valid, X_train, y_train, clf\n", " gc.collect()\n", " del df\n", " gc.collect()\n", - " \n", + "\n", " # We can sum WRMSSE scores over data segments because data segments contain disjoint sets of time series\n", " return np.array(scores).sum(axis=0).mean()" ] @@ -924,14 +975,16 @@ "source": [ "##### Number of hyperparameter combinations to try in parallel\n", "n_trials = 9 # Using a small n_trials so that the demo can finish quickly\n", - "#n_trials = 100\n", + "# n_trials = 100\n", "\n", "# Optimize in parallel on your Dask cluster\n", "backend_storage = optuna.storages.InMemoryStorage()\n", "dask_storage = optuna.integration.DaskStorage(storage=backend_storage, client=client)\n", - "study = optuna.create_study(direction=\"minimize\",\n", - " sampler=optuna.samplers.RandomSampler(seed=0),\n", - " storage=dask_storage)\n", + "study = optuna.create_study(\n", + " direction=\"minimize\",\n", + " sampler=optuna.samplers.RandomSampler(seed=0),\n", + " storage=dask_storage,\n", + ")\n", "futures = []\n", "for i in range(0, n_trials, n_workers):\n", " iter_range = (i, min([i + n_workers, n_trials]))\n", @@ -941,7 +994,7 @@ " \"futures\": [\n", " client.submit(study.optimize, objective, n_trials=1, pure=False)\n", " for _ in range(*iter_range)\n", - " ]\n", + " ],\n", " }\n", " )\n", "\n", @@ -953,7 +1006,9 @@ " for fut in partition[\"futures\"]:\n", " _ = fut.result() # Ensure that the training job was successful\n", " tnow = time.perf_counter()\n", - " print(f\"Best cross-validation metric: {study.best_value}, Time elapsed = {tnow - tstart}\")\n", + " print(\n", + " f\"Best cross-validation metric: {study.best_value}, Time elapsed = {tnow - tstart}\"\n", + " )\n", "tend = time.perf_counter()\n", "print(f\"Total time elapsed = {tend - tstart}\")" ] @@ -1126,14 +1181,17 @@ " df = cudf.DataFrame(pd.read_pickle(f))\n", "\n", " df_train = df[(df[\"day_id\"] >= train_mask[0]) & (df[\"day_id\"] < train_mask[1])]\n", - " X_train, y_train = df_train.drop(columns=[\"item_id\", \"dept_id\", \"cat_id\", \"day_id\", \"sales\"]), df_train[\"sales\"]\n", + " X_train, y_train = (\n", + " df_train.drop(columns=[\"item_id\", \"dept_id\", \"cat_id\", \"day_id\", \"sales\"]),\n", + " df_train[\"sales\"],\n", + " )\n", "\n", " clf = xgb.XGBRegressor(**params)\n", " clf.fit(X_train, y_train)\n", " model[store] = clf\n", " del df\n", " gc.collect()\n", - " \n", + "\n", " return model" ] }, @@ -1198,7 +1256,9 @@ " df_test = df[(df[\"day_id\"] >= holdout[0]) & (df[\"day_id\"] < holdout[1])]\n", " X_test = df_test.drop(columns=[\"item_id\", \"dept_id\", \"cat_id\", \"day_id\", \"sales\"])\n", " pred_sales = model[store].predict(X_test)\n", - " test_wrmsse += wrmsse(product_weights, df, pred_sales, train_mask=[0, 1914], valid_mask=holdout)\n", + " test_wrmsse += wrmsse(\n", + " product_weights, df, pred_sales, train_mask=[0, 1914], valid_mask=holdout\n", + " )\n", "print(f\"WRMSSE metric on the held-out test set: {test_wrmsse}\")" ] }, @@ -1260,33 +1320,48 @@ " \"alpha\": trial.suggest_float(\"alpha\", 1e-8, 100.0, log=True),\n", " \"colsample_bytree\": trial.suggest_float(\"colsample_bytree\", 0.2, 1.0),\n", " \"max_depth\": trial.suggest_int(\"max_depth\", 2, 6, step=1),\n", - " \"min_child_weight\": trial.suggest_float(\"min_child_weight\", 1e-8, 100, log=True),\n", + " \"min_child_weight\": trial.suggest_float(\n", + " \"min_child_weight\", 1e-8, 100, log=True\n", + " ),\n", " \"gamma\": trial.suggest_float(\"gamma\", 1e-8, 1.0, log=True),\n", " \"tweedie_variance_power\": trial.suggest_float(\"tweedie_variance_power\", 1, 2),\n", " }\n", " scores = [[] for i in range(len(STORES) * len(DEPTS))]\n", - " \n", + "\n", " for store_id, store in enumerate(STORES):\n", " for dept_id, dept in enumerate(DEPTS):\n", " print(f\"Processing store {store}, department {dept}...\")\n", - " with fs.open(f\"{bucket_name}/combined_df_store_{store}_dept_{dept}.pkl\", \"rb\") as f:\n", + " with fs.open(\n", + " f\"{bucket_name}/combined_df_store_{store}_dept_{dept}.pkl\", \"rb\"\n", + " ) as f:\n", " df = cudf.DataFrame(pd.read_pickle(f))\n", " for train_mask, valid_mask in cv_folds:\n", - " df_train = df[(df[\"day_id\"] >= train_mask[0]) & (df[\"day_id\"] < train_mask[1])]\n", - " df_valid = df[(df[\"day_id\"] >= valid_mask[0]) & (df[\"day_id\"] < valid_mask[1])]\n", + " df_train = df[\n", + " (df[\"day_id\"] >= train_mask[0]) & (df[\"day_id\"] < train_mask[1])\n", + " ]\n", + " df_valid = df[\n", + " (df[\"day_id\"] >= valid_mask[0]) & (df[\"day_id\"] < valid_mask[1])\n", + " ]\n", "\n", - " X_train, y_train = df_train.drop(columns=[\"item_id\", \"cat_id\", \"day_id\", \"sales\"]), df_train[\"sales\"]\n", - " X_valid = df_valid.drop(columns=[\"item_id\", \"cat_id\", \"day_id\", \"sales\"])\n", + " X_train, y_train = (\n", + " df_train.drop(columns=[\"item_id\", \"cat_id\", \"day_id\", \"sales\"]),\n", + " df_train[\"sales\"],\n", + " )\n", + " X_valid = df_valid.drop(\n", + " columns=[\"item_id\", \"cat_id\", \"day_id\", \"sales\"]\n", + " )\n", "\n", " clf = xgb.XGBRegressor(**params)\n", " clf.fit(X_train, y_train)\n", " sales_pred = clf.predict(X_valid)\n", - " scores[store_id * len(DEPTS) + dept_id].append(wrmsse(product_weights, df, sales_pred, train_mask, valid_mask))\n", + " scores[store_id * len(DEPTS) + dept_id].append(\n", + " wrmsse(product_weights, df, sales_pred, train_mask, valid_mask)\n", + " )\n", " del df_train, df_valid, X_train, y_train, clf\n", " gc.collect()\n", " del df\n", " gc.collect()\n", - " \n", + "\n", " # We can sum WRMSSE scores over data segments because data segments contain disjoint sets of time series\n", " return np.array(scores).sum(axis=0).mean()" ] @@ -1324,14 +1399,16 @@ "source": [ "##### Number of hyperparameter combinations to try in parallel\n", "n_trials = 9 # Using a small n_trials so that the demo can finish quickly\n", - "#n_trials = 100\n", + "# n_trials = 100\n", "\n", "# Optimize in parallel on your Dask cluster\n", "backend_storage = optuna.storages.InMemoryStorage()\n", "dask_storage = optuna.integration.DaskStorage(storage=backend_storage, client=client)\n", - "study = optuna.create_study(direction=\"minimize\",\n", - " sampler=optuna.samplers.RandomSampler(seed=0),\n", - " storage=dask_storage)\n", + "study = optuna.create_study(\n", + " direction=\"minimize\",\n", + " sampler=optuna.samplers.RandomSampler(seed=0),\n", + " storage=dask_storage,\n", + ")\n", "futures = []\n", "for i in range(0, n_trials, n_workers):\n", " iter_range = (i, min([i + n_workers, n_trials]))\n", @@ -1341,7 +1418,7 @@ " \"futures\": [\n", " client.submit(study.optimize, objective_alt, n_trials=1, pure=False)\n", " for _ in range(*iter_range)\n", - " ]\n", + " ],\n", " }\n", " )\n", "\n", @@ -1353,7 +1430,9 @@ " for fut in partition[\"futures\"]:\n", " _ = fut.result() # Ensure that the training job was successful\n", " tnow = time.perf_counter()\n", - " print(f\"Best cross-validation metric: {study.best_value}, Time elapsed = {tnow - tstart}\")\n", + " print(\n", + " f\"Best cross-validation metric: {study.best_value}, Time elapsed = {tnow - tstart}\"\n", + " )\n", "tend = time.perf_counter()\n", "print(f\"Total time elapsed = {tend - tstart}\")" ] @@ -1433,22 +1512,29 @@ " params.update(best_params)\n", " model = {}\n", " train_mask = [0, 1914]\n", - " \n", + "\n", " for store_id, store in enumerate(STORES):\n", " for dept_id, dept in enumerate(DEPTS):\n", " print(f\"Processing store {store}, department {dept}...\")\n", - " with fs.open(f\"{bucket_name}/combined_df_store_{store}_dept_{dept}.pkl\", \"rb\") as f:\n", + " with fs.open(\n", + " f\"{bucket_name}/combined_df_store_{store}_dept_{dept}.pkl\", \"rb\"\n", + " ) as f:\n", " df = cudf.DataFrame(pd.read_pickle(f))\n", " for train_mask, valid_mask in cv_folds:\n", - " df_train = df[(df[\"day_id\"] >= train_mask[0]) & (df[\"day_id\"] < train_mask[1])]\n", - " X_train, y_train = df_train.drop(columns=[\"item_id\", \"cat_id\", \"day_id\", \"sales\"]), df_train[\"sales\"]\n", + " df_train = df[\n", + " (df[\"day_id\"] >= train_mask[0]) & (df[\"day_id\"] < train_mask[1])\n", + " ]\n", + " X_train, y_train = (\n", + " df_train.drop(columns=[\"item_id\", \"cat_id\", \"day_id\", \"sales\"]),\n", + " df_train[\"sales\"],\n", + " )\n", "\n", " clf = xgb.XGBRegressor(**params)\n", " clf.fit(X_train, y_train)\n", " model[(store, dept)] = clf\n", " del df\n", " gc.collect()\n", - " \n", + "\n", " return model" ] }, @@ -1616,23 +1702,32 @@ " df_test = df[(df[\"day_id\"] >= holdout[0]) & (df[\"day_id\"] < holdout[1])]\n", " X_test = df_test.drop(columns=[\"item_id\", \"dept_id\", \"cat_id\", \"day_id\", \"sales\"])\n", " df_test[\"pred1\"] = model[store].predict(X_test)\n", - " \n", + "\n", " # Prediction from Model 2\n", " df_test[\"pred2\"] = [np.nan] * len(df_test)\n", " df_test[\"pred2\"] = df_test[\"pred2\"].astype(\"float32\")\n", " for dept in DEPTS:\n", - " with fs.open(f\"{bucket_name}/combined_df_store_{store}_dept_{dept}.pkl\", \"rb\") as f:\n", + " with fs.open(\n", + " f\"{bucket_name}/combined_df_store_{store}_dept_{dept}.pkl\", \"rb\"\n", + " ) as f:\n", " df2 = cudf.DataFrame(pd.read_pickle(f))\n", " df2_test = df2[(df2[\"day_id\"] >= holdout[0]) & (df2[\"day_id\"] < holdout[1])]\n", " X_test = df2_test.drop(columns=[\"item_id\", \"cat_id\", \"day_id\", \"sales\"])\n", " assert np.sum(df_test[\"dept_id\"] == dept) == len(X_test)\n", - " df_test[\"pred2\"][df_test[\"dept_id\"] == dept] = model_alt[(store, dept)].predict(X_test)\n", - " \n", + " df_test[\"pred2\"][df_test[\"dept_id\"] == dept] = model_alt[(store, dept)].predict(\n", + " X_test\n", + " )\n", + "\n", " # Average prediction\n", " df_test[\"avg_pred\"] = (df_test[\"pred1\"] + df_test[\"pred2\"]) / 2.0\n", "\n", - " test_wrmsse += wrmsse(product_weights, df, df_test[\"avg_pred\"],\n", - " train_mask=[0, 1914], valid_mask=holdout)\n", + " test_wrmsse += wrmsse(\n", + " product_weights,\n", + " df,\n", + " df_test[\"avg_pred\"],\n", + " train_mask=[0, 1914],\n", + " valid_mask=holdout,\n", + " )\n", "print(f\"WRMSSE metric on the held-out test set: {test_wrmsse}\")" ] }, From a0c2ae61d09e5c95f03b70bae138a5e67e64adbe Mon Sep 17 00:00:00 2001 From: Hyunsu Philip Cho Date: Wed, 2 Aug 2023 13:44:58 -0700 Subject: [PATCH 03/11] Fix Sphinx warnings --- .../start_here.ipynb | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/source/examples/time-series-forecasting-with-hpo/start_here.ipynb b/source/examples/time-series-forecasting-with-hpo/start_here.ipynb index e23c4c7e..a3f23b8a 100644 --- a/source/examples/time-series-forecasting-with-hpo/start_here.ipynb +++ b/source/examples/time-series-forecasting-with-hpo/start_here.ipynb @@ -271,13 +271,17 @@ "source": [ "We are now ready to run the preprocessing steps. You should run the six notebooks in order, to process the raw data into a form that can be used for model training:\n", "\n", - "* `preprocessing_part1.ipynb`\n", - "* `preprocessing_part2.ipynb`\n", - "* `preprocessing_part3.ipynb`\n", - "* `preprocessing_part4.ipynb`\n", - "* `preprocessing_part5.ipynb`\n", - "* `preprocessing_part6.ipynb`\n", - "* `training_and_evaluation.ipynb`\n" + "```{toctree}\n", + "---\n", + "maxdepth: 1\n", + "---\n", + "preprocessing_part1\n", + "preprocessing_part2\n", + "preprocessing_part3\n", + "preprocessing_part4\n", + "preprocessing_part5\n", + "preprocessing_part6\n", + "training_and_evaluation\n" ] } ], From b8bd729d34be3c468ad5f9a8302607952ee512e8 Mon Sep 17 00:00:00 2001 From: Hyunsu Philip Cho Date: Wed, 2 Aug 2023 13:52:50 -0700 Subject: [PATCH 04/11] Remove unused import --- .../time-series-forecasting-with-hpo/preprocessing_part1.ipynb | 1 - 1 file changed, 1 deletion(-) diff --git a/source/examples/time-series-forecasting-with-hpo/preprocessing_part1.ipynb b/source/examples/time-series-forecasting-with-hpo/preprocessing_part1.ipynb index a4347733..d91f94bf 100644 --- a/source/examples/time-series-forecasting-with-hpo/preprocessing_part1.ipynb +++ b/source/examples/time-series-forecasting-with-hpo/preprocessing_part1.ipynb @@ -27,7 +27,6 @@ "source": [ "import cudf\n", "import numpy as np\n", - "import pandas as pd\n", "import gc\n", "import pathlib\n", "import gcsfs" From 3054f6e478a73df1affc3ce1a44c4965a9b7173b Mon Sep 17 00:00:00 2001 From: Jacob Tomlinson Date: Fri, 25 Aug 2023 14:05:57 +0100 Subject: [PATCH 05/11] Enable dollarmath --- source/conf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/conf.py b/source/conf.py index 0fff7535..06cb1400 100644 --- a/source/conf.py +++ b/source/conf.py @@ -72,7 +72,7 @@ "rapids_admonitions", ] -myst_enable_extensions = ["colon_fence"] +myst_enable_extensions = ["colon_fence", "dollarmath"] # Add any paths that contain templates here, relative to this directory. templates_path = ["_templates"] From f3cb4e8835a5072854727997122595ef1596604a Mon Sep 17 00:00:00 2001 From: Jacob Tomlinson Date: Fri, 25 Aug 2023 14:18:45 +0100 Subject: [PATCH 06/11] Add new lines around math --- .../training_and_evaluation.ipynb | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/source/examples/time-series-forecasting-with-hpo/training_and_evaluation.ipynb b/source/examples/time-series-forecasting-with-hpo/training_and_evaluation.ipynb index ffdfb494..556f0814 100644 --- a/source/examples/time-series-forecasting-with-hpo/training_and_evaluation.ipynb +++ b/source/examples/time-series-forecasting-with-hpo/training_and_evaluation.ipynb @@ -775,13 +775,17 @@ "metadata": {}, "source": [ "The M5 forecasting competition defines a custom metric called WRMSSE as follows:\n", + "\n", "$$\n", "WRMSSE = \\sum w_i \\cdot RMSSE_i\n", "$$\n", + "\n", "i.e. WRMSEE is a weighted sum of RMSSE for all product items $i$. RMSSE is in turn defined to be\n", + "\n", "$$\n", "RMSSE = \\sqrt{\\frac{1/h \\cdot \\sum_t{\\left(Y_t - \\hat{Y}_t\\right)}^2}{1/(n-1)\\sum_t{(Y_t - Y_{t-1})}^2}}\n", "$$\n", + "\n", "where the squared error of the prediction (forecast) is normalized by the speed at which the sales amount changes per unit in the training data.\n", "\n", "Here is the implementation of the WRMSSE using cuDF. We use the product weights $w_i$ as computed in the first preprocessing notebook." From d455d143d4527d094ab904117a0cbb8589d9a712 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Thu, 28 Sep 2023 17:18:17 -0700 Subject: [PATCH 07/11] Combine all notebooks; incorporate feedback --- .../preprocessing_part1.ipynb | 3219 ------- .../preprocessing_part2.ipynb | 1290 --- .../preprocessing_part3.ipynb | 854 -- .../preprocessing_part4.ipynb | 628 -- .../preprocessing_part5.ipynb | 596 -- .../preprocessing_part6.ipynb | 511 -- .../start_here.ipynb | 7790 ++++++++++++++++- .../training_and_evaluation.ipynb | 1781 ---- 8 files changed, 7749 insertions(+), 8920 deletions(-) delete mode 100644 source/examples/time-series-forecasting-with-hpo/preprocessing_part1.ipynb delete mode 100644 source/examples/time-series-forecasting-with-hpo/preprocessing_part2.ipynb delete mode 100644 source/examples/time-series-forecasting-with-hpo/preprocessing_part3.ipynb delete mode 100644 source/examples/time-series-forecasting-with-hpo/preprocessing_part4.ipynb delete mode 100644 source/examples/time-series-forecasting-with-hpo/preprocessing_part5.ipynb delete mode 100644 source/examples/time-series-forecasting-with-hpo/preprocessing_part6.ipynb delete mode 100644 source/examples/time-series-forecasting-with-hpo/training_and_evaluation.ipynb diff --git a/source/examples/time-series-forecasting-with-hpo/preprocessing_part1.ipynb b/source/examples/time-series-forecasting-with-hpo/preprocessing_part1.ipynb deleted file mode 100644 index d91f94bf..00000000 --- a/source/examples/time-series-forecasting-with-hpo/preprocessing_part1.ipynb +++ /dev/null @@ -1,3219 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "70b2ef46-9ac6-47db-a66f-353bb4a27722", - "metadata": {}, - "source": [ - "# Data preprocesing, Part 1" - ] - }, - { - "cell_type": "markdown", - "id": "62f9ff3a-b875-4798-8988-7238c1c651a6", - "metadata": {}, - "source": [ - "## Import modules and define utility functions" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "fce9c954-bb2c-4c2d-bb6e-11b100cac88f", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "import cudf\n", - "import numpy as np\n", - "import gc\n", - "import pathlib\n", - "import gcsfs" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "45c07679-7305-4a8a-a826-efa42ec65ba2", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "def sizeof_fmt(num, suffix=\"B\"):\n", - " for unit in [\"\", \"Ki\", \"Mi\", \"Gi\", \"Ti\", \"Pi\", \"Ei\", \"Zi\"]:\n", - " if abs(num) < 1024.0:\n", - " return f\"{num:3.1f}{unit}{suffix}\"\n", - " num /= 1024.0\n", - " return \"%.1f%s%s\" % (num, \"Yi\", suffix)\n", - "\n", - "\n", - "def report_dataframe_size(df, name):\n", - " print(\n", - " \"{} takes up {} memory on GPU\".format(\n", - " name, sizeof_fmt(grid_df.memory_usage(index=True).sum())\n", - " )\n", - " )" - ] - }, - { - "cell_type": "markdown", - "id": "b981cb8f-f6f8-4a32-af08-d883ade14c0f", - "metadata": {}, - "source": [ - "## Load Data" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "1fda1c8f-2697-4e04-9484-2d7072c7e904", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "TARGET = \"sales\" # Our main target\n", - "END_TRAIN = 1941 # Last day in train set" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "d3ea8261-4b3d-470d-be49-b161c8d72b04", - "metadata": {}, - "outputs": [], - "source": [ - "raw_data_dir = \"./data/\"\n", - "processed_data_dir = \"./processed_data/\"\n", - "\n", - "pathlib.Path(processed_data_dir).mkdir(exist_ok=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "1a852fe7-e6c6-45f4-9767-af7b78bc510b", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "train_df = cudf.read_csv(raw_data_dir + \"sales_train_evaluation.csv\")\n", - "prices_df = cudf.read_csv(raw_data_dir + \"sell_prices.csv\")\n", - "calendar_df = cudf.read_csv(raw_data_dir + \"calendar.csv\").rename(\n", - " columns={\"d\": \"day_id\"}\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "e5d049a2-2797-45b3-9b32-367b06505477", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
iditem_iddept_idcat_idstore_idstate_idd_1d_2d_3d_4...d_1932d_1933d_1934d_1935d_1936d_1937d_1938d_1939d_1940d_1941
0HOBBIES_1_001_CA_1_evaluationHOBBIES_1_001HOBBIES_1HOBBIESCA_1CA0000...2400003301
1HOBBIES_1_002_CA_1_evaluationHOBBIES_1_002HOBBIES_1HOBBIESCA_1CA0000...0121100000
2HOBBIES_1_003_CA_1_evaluationHOBBIES_1_003HOBBIES_1HOBBIESCA_1CA0000...1020002301
3HOBBIES_1_004_CA_1_evaluationHOBBIES_1_004HOBBIES_1HOBBIESCA_1CA0000...1104013026
4HOBBIES_1_005_CA_1_evaluationHOBBIES_1_005HOBBIES_1HOBBIESCA_1CA0000...0002100210
..................................................................
30485FOODS_3_823_WI_3_evaluationFOODS_3_823FOODS_3FOODSWI_3WI0022...1030110011
30486FOODS_3_824_WI_3_evaluationFOODS_3_824FOODS_3FOODSWI_3WI0000...0000001010
30487FOODS_3_825_WI_3_evaluationFOODS_3_825FOODS_3FOODSWI_3WI0602...0012010102
30488FOODS_3_826_WI_3_evaluationFOODS_3_826FOODS_3FOODSWI_3WI0000...1114601110
30489FOODS_3_827_WI_3_evaluationFOODS_3_827FOODS_3FOODSWI_3WI0000...1205402251
\n", - "

30490 rows × 1947 columns

\n", - "
" - ], - "text/plain": [ - " id item_id dept_id cat_id \\\n", - "0 HOBBIES_1_001_CA_1_evaluation HOBBIES_1_001 HOBBIES_1 HOBBIES \n", - "1 HOBBIES_1_002_CA_1_evaluation HOBBIES_1_002 HOBBIES_1 HOBBIES \n", - "2 HOBBIES_1_003_CA_1_evaluation HOBBIES_1_003 HOBBIES_1 HOBBIES \n", - "3 HOBBIES_1_004_CA_1_evaluation HOBBIES_1_004 HOBBIES_1 HOBBIES \n", - "4 HOBBIES_1_005_CA_1_evaluation HOBBIES_1_005 HOBBIES_1 HOBBIES \n", - "... ... ... ... ... \n", - "30485 FOODS_3_823_WI_3_evaluation FOODS_3_823 FOODS_3 FOODS \n", - "30486 FOODS_3_824_WI_3_evaluation FOODS_3_824 FOODS_3 FOODS \n", - "30487 FOODS_3_825_WI_3_evaluation FOODS_3_825 FOODS_3 FOODS \n", - "30488 FOODS_3_826_WI_3_evaluation FOODS_3_826 FOODS_3 FOODS \n", - "30489 FOODS_3_827_WI_3_evaluation FOODS_3_827 FOODS_3 FOODS \n", - "\n", - " store_id state_id d_1 d_2 d_3 d_4 ... d_1932 d_1933 d_1934 \\\n", - "0 CA_1 CA 0 0 0 0 ... 2 4 0 \n", - "1 CA_1 CA 0 0 0 0 ... 0 1 2 \n", - "2 CA_1 CA 0 0 0 0 ... 1 0 2 \n", - "3 CA_1 CA 0 0 0 0 ... 1 1 0 \n", - "4 CA_1 CA 0 0 0 0 ... 0 0 0 \n", - "... ... ... ... ... ... ... ... ... ... ... \n", - "30485 WI_3 WI 0 0 2 2 ... 1 0 3 \n", - "30486 WI_3 WI 0 0 0 0 ... 0 0 0 \n", - "30487 WI_3 WI 0 6 0 2 ... 0 0 1 \n", - "30488 WI_3 WI 0 0 0 0 ... 1 1 1 \n", - "30489 WI_3 WI 0 0 0 0 ... 1 2 0 \n", - "\n", - " d_1935 d_1936 d_1937 d_1938 d_1939 d_1940 d_1941 \n", - "0 0 0 0 3 3 0 1 \n", - "1 1 1 0 0 0 0 0 \n", - "2 0 0 0 2 3 0 1 \n", - "3 4 0 1 3 0 2 6 \n", - "4 2 1 0 0 2 1 0 \n", - "... ... ... ... ... ... ... ... \n", - "30485 0 1 1 0 0 1 1 \n", - "30486 0 0 0 1 0 1 0 \n", - "30487 2 0 1 0 1 0 2 \n", - "30488 4 6 0 1 1 1 0 \n", - "30489 5 4 0 2 2 5 1 \n", - "\n", - "[30490 rows x 1947 columns]" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "train_df" - ] - }, - { - "cell_type": "markdown", - "id": "d9053e74-cddf-4ca2-b8cb-f6da58266e2a", - "metadata": {}, - "source": [ - "The columns `d_1`, `d_2`, ..., `d_1941` indicate the sales data at days 1, 2, ..., 1941 from 2011-01-29." - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "efeb85c5-c72e-4cd2-be2d-17d4a240aef4", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
store_iditem_idwm_yr_wksell_price
0CA_1HOBBIES_1_001113259.58
1CA_1HOBBIES_1_001113269.58
2CA_1HOBBIES_1_001113278.26
3CA_1HOBBIES_1_001113288.26
4CA_1HOBBIES_1_001113298.26
...............
6841116WI_3FOODS_3_827116171.00
6841117WI_3FOODS_3_827116181.00
6841118WI_3FOODS_3_827116191.00
6841119WI_3FOODS_3_827116201.00
6841120WI_3FOODS_3_827116211.00
\n", - "

6841121 rows × 4 columns

\n", - "
" - ], - "text/plain": [ - " store_id item_id wm_yr_wk sell_price\n", - "0 CA_1 HOBBIES_1_001 11325 9.58\n", - "1 CA_1 HOBBIES_1_001 11326 9.58\n", - "2 CA_1 HOBBIES_1_001 11327 8.26\n", - "3 CA_1 HOBBIES_1_001 11328 8.26\n", - "4 CA_1 HOBBIES_1_001 11329 8.26\n", - "... ... ... ... ...\n", - "6841116 WI_3 FOODS_3_827 11617 1.00\n", - "6841117 WI_3 FOODS_3_827 11618 1.00\n", - "6841118 WI_3 FOODS_3_827 11619 1.00\n", - "6841119 WI_3 FOODS_3_827 11620 1.00\n", - "6841120 WI_3 FOODS_3_827 11621 1.00\n", - "\n", - "[6841121 rows x 4 columns]" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "prices_df" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "77187dd8-92a7-400b-b5d8-4db31817e4df", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
datewm_yr_wkweekdaywdaymonthyearday_idevent_name_1event_type_1event_name_2event_type_2snap_CAsnap_TXsnap_WI
02011-01-2911101Saturday112011d_1<NA><NA><NA><NA>000
12011-01-3011101Sunday212011d_2<NA><NA><NA><NA>000
22011-01-3111101Monday312011d_3<NA><NA><NA><NA>000
32011-02-0111101Tuesday422011d_4<NA><NA><NA><NA>110
42011-02-0211101Wednesday522011d_5<NA><NA><NA><NA>101
.............................................
19642016-06-1511620Wednesday562016d_1965<NA><NA><NA><NA>011
19652016-06-1611620Thursday662016d_1966<NA><NA><NA><NA>000
19662016-06-1711620Friday762016d_1967<NA><NA><NA><NA>000
19672016-06-1811621Saturday162016d_1968<NA><NA><NA><NA>000
19682016-06-1911621Sunday262016d_1969NBAFinalsEndSportingFather's dayCultural000
\n", - "

1969 rows × 14 columns

\n", - "
" - ], - "text/plain": [ - " date wm_yr_wk weekday wday month year day_id \\\n", - "0 2011-01-29 11101 Saturday 1 1 2011 d_1 \n", - "1 2011-01-30 11101 Sunday 2 1 2011 d_2 \n", - "2 2011-01-31 11101 Monday 3 1 2011 d_3 \n", - "3 2011-02-01 11101 Tuesday 4 2 2011 d_4 \n", - "4 2011-02-02 11101 Wednesday 5 2 2011 d_5 \n", - "... ... ... ... ... ... ... ... \n", - "1964 2016-06-15 11620 Wednesday 5 6 2016 d_1965 \n", - "1965 2016-06-16 11620 Thursday 6 6 2016 d_1966 \n", - "1966 2016-06-17 11620 Friday 7 6 2016 d_1967 \n", - "1967 2016-06-18 11621 Saturday 1 6 2016 d_1968 \n", - "1968 2016-06-19 11621 Sunday 2 6 2016 d_1969 \n", - "\n", - " event_name_1 event_type_1 event_name_2 event_type_2 snap_CA snap_TX \\\n", - "0 0 0 \n", - "1 0 0 \n", - "2 0 0 \n", - "3 1 1 \n", - "4 1 0 \n", - "... ... ... ... ... ... ... \n", - "1964 0 1 \n", - "1965 0 0 \n", - "1966 0 0 \n", - "1967 0 0 \n", - "1968 NBAFinalsEnd Sporting Father's day Cultural 0 0 \n", - "\n", - " snap_WI \n", - "0 0 \n", - "1 0 \n", - "2 0 \n", - "3 0 \n", - "4 1 \n", - "... ... \n", - "1964 1 \n", - "1965 0 \n", - "1966 0 \n", - "1967 0 \n", - "1968 0 \n", - "\n", - "[1969 rows x 14 columns]" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "calendar_df" - ] - }, - { - "cell_type": "markdown", - "id": "dd65fd8d-ba0f-4120-8ebd-4bc9d1513b74", - "metadata": {}, - "source": [ - "## Reformat sales times series data" - ] - }, - { - "cell_type": "markdown", - "id": "666b07b0-bc62-4c9b-86bc-5287b1d40de2", - "metadata": {}, - "source": [ - "Pivot the columns `d_1`, `d_2`, ..., `d_1941` into separate rows using `cudf.melt`." - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "4908f0c5-b83b-4c8d-81e1-7853b5bbc7af", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
iditem_iddept_idcat_idstore_idstate_idday_idsales
0HOBBIES_1_001_CA_1_evaluationHOBBIES_1_001HOBBIES_1HOBBIESCA_1CAd_10
1HOBBIES_1_002_CA_1_evaluationHOBBIES_1_002HOBBIES_1HOBBIESCA_1CAd_10
2HOBBIES_1_003_CA_1_evaluationHOBBIES_1_003HOBBIES_1HOBBIESCA_1CAd_10
3HOBBIES_1_004_CA_1_evaluationHOBBIES_1_004HOBBIES_1HOBBIESCA_1CAd_10
4HOBBIES_1_005_CA_1_evaluationHOBBIES_1_005HOBBIES_1HOBBIESCA_1CAd_10
...........................
59181085FOODS_3_823_WI_3_evaluationFOODS_3_823FOODS_3FOODSWI_3WId_19411
59181086FOODS_3_824_WI_3_evaluationFOODS_3_824FOODS_3FOODSWI_3WId_19410
59181087FOODS_3_825_WI_3_evaluationFOODS_3_825FOODS_3FOODSWI_3WId_19412
59181088FOODS_3_826_WI_3_evaluationFOODS_3_826FOODS_3FOODSWI_3WId_19410
59181089FOODS_3_827_WI_3_evaluationFOODS_3_827FOODS_3FOODSWI_3WId_19411
\n", - "

59181090 rows × 8 columns

\n", - "
" - ], - "text/plain": [ - " id item_id dept_id cat_id \\\n", - "0 HOBBIES_1_001_CA_1_evaluation HOBBIES_1_001 HOBBIES_1 HOBBIES \n", - "1 HOBBIES_1_002_CA_1_evaluation HOBBIES_1_002 HOBBIES_1 HOBBIES \n", - "2 HOBBIES_1_003_CA_1_evaluation HOBBIES_1_003 HOBBIES_1 HOBBIES \n", - "3 HOBBIES_1_004_CA_1_evaluation HOBBIES_1_004 HOBBIES_1 HOBBIES \n", - "4 HOBBIES_1_005_CA_1_evaluation HOBBIES_1_005 HOBBIES_1 HOBBIES \n", - "... ... ... ... ... \n", - "59181085 FOODS_3_823_WI_3_evaluation FOODS_3_823 FOODS_3 FOODS \n", - "59181086 FOODS_3_824_WI_3_evaluation FOODS_3_824 FOODS_3 FOODS \n", - "59181087 FOODS_3_825_WI_3_evaluation FOODS_3_825 FOODS_3 FOODS \n", - "59181088 FOODS_3_826_WI_3_evaluation FOODS_3_826 FOODS_3 FOODS \n", - "59181089 FOODS_3_827_WI_3_evaluation FOODS_3_827 FOODS_3 FOODS \n", - "\n", - " store_id state_id day_id sales \n", - "0 CA_1 CA d_1 0 \n", - "1 CA_1 CA d_1 0 \n", - "2 CA_1 CA d_1 0 \n", - "3 CA_1 CA d_1 0 \n", - "4 CA_1 CA d_1 0 \n", - "... ... ... ... ... \n", - "59181085 WI_3 WI d_1941 1 \n", - "59181086 WI_3 WI d_1941 0 \n", - "59181087 WI_3 WI d_1941 2 \n", - "59181088 WI_3 WI d_1941 0 \n", - "59181089 WI_3 WI d_1941 1 \n", - "\n", - "[59181090 rows x 8 columns]" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "index_columns = [\"id\", \"item_id\", \"dept_id\", \"cat_id\", \"store_id\", \"state_id\"]\n", - "grid_df = cudf.melt(\n", - " train_df, id_vars=index_columns, var_name=\"day_id\", value_name=TARGET\n", - ")\n", - "grid_df" - ] - }, - { - "cell_type": "markdown", - "id": "fe51bf83-6831-4d4a-99bc-f6d5f6ec9e15", - "metadata": {}, - "source": [ - "For each time series, add 28 rows that corresponds to the future forecast horizon:" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "5700f3c3-f294-4963-8cea-9235ab55b0f4", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
iditem_iddept_idcat_idstore_idstate_idday_idsales
0HOBBIES_1_001_CA_1_evaluationHOBBIES_1_001HOBBIES_1HOBBIESCA_1CAd_10.0
1HOBBIES_1_002_CA_1_evaluationHOBBIES_1_002HOBBIES_1HOBBIESCA_1CAd_10.0
2HOBBIES_1_003_CA_1_evaluationHOBBIES_1_003HOBBIES_1HOBBIESCA_1CAd_10.0
3HOBBIES_1_004_CA_1_evaluationHOBBIES_1_004HOBBIES_1HOBBIESCA_1CAd_10.0
4HOBBIES_1_005_CA_1_evaluationHOBBIES_1_005HOBBIES_1HOBBIESCA_1CAd_10.0
...........................
60034805FOODS_3_823_WI_3_evaluationFOODS_3_823FOODS_3FOODSWI_3WId_1969NaN
60034806FOODS_3_824_WI_3_evaluationFOODS_3_824FOODS_3FOODSWI_3WId_1969NaN
60034807FOODS_3_825_WI_3_evaluationFOODS_3_825FOODS_3FOODSWI_3WId_1969NaN
60034808FOODS_3_826_WI_3_evaluationFOODS_3_826FOODS_3FOODSWI_3WId_1969NaN
60034809FOODS_3_827_WI_3_evaluationFOODS_3_827FOODS_3FOODSWI_3WId_1969NaN
\n", - "

60034810 rows × 8 columns

\n", - "
" - ], - "text/plain": [ - " id item_id dept_id cat_id \\\n", - "0 HOBBIES_1_001_CA_1_evaluation HOBBIES_1_001 HOBBIES_1 HOBBIES \n", - "1 HOBBIES_1_002_CA_1_evaluation HOBBIES_1_002 HOBBIES_1 HOBBIES \n", - "2 HOBBIES_1_003_CA_1_evaluation HOBBIES_1_003 HOBBIES_1 HOBBIES \n", - "3 HOBBIES_1_004_CA_1_evaluation HOBBIES_1_004 HOBBIES_1 HOBBIES \n", - "4 HOBBIES_1_005_CA_1_evaluation HOBBIES_1_005 HOBBIES_1 HOBBIES \n", - "... ... ... ... ... \n", - "60034805 FOODS_3_823_WI_3_evaluation FOODS_3_823 FOODS_3 FOODS \n", - "60034806 FOODS_3_824_WI_3_evaluation FOODS_3_824 FOODS_3 FOODS \n", - "60034807 FOODS_3_825_WI_3_evaluation FOODS_3_825 FOODS_3 FOODS \n", - "60034808 FOODS_3_826_WI_3_evaluation FOODS_3_826 FOODS_3 FOODS \n", - "60034809 FOODS_3_827_WI_3_evaluation FOODS_3_827 FOODS_3 FOODS \n", - "\n", - " store_id state_id day_id sales \n", - "0 CA_1 CA d_1 0.0 \n", - "1 CA_1 CA d_1 0.0 \n", - "2 CA_1 CA d_1 0.0 \n", - "3 CA_1 CA d_1 0.0 \n", - "4 CA_1 CA d_1 0.0 \n", - "... ... ... ... ... \n", - "60034805 WI_3 WI d_1969 NaN \n", - "60034806 WI_3 WI d_1969 NaN \n", - "60034807 WI_3 WI d_1969 NaN \n", - "60034808 WI_3 WI d_1969 NaN \n", - "60034809 WI_3 WI d_1969 NaN \n", - "\n", - "[60034810 rows x 8 columns]" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "add_grid = cudf.DataFrame()\n", - "for i in range(1, 29):\n", - " temp_df = train_df[index_columns]\n", - " temp_df = temp_df.drop_duplicates()\n", - " temp_df[\"day_id\"] = \"d_\" + str(END_TRAIN + i)\n", - " temp_df[TARGET] = np.nan # Sales amount at time (n + i) is unknown\n", - " add_grid = cudf.concat([add_grid, temp_df])\n", - "add_grid[\"day_id\"] = add_grid[\"day_id\"].astype(\n", - " \"category\"\n", - ") # The day_id column is categorical, after cudf.melt\n", - "\n", - "grid_df = cudf.concat([grid_df, add_grid])\n", - "grid_df = grid_df.reset_index(drop=True)\n", - "grid_df[\"sales\"] = grid_df[\"sales\"].astype(\n", - " np.float32\n", - ") # Use float32 type for sales column, to conserve memory\n", - "grid_df" - ] - }, - { - "cell_type": "markdown", - "id": "4af8ec5b-5ad3-43c0-9e49-136f2e061019", - "metadata": {}, - "source": [ - "### Free up GPU memory" - ] - }, - { - "cell_type": "markdown", - "id": "8cac1630-1062-4d3e-be9e-77f5315d903b", - "metadata": {}, - "source": [ - "GPU memory is a precious resource, so let's try to free up some memory. First, delete temporary variables we no longer need:" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "65e025ab-7443-4ba2-b47d-1b6494a57afe", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "data": { - "text/plain": [ - "8184" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# Use xdel magic to scrub extra references from Jupyter notebook\n", - "%xdel temp_df\n", - "%xdel add_grid\n", - "%xdel train_df\n", - "\n", - "# Invoke the garbage collector explicitly to free up memory\n", - "gc.collect()" - ] - }, - { - "cell_type": "markdown", - "id": "15bc2d78-417f-4920-83ab-ea3d32ba6744", - "metadata": {}, - "source": [ - "Second, let's reduce the footprint of `grid_df` by converting strings into categoricals:" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "7d2ef775-f68d-47d2-8f6f-be1efffc831d", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "grid_df takes up 5.2GiB memory on GPU\n" - ] - } - ], - "source": [ - "report_dataframe_size(grid_df, \"grid_df\")" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "7c3fa45a-4f89-48c5-9fd7-7ac31e065ba7", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "data": { - "text/plain": [ - "id object\n", - "item_id object\n", - "dept_id object\n", - "cat_id object\n", - "store_id object\n", - "state_id object\n", - "day_id category\n", - "sales float32\n", - "dtype: object" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "grid_df.dtypes" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "dd0acbcd-c192-4f01-800f-ea6ac441b1ab", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "grid_df takes up 802.6MiB memory on GPU\n" - ] - } - ], - "source": [ - "for col in index_columns:\n", - " grid_df[col] = grid_df[col].astype(\"category\")\n", - " gc.collect()\n", - "report_dataframe_size(grid_df, \"grid_df\")" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "96945261-2069-4923-a443-07766774465a", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "data": { - "text/plain": [ - "id category\n", - "item_id category\n", - "dept_id category\n", - "cat_id category\n", - "store_id category\n", - "state_id category\n", - "day_id category\n", - "sales float32\n", - "dtype: object" - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "grid_df.dtypes" - ] - }, - { - "cell_type": "markdown", - "id": "48321684-4e32-4d6b-a0b1-6495b773b4ff", - "metadata": {}, - "source": [ - "## Identify the release week of each product" - ] - }, - { - "cell_type": "markdown", - "id": "8c3b2631-02b1-4942-b881-16c43321aad6", - "metadata": {}, - "source": [ - "Each row in the `prices_df` table contains the price of a product sold at a store for a given week." - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "dd624c0c-0460-4e58-9bd1-0a4148e21a46", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
store_iditem_idwm_yr_wksell_price
0CA_1HOBBIES_1_001113259.58
1CA_1HOBBIES_1_001113269.58
2CA_1HOBBIES_1_001113278.26
3CA_1HOBBIES_1_001113288.26
4CA_1HOBBIES_1_001113298.26
...............
6841116WI_3FOODS_3_827116171.00
6841117WI_3FOODS_3_827116181.00
6841118WI_3FOODS_3_827116191.00
6841119WI_3FOODS_3_827116201.00
6841120WI_3FOODS_3_827116211.00
\n", - "

6841121 rows × 4 columns

\n", - "
" - ], - "text/plain": [ - " store_id item_id wm_yr_wk sell_price\n", - "0 CA_1 HOBBIES_1_001 11325 9.58\n", - "1 CA_1 HOBBIES_1_001 11326 9.58\n", - "2 CA_1 HOBBIES_1_001 11327 8.26\n", - "3 CA_1 HOBBIES_1_001 11328 8.26\n", - "4 CA_1 HOBBIES_1_001 11329 8.26\n", - "... ... ... ... ...\n", - "6841116 WI_3 FOODS_3_827 11617 1.00\n", - "6841117 WI_3 FOODS_3_827 11618 1.00\n", - "6841118 WI_3 FOODS_3_827 11619 1.00\n", - "6841119 WI_3 FOODS_3_827 11620 1.00\n", - "6841120 WI_3 FOODS_3_827 11621 1.00\n", - "\n", - "[6841121 rows x 4 columns]" - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "prices_df" - ] - }, - { - "cell_type": "markdown", - "id": "87cf77f2-3eb5-440d-850d-a80eee2fbce9", - "metadata": {}, - "source": [ - "Notice that not all products were sold over every week. Some products were sold only during some weeks. Let's use the groupby operation to identify the first week in which each product went on the shelf." - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "b4bd0430-bffa-4d62-b15a-bf71498c19b5", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
store_iditem_idrelease_week
0CA_4FOODS_3_52911421
1TX_1HOUSEHOLD_1_40911230
2WI_2FOODS_3_14511214
3CA_4HOUSEHOLD_1_49411106
4WI_3HOBBIES_1_09311223
............
30485CA_3HOUSEHOLD_1_36911205
30486CA_2FOODS_3_10911101
30487CA_4FOODS_2_11911101
30488CA_4HOUSEHOLD_2_38411110
30489WI_3HOBBIES_1_13511328
\n", - "

30490 rows × 3 columns

\n", - "
" - ], - "text/plain": [ - " store_id item_id release_week\n", - "0 CA_4 FOODS_3_529 11421\n", - "1 TX_1 HOUSEHOLD_1_409 11230\n", - "2 WI_2 FOODS_3_145 11214\n", - "3 CA_4 HOUSEHOLD_1_494 11106\n", - "4 WI_3 HOBBIES_1_093 11223\n", - "... ... ... ...\n", - "30485 CA_3 HOUSEHOLD_1_369 11205\n", - "30486 CA_2 FOODS_3_109 11101\n", - "30487 CA_4 FOODS_2_119 11101\n", - "30488 CA_4 HOUSEHOLD_2_384 11110\n", - "30489 WI_3 HOBBIES_1_135 11328\n", - "\n", - "[30490 rows x 3 columns]" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "release_df = (\n", - " prices_df.groupby([\"store_id\", \"item_id\"])[\"wm_yr_wk\"].agg(\"min\").reset_index()\n", - ")\n", - "release_df.columns = [\"store_id\", \"item_id\", \"release_week\"]\n", - "release_df" - ] - }, - { - "cell_type": "markdown", - "id": "b987d3f4-7885-4f59-86e0-bb78a9eea106", - "metadata": {}, - "source": [ - "Now that we've computed the release week for each product, let's merge it back to `grid_df`:" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "id": "9b1baea3-8507-452d-ae2e-bc0024d3c926", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
iditem_iddept_idcat_idstore_idstate_idday_idsalesrelease_week
0FOODS_1_001_CA_1_evaluationFOODS_1_001FOODS_1FOODSCA_1CAd_13.011101
1FOODS_1_001_CA_1_evaluationFOODS_1_001FOODS_1FOODSCA_1CAd_20.011101
2FOODS_1_001_CA_1_evaluationFOODS_1_001FOODS_1FOODSCA_1CAd_30.011101
3FOODS_1_001_CA_1_evaluationFOODS_1_001FOODS_1FOODSCA_1CAd_41.011101
4FOODS_1_001_CA_1_evaluationFOODS_1_001FOODS_1FOODSCA_1CAd_54.011101
..............................
60034805HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_1965NaN11101
60034806HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_1966NaN11101
60034807HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_1967NaN11101
60034808HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_1968NaN11101
60034809HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_1969NaN11101
\n", - "

60034810 rows × 9 columns

\n", - "
" - ], - "text/plain": [ - " id item_id dept_id \\\n", - "0 FOODS_1_001_CA_1_evaluation FOODS_1_001 FOODS_1 \n", - "1 FOODS_1_001_CA_1_evaluation FOODS_1_001 FOODS_1 \n", - "2 FOODS_1_001_CA_1_evaluation FOODS_1_001 FOODS_1 \n", - "3 FOODS_1_001_CA_1_evaluation FOODS_1_001 FOODS_1 \n", - "4 FOODS_1_001_CA_1_evaluation FOODS_1_001 FOODS_1 \n", - "... ... ... ... \n", - "60034805 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", - "60034806 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", - "60034807 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", - "60034808 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", - "60034809 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", - "\n", - " cat_id store_id state_id day_id sales release_week \n", - "0 FOODS CA_1 CA d_1 3.0 11101 \n", - "1 FOODS CA_1 CA d_2 0.0 11101 \n", - "2 FOODS CA_1 CA d_3 0.0 11101 \n", - "3 FOODS CA_1 CA d_4 1.0 11101 \n", - "4 FOODS CA_1 CA d_5 4.0 11101 \n", - "... ... ... ... ... ... ... \n", - "60034805 HOUSEHOLD WI_3 WI d_1965 NaN 11101 \n", - "60034806 HOUSEHOLD WI_3 WI d_1966 NaN 11101 \n", - "60034807 HOUSEHOLD WI_3 WI d_1967 NaN 11101 \n", - "60034808 HOUSEHOLD WI_3 WI d_1968 NaN 11101 \n", - "60034809 HOUSEHOLD WI_3 WI d_1969 NaN 11101 \n", - "\n", - "[60034810 rows x 9 columns]" - ] - }, - "execution_count": 18, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "grid_df = grid_df.merge(release_df, on=[\"store_id\", \"item_id\"], how=\"left\")\n", - "grid_df = grid_df.sort_values(index_columns + [\"day_id\"]).reset_index(drop=True)\n", - "grid_df" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "9e2c3c9b-c753-4b49-876d-447e73a55c81", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "data": { - "text/plain": [ - "138" - ] - }, - "execution_count": 19, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "del release_df # No longer needed\n", - "gc.collect()" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "id": "b973c4df-67a6-414b-8c5d-883ab276bbd9", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "grid_df takes up 1.2GiB memory on GPU\n" - ] - } - ], - "source": [ - "report_dataframe_size(grid_df, \"grid_df\")" - ] - }, - { - "cell_type": "markdown", - "id": "01e39a12-a2b8-49ce-9a08-42fa08f2d1a8", - "metadata": {}, - "source": [ - "## Filter out entries with zero sales" - ] - }, - { - "cell_type": "markdown", - "id": "8733da5f-44bf-4b23-9645-0711bb68f6ef", - "metadata": {}, - "source": [ - "We can further save space by dropping rows from `grid_df` that correspond to zero sales. Since each product doesn't go on the shelf until its release week, its sale must be zero during any week that's prior to the release week.\n", - "\n", - "To make use of this insight, we bring in the `wm_yr_wk` column from `calendar_df`:" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "id": "565fee38-25a2-4ed7-a160-5be11bd68df4", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
iditem_iddept_idcat_idstore_idstate_idday_idsalesrelease_weekwm_yr_wk
0FOODS_1_001_CA_1_evaluationFOODS_1_001FOODS_1FOODSCA_1CAd_15371.01110111511
1FOODS_1_001_CA_1_evaluationFOODS_1_001FOODS_1FOODSCA_1CAd_15380.01110111511
2FOODS_1_001_CA_1_evaluationFOODS_1_001FOODS_1FOODSCA_1CAd_15392.01110111511
3FOODS_1_001_CA_1_evaluationFOODS_1_001FOODS_1FOODSCA_1CAd_15400.01110111511
4FOODS_1_001_CA_1_evaluationFOODS_1_001FOODS_1FOODSCA_1CAd_15410.01110111512
.................................
60034805HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_520.01110111108
60034806HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_530.01110111108
60034807HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_540.01110111108
60034808HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_550.01110111108
60034809HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_490.01110111107
\n", - "

60034810 rows × 10 columns

\n", - "
" - ], - "text/plain": [ - " id item_id dept_id \\\n", - "0 FOODS_1_001_CA_1_evaluation FOODS_1_001 FOODS_1 \n", - "1 FOODS_1_001_CA_1_evaluation FOODS_1_001 FOODS_1 \n", - "2 FOODS_1_001_CA_1_evaluation FOODS_1_001 FOODS_1 \n", - "3 FOODS_1_001_CA_1_evaluation FOODS_1_001 FOODS_1 \n", - "4 FOODS_1_001_CA_1_evaluation FOODS_1_001 FOODS_1 \n", - "... ... ... ... \n", - "60034805 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", - "60034806 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", - "60034807 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", - "60034808 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", - "60034809 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", - "\n", - " cat_id store_id state_id day_id sales release_week wm_yr_wk \n", - "0 FOODS CA_1 CA d_1537 1.0 11101 11511 \n", - "1 FOODS CA_1 CA d_1538 0.0 11101 11511 \n", - "2 FOODS CA_1 CA d_1539 2.0 11101 11511 \n", - "3 FOODS CA_1 CA d_1540 0.0 11101 11511 \n", - "4 FOODS CA_1 CA d_1541 0.0 11101 11512 \n", - "... ... ... ... ... ... ... ... \n", - "60034805 HOUSEHOLD WI_3 WI d_52 0.0 11101 11108 \n", - "60034806 HOUSEHOLD WI_3 WI d_53 0.0 11101 11108 \n", - "60034807 HOUSEHOLD WI_3 WI d_54 0.0 11101 11108 \n", - "60034808 HOUSEHOLD WI_3 WI d_55 0.0 11101 11108 \n", - "60034809 HOUSEHOLD WI_3 WI d_49 0.0 11101 11107 \n", - "\n", - "[60034810 rows x 10 columns]" - ] - }, - "execution_count": 21, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "grid_df = grid_df.merge(calendar_df[[\"wm_yr_wk\", \"day_id\"]], on=[\"day_id\"], how=\"left\")\n", - "grid_df" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "id": "4cdfb852-2d25-4b60-8470-71edb3c63d14", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "grid_df takes up 1.7GiB memory on GPU\n" - ] - } - ], - "source": [ - "report_dataframe_size(grid_df, \"grid_df\")" - ] - }, - { - "cell_type": "markdown", - "id": "cb5059b7-6a1f-464b-b027-ce3b389eb15e", - "metadata": {}, - "source": [ - "The `wm_yr_wk` column identifies the week that contains the day given by the `day_id` column. Now let's filter all rows in `grid_df` for which `wm_yr_wk` is less than `release_week`:" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "id": "79b8cf35-776e-4295-8b5b-9e293c2ff325", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
iditem_iddept_idcat_idstore_idstate_idday_idsalesrelease_weekwm_yr_wk
9990FOODS_1_001_TX_3_evaluationFOODS_1_001FOODS_1FOODSTX_3TXd_10.01110211101
9991FOODS_1_001_TX_3_evaluationFOODS_1_001FOODS_1FOODSTX_3TXd_20.01110211101
9992FOODS_1_001_TX_3_evaluationFOODS_1_001FOODS_1FOODSTX_3TXd_30.01110211101
9993FOODS_1_001_TX_3_evaluationFOODS_1_001FOODS_1FOODSTX_3TXd_40.01110211101
9994FOODS_1_001_TX_3_evaluationFOODS_1_001FOODS_1FOODSTX_3TXd_50.01110211101
.................................
60032955HOUSEHOLD_2_516_WI_2_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_2WId_200.01110611103
60032956HOUSEHOLD_2_516_WI_2_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_2WId_210.01110611103
60032957HOUSEHOLD_2_516_WI_2_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_2WId_220.01110611104
60032958HOUSEHOLD_2_516_WI_2_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_2WId_230.01110611104
60032959HOUSEHOLD_2_516_WI_2_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_2WId_240.01110611104
\n", - "

12299413 rows × 10 columns

\n", - "
" - ], - "text/plain": [ - " id item_id dept_id \\\n", - "9990 FOODS_1_001_TX_3_evaluation FOODS_1_001 FOODS_1 \n", - "9991 FOODS_1_001_TX_3_evaluation FOODS_1_001 FOODS_1 \n", - "9992 FOODS_1_001_TX_3_evaluation FOODS_1_001 FOODS_1 \n", - "9993 FOODS_1_001_TX_3_evaluation FOODS_1_001 FOODS_1 \n", - "9994 FOODS_1_001_TX_3_evaluation FOODS_1_001 FOODS_1 \n", - "... ... ... ... \n", - "60032955 HOUSEHOLD_2_516_WI_2_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", - "60032956 HOUSEHOLD_2_516_WI_2_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", - "60032957 HOUSEHOLD_2_516_WI_2_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", - "60032958 HOUSEHOLD_2_516_WI_2_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", - "60032959 HOUSEHOLD_2_516_WI_2_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", - "\n", - " cat_id store_id state_id day_id sales release_week wm_yr_wk \n", - "9990 FOODS TX_3 TX d_1 0.0 11102 11101 \n", - "9991 FOODS TX_3 TX d_2 0.0 11102 11101 \n", - "9992 FOODS TX_3 TX d_3 0.0 11102 11101 \n", - "9993 FOODS TX_3 TX d_4 0.0 11102 11101 \n", - "9994 FOODS TX_3 TX d_5 0.0 11102 11101 \n", - "... ... ... ... ... ... ... ... \n", - "60032955 HOUSEHOLD WI_2 WI d_20 0.0 11106 11103 \n", - "60032956 HOUSEHOLD WI_2 WI d_21 0.0 11106 11103 \n", - "60032957 HOUSEHOLD WI_2 WI d_22 0.0 11106 11104 \n", - "60032958 HOUSEHOLD WI_2 WI d_23 0.0 11106 11104 \n", - "60032959 HOUSEHOLD WI_2 WI d_24 0.0 11106 11104 \n", - "\n", - "[12299413 rows x 10 columns]" - ] - }, - "execution_count": 23, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df = grid_df[grid_df[\"wm_yr_wk\"] < grid_df[\"release_week\"]]\n", - "df" - ] - }, - { - "cell_type": "markdown", - "id": "3056f24d-8bef-457e-aee0-82b43459ba92", - "metadata": {}, - "source": [ - "As we suspected, the sales amount is zero during weeks that come before the release week." - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "id": "e82c98f0-1c6d-49ae-abd4-e37c3a07bc16", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "assert (df[\"sales\"] == 0).all()" - ] - }, - { - "cell_type": "markdown", - "id": "49d6e163-e634-43c6-a021-55d5604944ff", - "metadata": {}, - "source": [ - "For the purpose of our data analysis, we can safely drop the rows with zero sales:" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "id": "dd694a7a-86c0-44c9-b29c-8387f047500c", - "metadata": { - "scrolled": true, - "tags": [] - }, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
iditem_iddept_idcat_idstore_idstate_idday_idsalesrelease_weekwm_yr_wk
0FOODS_1_001_CA_1_evaluationFOODS_1_001FOODS_1FOODSCA_1CAd_15371.01110111511
1FOODS_1_001_CA_1_evaluationFOODS_1_001FOODS_1FOODSCA_1CAd_15380.01110111511
2FOODS_1_001_CA_1_evaluationFOODS_1_001FOODS_1FOODSCA_1CAd_15392.01110111511
3FOODS_1_001_CA_1_evaluationFOODS_1_001FOODS_1FOODSCA_1CAd_15400.01110111511
4FOODS_1_001_CA_1_evaluationFOODS_1_001FOODS_1FOODSCA_1CAd_15410.01110111512
.................................
47735392HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_520.01110111108
47735393HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_530.01110111108
47735394HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_540.01110111108
47735395HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_550.01110111108
47735396HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_490.01110111107
\n", - "

47735397 rows × 10 columns

\n", - "
" - ], - "text/plain": [ - " id item_id dept_id \\\n", - "0 FOODS_1_001_CA_1_evaluation FOODS_1_001 FOODS_1 \n", - "1 FOODS_1_001_CA_1_evaluation FOODS_1_001 FOODS_1 \n", - "2 FOODS_1_001_CA_1_evaluation FOODS_1_001 FOODS_1 \n", - "3 FOODS_1_001_CA_1_evaluation FOODS_1_001 FOODS_1 \n", - "4 FOODS_1_001_CA_1_evaluation FOODS_1_001 FOODS_1 \n", - "... ... ... ... \n", - "47735392 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", - "47735393 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", - "47735394 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", - "47735395 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", - "47735396 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", - "\n", - " cat_id store_id state_id day_id sales release_week wm_yr_wk \n", - "0 FOODS CA_1 CA d_1537 1.0 11101 11511 \n", - "1 FOODS CA_1 CA d_1538 0.0 11101 11511 \n", - "2 FOODS CA_1 CA d_1539 2.0 11101 11511 \n", - "3 FOODS CA_1 CA d_1540 0.0 11101 11511 \n", - "4 FOODS CA_1 CA d_1541 0.0 11101 11512 \n", - "... ... ... ... ... ... ... ... \n", - "47735392 HOUSEHOLD WI_3 WI d_52 0.0 11101 11108 \n", - "47735393 HOUSEHOLD WI_3 WI d_53 0.0 11101 11108 \n", - "47735394 HOUSEHOLD WI_3 WI d_54 0.0 11101 11108 \n", - "47735395 HOUSEHOLD WI_3 WI d_55 0.0 11101 11108 \n", - "47735396 HOUSEHOLD WI_3 WI d_49 0.0 11101 11107 \n", - "\n", - "[47735397 rows x 10 columns]" - ] - }, - "execution_count": 25, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "grid_df = grid_df[grid_df[\"wm_yr_wk\"] >= grid_df[\"release_week\"]].reset_index(drop=True)\n", - "grid_df[\"wm_yr_wk\"] = grid_df[\"wm_yr_wk\"].astype(\n", - " np.int32\n", - ") # Convert wm_yr_wk column to int32, to conserve memory\n", - "grid_df" - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "id": "b46d629b-9996-443f-95d5-9da888b00fa0", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "grid_df takes up 1.2GiB memory on GPU\n" - ] - } - ], - "source": [ - "report_dataframe_size(grid_df, \"grid_df\")" - ] - }, - { - "cell_type": "markdown", - "id": "d323fd3a-21e4-4df6-8d32-78948acbf6fd", - "metadata": {}, - "source": [ - "## Assign weights for product items" - ] - }, - { - "cell_type": "markdown", - "id": "4b55b355-ffc3-429d-98bd-30c356f70c06", - "metadata": {}, - "source": [ - "When we assess the accuracy of our machine learning model, we should assign a weight for each product item, to indicate the relative importance of the item. For the M5 competition, the weights are computed from the total sales amount (in US dollars) in the lastest 28 days." - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "id": "0b216249-dd2d-40e5-bbb3-6a41773c7b98", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
sales_usd_sum
item_id
FOODS_1_0013516.80
FOODS_1_00212418.80
FOODS_1_0035943.20
FOODS_1_00454184.82
FOODS_1_00517877.00
......
HOUSEHOLD_2_5126034.40
HOUSEHOLD_2_5132668.80
HOUSEHOLD_2_5149574.60
HOUSEHOLD_2_515630.40
HOUSEHOLD_2_5162574.00
\n", - "

3049 rows × 1 columns

\n", - "
" - ], - "text/plain": [ - " sales_usd_sum\n", - "item_id \n", - "FOODS_1_001 3516.80\n", - "FOODS_1_002 12418.80\n", - "FOODS_1_003 5943.20\n", - "FOODS_1_004 54184.82\n", - "FOODS_1_005 17877.00\n", - "... ...\n", - "HOUSEHOLD_2_512 6034.40\n", - "HOUSEHOLD_2_513 2668.80\n", - "HOUSEHOLD_2_514 9574.60\n", - "HOUSEHOLD_2_515 630.40\n", - "HOUSEHOLD_2_516 2574.00\n", - "\n", - "[3049 rows x 1 columns]" - ] - }, - "execution_count": 27, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# Convert day_id to integers\n", - "grid_df[\"day_id_int\"] = grid_df[\"day_id\"].to_pandas().apply(lambda x: x[2:]).astype(int)\n", - "\n", - "# Compute the total sales over the latest 28 days, per product item\n", - "last28 = grid_df[(grid_df[\"day_id_int\"] >= 1914) & (grid_df[\"day_id_int\"] < 1942)]\n", - "last28 = last28[[\"item_id\", \"wm_yr_wk\", \"sales\"]].merge(\n", - " prices_df[[\"item_id\", \"wm_yr_wk\", \"sell_price\"]], on=[\"item_id\", \"wm_yr_wk\"]\n", - ")\n", - "last28[\"sales_usd\"] = last28[\"sales\"] * last28[\"sell_price\"]\n", - "total_sales_usd = last28.groupby(\"item_id\")[[\"sales_usd\"]].agg([\"sum\"]).sort_index()\n", - "total_sales_usd.columns = total_sales_usd.columns.map(\"_\".join)\n", - "total_sales_usd" - ] - }, - { - "cell_type": "markdown", - "id": "ae0934ba-37de-4391-bdb4-7c10685fec6c", - "metadata": {}, - "source": [ - "To obtain weights, we normalize the sales amount for one item by the total sales for all items." - ] - }, - { - "cell_type": "code", - "execution_count": 28, - "id": "bfab66af-8b70-41fb-bc99-b43567b3f87a", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
weights
item_id
FOODS_1_0010.000090
FOODS_1_0020.000318
FOODS_1_0030.000152
FOODS_1_0040.001389
FOODS_1_0050.000458
......
HOUSEHOLD_2_5120.000155
HOUSEHOLD_2_5130.000068
HOUSEHOLD_2_5140.000245
HOUSEHOLD_2_5150.000016
HOUSEHOLD_2_5160.000066
\n", - "

3049 rows × 1 columns

\n", - "
" - ], - "text/plain": [ - " weights\n", - "item_id \n", - "FOODS_1_001 0.000090\n", - "FOODS_1_002 0.000318\n", - "FOODS_1_003 0.000152\n", - "FOODS_1_004 0.001389\n", - "FOODS_1_005 0.000458\n", - "... ...\n", - "HOUSEHOLD_2_512 0.000155\n", - "HOUSEHOLD_2_513 0.000068\n", - "HOUSEHOLD_2_514 0.000245\n", - "HOUSEHOLD_2_515 0.000016\n", - "HOUSEHOLD_2_516 0.000066\n", - "\n", - "[3049 rows x 1 columns]" - ] - }, - "execution_count": 28, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "weights = total_sales_usd / total_sales_usd.sum()\n", - "weights = weights.rename(columns={\"sales_usd_sum\": \"weights\"})\n", - "weights" - ] - }, - { - "cell_type": "markdown", - "id": "2e717a7d-9022-4c44-9d09-2916fa45afb6", - "metadata": {}, - "source": [ - "## Persist the processed data to disk" - ] - }, - { - "cell_type": "code", - "execution_count": 29, - "id": "cb06bbd9-1e00-44cb-b197-2098b7c67069", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "# No longer needed\n", - "del grid_df[\"day_id_int\"]\n", - "\n", - "# Persist grid_df to disk\n", - "grid_df.to_pandas().to_pickle(processed_data_dir + \"grid_df_part1.pkl\")\n", - "weights.to_pandas().to_pickle(processed_data_dir + \"product_weights.pkl\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "126198cf-e07f-46e6-bc3b-cc6170a57c08", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.12" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/source/examples/time-series-forecasting-with-hpo/preprocessing_part2.ipynb b/source/examples/time-series-forecasting-with-hpo/preprocessing_part2.ipynb deleted file mode 100644 index 2c3ba598..00000000 --- a/source/examples/time-series-forecasting-with-hpo/preprocessing_part2.ipynb +++ /dev/null @@ -1,1290 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "a17c37ef-5035-4540-8066-bb9741496639", - "metadata": {}, - "source": [ - "# Data preprocesing, Part 2" - ] - }, - { - "cell_type": "markdown", - "id": "879d5849-97b4-4a05-9f26-7035a2ce220c", - "metadata": {}, - "source": [ - "## Import modules" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "96a23fd4-d1e3-4487-ae73-38d456fa2408", - "metadata": {}, - "outputs": [], - "source": [ - "import cudf\n", - "import numpy as np\n", - "import pandas as pd\n", - "import gc" - ] - }, - { - "cell_type": "markdown", - "id": "ed21ca9c-5a4e-4788-a2a2-2db5b9f5cab8", - "metadata": {}, - "source": [ - "## Load data" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "239019da-02cb-4e53-bfe1-ca0bfef1f201", - "metadata": {}, - "outputs": [], - "source": [ - "raw_data_dir = \"./data/\"\n", - "processed_data_dir = \"./processed_data/\"" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "eb0d40de-c3f3-40fe-9651-63c0cdf5b077", - "metadata": {}, - "outputs": [], - "source": [ - "prices_df = cudf.read_csv(raw_data_dir + \"sell_prices.csv\")\n", - "calendar_df = cudf.read_csv(raw_data_dir + \"calendar.csv\").rename(\n", - " columns={\"d\": \"day_id\"}\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "5946cf28-0188-4c4d-b98f-6bc855e8a92d", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
store_iditem_idwm_yr_wksell_price
0CA_1HOBBIES_1_001113259.58
1CA_1HOBBIES_1_001113269.58
2CA_1HOBBIES_1_001113278.26
3CA_1HOBBIES_1_001113288.26
4CA_1HOBBIES_1_001113298.26
...............
6841116WI_3FOODS_3_827116171.00
6841117WI_3FOODS_3_827116181.00
6841118WI_3FOODS_3_827116191.00
6841119WI_3FOODS_3_827116201.00
6841120WI_3FOODS_3_827116211.00
\n", - "

6841121 rows × 4 columns

\n", - "
" - ], - "text/plain": [ - " store_id item_id wm_yr_wk sell_price\n", - "0 CA_1 HOBBIES_1_001 11325 9.58\n", - "1 CA_1 HOBBIES_1_001 11326 9.58\n", - "2 CA_1 HOBBIES_1_001 11327 8.26\n", - "3 CA_1 HOBBIES_1_001 11328 8.26\n", - "4 CA_1 HOBBIES_1_001 11329 8.26\n", - "... ... ... ... ...\n", - "6841116 WI_3 FOODS_3_827 11617 1.00\n", - "6841117 WI_3 FOODS_3_827 11618 1.00\n", - "6841118 WI_3 FOODS_3_827 11619 1.00\n", - "6841119 WI_3 FOODS_3_827 11620 1.00\n", - "6841120 WI_3 FOODS_3_827 11621 1.00\n", - "\n", - "[6841121 rows x 4 columns]" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "prices_df" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "d6d8e3e5-426e-413c-8461-b01dbead0de0", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
datewm_yr_wkweekdaywdaymonthyearday_idevent_name_1event_type_1event_name_2event_type_2snap_CAsnap_TXsnap_WI
02011-01-2911101Saturday112011d_1<NA><NA><NA><NA>000
12011-01-3011101Sunday212011d_2<NA><NA><NA><NA>000
22011-01-3111101Monday312011d_3<NA><NA><NA><NA>000
32011-02-0111101Tuesday422011d_4<NA><NA><NA><NA>110
42011-02-0211101Wednesday522011d_5<NA><NA><NA><NA>101
.............................................
19642016-06-1511620Wednesday562016d_1965<NA><NA><NA><NA>011
19652016-06-1611620Thursday662016d_1966<NA><NA><NA><NA>000
19662016-06-1711620Friday762016d_1967<NA><NA><NA><NA>000
19672016-06-1811621Saturday162016d_1968<NA><NA><NA><NA>000
19682016-06-1911621Sunday262016d_1969NBAFinalsEndSportingFather's dayCultural000
\n", - "

1969 rows × 14 columns

\n", - "
" - ], - "text/plain": [ - " date wm_yr_wk weekday wday month year day_id \\\n", - "0 2011-01-29 11101 Saturday 1 1 2011 d_1 \n", - "1 2011-01-30 11101 Sunday 2 1 2011 d_2 \n", - "2 2011-01-31 11101 Monday 3 1 2011 d_3 \n", - "3 2011-02-01 11101 Tuesday 4 2 2011 d_4 \n", - "4 2011-02-02 11101 Wednesday 5 2 2011 d_5 \n", - "... ... ... ... ... ... ... ... \n", - "1964 2016-06-15 11620 Wednesday 5 6 2016 d_1965 \n", - "1965 2016-06-16 11620 Thursday 6 6 2016 d_1966 \n", - "1966 2016-06-17 11620 Friday 7 6 2016 d_1967 \n", - "1967 2016-06-18 11621 Saturday 1 6 2016 d_1968 \n", - "1968 2016-06-19 11621 Sunday 2 6 2016 d_1969 \n", - "\n", - " event_name_1 event_type_1 event_name_2 event_type_2 snap_CA snap_TX \\\n", - "0 0 0 \n", - "1 0 0 \n", - "2 0 0 \n", - "3 1 1 \n", - "4 1 0 \n", - "... ... ... ... ... ... ... \n", - "1964 0 1 \n", - "1965 0 0 \n", - "1966 0 0 \n", - "1967 0 0 \n", - "1968 NBAFinalsEnd Sporting Father's day Cultural 0 0 \n", - "\n", - " snap_WI \n", - "0 0 \n", - "1 0 \n", - "2 0 \n", - "3 0 \n", - "4 1 \n", - "... ... \n", - "1964 1 \n", - "1965 0 \n", - "1966 0 \n", - "1967 0 \n", - "1968 0 \n", - "\n", - "[1969 rows x 14 columns]" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "calendar_df" - ] - }, - { - "cell_type": "markdown", - "id": "8fa43d3f-aa95-4d98-9265-0d56f0f39eba", - "metadata": {}, - "source": [ - "## Generate price-related features" - ] - }, - { - "cell_type": "markdown", - "id": "ed99b0f4-ab4a-485f-a75e-211cb6bacd00", - "metadata": {}, - "source": [ - "Let us engineer additional features that are related to the sale price. We consider the distribution of the price of a given product over time and ask how the current price compares to the historical trend." - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "380f03c7-c8e7-42a5-ac43-f4aec0883880", - "metadata": {}, - "outputs": [], - "source": [ - "# Highest price over all weeks\n", - "prices_df[\"price_max\"] = prices_df.groupby([\"store_id\", \"item_id\"])[\n", - " \"sell_price\"\n", - "].transform(\"max\")\n", - "# Lowest price over all weeks\n", - "prices_df[\"price_min\"] = prices_df.groupby([\"store_id\", \"item_id\"])[\n", - " \"sell_price\"\n", - "].transform(\"min\")\n", - "# Standard deviation of the price\n", - "prices_df[\"price_std\"] = prices_df.groupby([\"store_id\", \"item_id\"])[\n", - " \"sell_price\"\n", - "].transform(\"std\")\n", - "# Mean (average) price over all weeks\n", - "prices_df[\"price_mean\"] = prices_df.groupby([\"store_id\", \"item_id\"])[\n", - " \"sell_price\"\n", - "].transform(\"mean\")" - ] - }, - { - "cell_type": "markdown", - "id": "03eb645e-3a5e-41e5-866c-b9904cdd3953", - "metadata": {}, - "source": [ - "We also consider the ratio of the current price to the max price." - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "743d269e-14ea-448d-9dc3-9214c5a42ea7", - "metadata": {}, - "outputs": [], - "source": [ - "prices_df[\"price_norm\"] = prices_df[\"sell_price\"] / prices_df[\"price_max\"]" - ] - }, - { - "cell_type": "markdown", - "id": "4c256268-3c1b-4904-926c-b0b5bf13e98a", - "metadata": {}, - "source": [ - "Some items have a very stable price, whereas other items respond to inflation quickly and rise in price. To capture the price elasticity, we count the number of unique price values for a given product over time." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "6668ab89-155a-4eb0-8f5c-e2f7895afdf9", - "metadata": {}, - "outputs": [], - "source": [ - "prices_df[\"price_nunique\"] = prices_df.groupby([\"store_id\", \"item_id\"])[\n", - " \"sell_price\"\n", - "].transform(\"nunique\")" - ] - }, - { - "cell_type": "markdown", - "id": "4da0b2ad-3b36-4bc1-a469-08ae9c8abb1b", - "metadata": {}, - "source": [ - "We also consider, for a given price, how many other items are being sold at the exact same price." - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "7b8ef0a3-7736-4f95-8c8a-624f3f7a4faa", - "metadata": {}, - "outputs": [], - "source": [ - "prices_df[\"item_nunique\"] = prices_df.groupby([\"store_id\", \"sell_price\"])[\n", - " \"item_id\"\n", - "].transform(\"nunique\")" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "d0bebd07-4238-410d-9f16-f4cbeef90819", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
store_iditem_idwm_yr_wksell_priceprice_maxprice_minprice_stdprice_meanprice_normprice_nuniqueitem_nunique
0CA_1HOBBIES_1_001113259.589.588.260.1521398.2857141.00000033
1CA_1HOBBIES_1_001113269.589.588.260.1521398.2857141.00000033
2CA_1HOBBIES_1_001113278.269.588.260.1521398.2857140.86221335
3CA_1HOBBIES_1_001113288.269.588.260.1521398.2857140.86221335
4CA_1HOBBIES_1_001113298.269.588.260.1521398.2857140.86221335
....................................
6841116WI_3FOODS_3_827116171.001.001.000.0000001.0000001.0000001142
6841117WI_3FOODS_3_827116181.001.001.000.0000001.0000001.0000001142
6841118WI_3FOODS_3_827116191.001.001.000.0000001.0000001.0000001142
6841119WI_3FOODS_3_827116201.001.001.000.0000001.0000001.0000001142
6841120WI_3FOODS_3_827116211.001.001.000.0000001.0000001.0000001142
\n", - "

6841121 rows × 11 columns

\n", - "
" - ], - "text/plain": [ - " store_id item_id wm_yr_wk sell_price price_max price_min \\\n", - "0 CA_1 HOBBIES_1_001 11325 9.58 9.58 8.26 \n", - "1 CA_1 HOBBIES_1_001 11326 9.58 9.58 8.26 \n", - "2 CA_1 HOBBIES_1_001 11327 8.26 9.58 8.26 \n", - "3 CA_1 HOBBIES_1_001 11328 8.26 9.58 8.26 \n", - "4 CA_1 HOBBIES_1_001 11329 8.26 9.58 8.26 \n", - "... ... ... ... ... ... ... \n", - "6841116 WI_3 FOODS_3_827 11617 1.00 1.00 1.00 \n", - "6841117 WI_3 FOODS_3_827 11618 1.00 1.00 1.00 \n", - "6841118 WI_3 FOODS_3_827 11619 1.00 1.00 1.00 \n", - "6841119 WI_3 FOODS_3_827 11620 1.00 1.00 1.00 \n", - "6841120 WI_3 FOODS_3_827 11621 1.00 1.00 1.00 \n", - "\n", - " price_std price_mean price_norm price_nunique item_nunique \n", - "0 0.152139 8.285714 1.000000 3 3 \n", - "1 0.152139 8.285714 1.000000 3 3 \n", - "2 0.152139 8.285714 0.862213 3 5 \n", - "3 0.152139 8.285714 0.862213 3 5 \n", - "4 0.152139 8.285714 0.862213 3 5 \n", - "... ... ... ... ... ... \n", - "6841116 0.000000 1.000000 1.000000 1 142 \n", - "6841117 0.000000 1.000000 1.000000 1 142 \n", - "6841118 0.000000 1.000000 1.000000 1 142 \n", - "6841119 0.000000 1.000000 1.000000 1 142 \n", - "6841120 0.000000 1.000000 1.000000 1 142 \n", - "\n", - "[6841121 rows x 11 columns]" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "prices_df" - ] - }, - { - "cell_type": "markdown", - "id": "98d24d6e-91b2-43a8-8051-2898ce047254", - "metadata": {}, - "source": [ - "Another useful way to put prices in context is to compare the price of a product to its historical price a week ago, a month ago, or an year ago." - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "7d95506a-aeaa-445b-8202-d158ce22c182", - "metadata": {}, - "outputs": [], - "source": [ - "# Add \"month\" and \"year\" columns to prices_df\n", - "week_to_month_map = calendar_df[[\"wm_yr_wk\", \"month\", \"year\"]].drop_duplicates(\n", - " subset=[\"wm_yr_wk\"]\n", - ")\n", - "prices_df = prices_df.merge(week_to_month_map, on=[\"wm_yr_wk\"], how=\"left\")\n", - "\n", - "# Sort by wm_yr_wk. The rows will also be sorted in ascending months and years.\n", - "prices_df = prices_df.sort_values([\"store_id\", \"item_id\", \"wm_yr_wk\"])" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "646edece-e9be-4f35-a252-9710383e5a5b", - "metadata": {}, - "outputs": [], - "source": [ - "# Compare with the average price in the previous week\n", - "prices_df[\"price_momentum\"] = prices_df[\"sell_price\"] / prices_df.groupby(\n", - " [\"store_id\", \"item_id\"]\n", - ")[\"sell_price\"].shift(1)\n", - "# Compare with the average price in the previous month\n", - "prices_df[\"price_momentum_m\"] = prices_df[\"sell_price\"] / prices_df.groupby(\n", - " [\"store_id\", \"item_id\", \"month\"]\n", - ")[\"sell_price\"].transform(\"mean\")\n", - "# Compare with the average price in the previous year\n", - "prices_df[\"price_momentum_y\"] = prices_df[\"sell_price\"] / prices_df.groupby(\n", - " [\"store_id\", \"item_id\", \"year\"]\n", - ")[\"sell_price\"].transform(\"mean\")" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "6205fac5-4bf9-4c9b-9eae-05341cda3834", - "metadata": {}, - "outputs": [], - "source": [ - "# Remove \"month\" and \"year\" columns, as we don't need them any more\n", - "del prices_df[\"month\"], prices_df[\"year\"]\n", - "\n", - "# Convert float64 columns into float32 type to save memory\n", - "columns = [\n", - " \"sell_price\",\n", - " \"price_max\",\n", - " \"price_min\",\n", - " \"price_std\",\n", - " \"price_mean\",\n", - " \"price_norm\",\n", - " \"price_momentum\",\n", - " \"price_momentum_m\",\n", - " \"price_momentum_y\",\n", - "]\n", - "for col in columns:\n", - " prices_df[col] = prices_df[col].astype(np.float32)" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "dc7e75b8-33d0-42bf-b14a-d398012c2ce6", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "store_id object\n", - "item_id object\n", - "wm_yr_wk int64\n", - "sell_price float32\n", - "price_max float32\n", - "price_min float32\n", - "price_std float32\n", - "price_mean float32\n", - "price_norm float32\n", - "price_nunique int32\n", - "item_nunique int32\n", - "price_momentum float32\n", - "price_momentum_m float32\n", - "price_momentum_y float32\n", - "dtype: object" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "prices_df.dtypes" - ] - }, - { - "cell_type": "markdown", - "id": "dab2dc14-6273-4598-98ca-734459b66aa5", - "metadata": {}, - "source": [ - "## Bring in price-related features into `grid_df`" - ] - }, - { - "cell_type": "markdown", - "id": "cf3bb251-871c-452a-b24b-c33d56559160", - "metadata": {}, - "source": [ - "We load `grid_df` from the Part 1 notebook and bring in columns from `price_df`." - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "0c5564da-8608-4640-805f-d14960dbe760", - "metadata": {}, - "outputs": [], - "source": [ - "grid_df = cudf.from_pandas(pd.read_pickle(processed_data_dir + \"grid_df_part1.pkl\"))" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "a7901ee0-1a60-42cc-a983-fe6838b7a612", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
idday_idsell_priceprice_maxprice_minprice_stdprice_meanprice_normprice_nuniqueitem_nuniqueprice_momentumprice_momentum_mprice_momentum_y
0FOODS_1_001_CA_2_evaluationd_10402.242.242.001.095719e-012.1693621.02611.01.0198681.0
1FOODS_1_001_CA_2_evaluationd_10412.242.242.001.095719e-012.1693621.02611.01.0198681.0
2FOODS_1_001_CA_2_evaluationd_10422.242.242.001.095719e-012.1693621.02611.01.0198681.0
3FOODS_1_001_CA_2_evaluationd_10432.242.242.001.095719e-012.1693621.02611.01.0198681.0
4FOODS_1_001_CA_2_evaluationd_10442.242.242.001.095719e-012.1693621.02611.01.0249581.0
..........................................
47735392HOUSEHOLD_2_516_WI_2_evaluationd_8845.945.945.943.648122e-145.9400001.01471.01.0000001.0
47735393HOUSEHOLD_2_516_WI_2_evaluationd_8855.945.945.943.648122e-145.9400001.01471.01.0000001.0
47735394HOUSEHOLD_2_516_WI_2_evaluationd_8865.945.945.943.648122e-145.9400001.01471.01.0000001.0
47735395HOUSEHOLD_2_516_WI_2_evaluationd_8875.945.945.943.648122e-145.9400001.01471.01.0000001.0
47735396HOUSEHOLD_2_516_WI_2_evaluationd_8885.945.945.943.648122e-145.9400001.01471.01.0000001.0
\n", - "

47735397 rows × 13 columns

\n", - "
" - ], - "text/plain": [ - " id day_id sell_price price_max \\\n", - "0 FOODS_1_001_CA_2_evaluation d_1040 2.24 2.24 \n", - "1 FOODS_1_001_CA_2_evaluation d_1041 2.24 2.24 \n", - "2 FOODS_1_001_CA_2_evaluation d_1042 2.24 2.24 \n", - "3 FOODS_1_001_CA_2_evaluation d_1043 2.24 2.24 \n", - "4 FOODS_1_001_CA_2_evaluation d_1044 2.24 2.24 \n", - "... ... ... ... ... \n", - "47735392 HOUSEHOLD_2_516_WI_2_evaluation d_884 5.94 5.94 \n", - "47735393 HOUSEHOLD_2_516_WI_2_evaluation d_885 5.94 5.94 \n", - "47735394 HOUSEHOLD_2_516_WI_2_evaluation d_886 5.94 5.94 \n", - "47735395 HOUSEHOLD_2_516_WI_2_evaluation d_887 5.94 5.94 \n", - "47735396 HOUSEHOLD_2_516_WI_2_evaluation d_888 5.94 5.94 \n", - "\n", - " price_min price_std price_mean price_norm price_nunique \\\n", - "0 2.00 1.095719e-01 2.169362 1.0 2 \n", - "1 2.00 1.095719e-01 2.169362 1.0 2 \n", - "2 2.00 1.095719e-01 2.169362 1.0 2 \n", - "3 2.00 1.095719e-01 2.169362 1.0 2 \n", - "4 2.00 1.095719e-01 2.169362 1.0 2 \n", - "... ... ... ... ... ... \n", - "47735392 5.94 3.648122e-14 5.940000 1.0 1 \n", - "47735393 5.94 3.648122e-14 5.940000 1.0 1 \n", - "47735394 5.94 3.648122e-14 5.940000 1.0 1 \n", - "47735395 5.94 3.648122e-14 5.940000 1.0 1 \n", - "47735396 5.94 3.648122e-14 5.940000 1.0 1 \n", - "\n", - " item_nunique price_momentum price_momentum_m price_momentum_y \n", - "0 61 1.0 1.019868 1.0 \n", - "1 61 1.0 1.019868 1.0 \n", - "2 61 1.0 1.019868 1.0 \n", - "3 61 1.0 1.019868 1.0 \n", - "4 61 1.0 1.024958 1.0 \n", - "... ... ... ... ... \n", - "47735392 47 1.0 1.000000 1.0 \n", - "47735393 47 1.0 1.000000 1.0 \n", - "47735394 47 1.0 1.000000 1.0 \n", - "47735395 47 1.0 1.000000 1.0 \n", - "47735396 47 1.0 1.000000 1.0 \n", - "\n", - "[47735397 rows x 13 columns]" - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# After merging price_df, keep columns id and day_id from grid_df and drop all other columns from grid_df\n", - "original_columns = list(grid_df)\n", - "grid_df = grid_df.merge(prices_df, on=[\"store_id\", \"item_id\", \"wm_yr_wk\"], how=\"left\")\n", - "columns_to_keep = [\"id\", \"day_id\"] + [\n", - " col for col in list(grid_df) if col not in original_columns\n", - "]\n", - "grid_df = grid_df[[\"id\", \"day_id\"] + columns_to_keep]\n", - "grid_df" - ] - }, - { - "cell_type": "markdown", - "id": "26dca23e-ff21-4f37-b93a-8c4edfdafe29", - "metadata": {}, - "source": [ - "We persist the combined table to disk." - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "ee4b6223-c10e-409d-9d27-987e7c626b4c", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "grid_df.to_pandas().to_pickle(processed_data_dir + \"grid_df_part2.pkl\")" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.12" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/source/examples/time-series-forecasting-with-hpo/preprocessing_part3.ipynb b/source/examples/time-series-forecasting-with-hpo/preprocessing_part3.ipynb deleted file mode 100644 index e9f43755..00000000 --- a/source/examples/time-series-forecasting-with-hpo/preprocessing_part3.ipynb +++ /dev/null @@ -1,854 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "aaaa81be-da18-4108-aca2-1dc8e28ac34b", - "metadata": {}, - "source": [ - "# Data preprocesing, Part 3" - ] - }, - { - "cell_type": "markdown", - "id": "52de27ff-2ff7-4da8-a836-96cb109940ae", - "metadata": {}, - "source": [ - "## Import modules" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "29dd3d72-b823-4eaa-afb5-1e06c7311278", - "metadata": {}, - "outputs": [], - "source": [ - "import cudf\n", - "import numpy as np\n", - "import pandas as pd\n", - "import cupy as cp\n", - "import gc" - ] - }, - { - "cell_type": "markdown", - "id": "4b598eb9-d289-48f5-8e4d-30b84e7000d9", - "metadata": {}, - "source": [ - "## Load data" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "d9f842e3-51b8-4a06-9ad8-2c5832682c97", - "metadata": {}, - "outputs": [], - "source": [ - "raw_data_dir = \"./data/\"\n", - "processed_data_dir = \"./processed_data/\"" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "97f8792c-d41a-4e1d-ac9f-7a498937ead6", - "metadata": {}, - "outputs": [], - "source": [ - "calendar_df = cudf.read_csv(raw_data_dir + \"calendar.csv\").rename(\n", - " columns={\"d\": \"day_id\"}\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "8a519b84-528d-45eb-97e4-08cf53775e74", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
idday_id
0FOODS_1_001_CA_1_evaluationd_1537
1FOODS_1_001_CA_1_evaluationd_1538
2FOODS_1_001_CA_1_evaluationd_1539
3FOODS_1_001_CA_1_evaluationd_1540
4FOODS_1_001_CA_1_evaluationd_1541
.........
47735392HOUSEHOLD_2_516_WI_3_evaluationd_52
47735393HOUSEHOLD_2_516_WI_3_evaluationd_53
47735394HOUSEHOLD_2_516_WI_3_evaluationd_54
47735395HOUSEHOLD_2_516_WI_3_evaluationd_55
47735396HOUSEHOLD_2_516_WI_3_evaluationd_49
\n", - "

47735397 rows × 2 columns

\n", - "
" - ], - "text/plain": [ - " id day_id\n", - "0 FOODS_1_001_CA_1_evaluation d_1537\n", - "1 FOODS_1_001_CA_1_evaluation d_1538\n", - "2 FOODS_1_001_CA_1_evaluation d_1539\n", - "3 FOODS_1_001_CA_1_evaluation d_1540\n", - "4 FOODS_1_001_CA_1_evaluation d_1541\n", - "... ... ...\n", - "47735392 HOUSEHOLD_2_516_WI_3_evaluation d_52\n", - "47735393 HOUSEHOLD_2_516_WI_3_evaluation d_53\n", - "47735394 HOUSEHOLD_2_516_WI_3_evaluation d_54\n", - "47735395 HOUSEHOLD_2_516_WI_3_evaluation d_55\n", - "47735396 HOUSEHOLD_2_516_WI_3_evaluation d_49\n", - "\n", - "[47735397 rows x 2 columns]" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "grid_df = cudf.from_pandas(pd.read_pickle(processed_data_dir + \"grid_df_part1.pkl\"))\n", - "grid_df = grid_df[[\"id\", \"day_id\"]]\n", - "grid_df" - ] - }, - { - "cell_type": "markdown", - "id": "9fbca2d4-fea9-4bfe-88f9-3c6f0fd8dc50", - "metadata": {}, - "source": [ - "## Generate date-related features" - ] - }, - { - "cell_type": "markdown", - "id": "dff9fb9f-4f72-445a-8e49-0136c98481d1", - "metadata": {}, - "source": [ - "We first identify the date in each row of `grid_df` using information from `calendar_df`." - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "ed28f3b3-59ef-4bc0-9e39-6639e6d3c7c7", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
idday_iddateevent_name_1event_type_1event_name_2event_type_2snap_CAsnap_TXsnap_WI
0FOODS_1_001_TX_3_evaluationd_15222015-03-30<NA><NA><NA><NA>000
1FOODS_1_001_TX_3_evaluationd_15232015-03-31<NA><NA><NA><NA>000
2FOODS_1_001_TX_3_evaluationd_15242015-04-01<NA><NA><NA><NA>110
3FOODS_1_001_TX_3_evaluationd_15252015-04-02<NA><NA><NA><NA>101
4FOODS_1_001_TX_3_evaluationd_15262015-04-03<NA><NA><NA><NA>111
.................................
47735392HOUSEHOLD_2_516_WI_3_evaluationd_522011-03-21<NA><NA><NA><NA>000
47735393HOUSEHOLD_2_516_WI_3_evaluationd_532011-03-22<NA><NA><NA><NA>000
47735394HOUSEHOLD_2_516_WI_3_evaluationd_542011-03-23<NA><NA><NA><NA>000
47735395HOUSEHOLD_2_516_WI_3_evaluationd_552011-03-24<NA><NA><NA><NA>000
47735396HOUSEHOLD_2_516_WI_3_evaluationd_492011-03-18<NA><NA><NA><NA>000
\n", - "

47735397 rows × 10 columns

\n", - "
" - ], - "text/plain": [ - " id day_id date event_name_1 \\\n", - "0 FOODS_1_001_TX_3_evaluation d_1522 2015-03-30 \n", - "1 FOODS_1_001_TX_3_evaluation d_1523 2015-03-31 \n", - "2 FOODS_1_001_TX_3_evaluation d_1524 2015-04-01 \n", - "3 FOODS_1_001_TX_3_evaluation d_1525 2015-04-02 \n", - "4 FOODS_1_001_TX_3_evaluation d_1526 2015-04-03 \n", - "... ... ... ... ... \n", - "47735392 HOUSEHOLD_2_516_WI_3_evaluation d_52 2011-03-21 \n", - "47735393 HOUSEHOLD_2_516_WI_3_evaluation d_53 2011-03-22 \n", - "47735394 HOUSEHOLD_2_516_WI_3_evaluation d_54 2011-03-23 \n", - "47735395 HOUSEHOLD_2_516_WI_3_evaluation d_55 2011-03-24 \n", - "47735396 HOUSEHOLD_2_516_WI_3_evaluation d_49 2011-03-18 \n", - "\n", - " event_type_1 event_name_2 event_type_2 snap_CA snap_TX snap_WI \n", - "0 0 0 0 \n", - "1 0 0 0 \n", - "2 1 1 0 \n", - "3 1 0 1 \n", - "4 1 1 1 \n", - "... ... ... ... ... ... ... \n", - "47735392 0 0 0 \n", - "47735393 0 0 0 \n", - "47735394 0 0 0 \n", - "47735395 0 0 0 \n", - "47735396 0 0 0 \n", - "\n", - "[47735397 rows x 10 columns]" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# Bring in the following columns from calendar_df into grid_df\n", - "icols = [\n", - " \"date\",\n", - " \"day_id\",\n", - " \"event_name_1\",\n", - " \"event_type_1\",\n", - " \"event_name_2\",\n", - " \"event_type_2\",\n", - " \"snap_CA\",\n", - " \"snap_TX\",\n", - " \"snap_WI\",\n", - "]\n", - "grid_df = grid_df.merge(calendar_df[icols], on=[\"day_id\"], how=\"left\")\n", - "grid_df" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "dac7b398-05a9-48cc-bcb1-f53987c4d872", - "metadata": {}, - "outputs": [], - "source": [ - "# Convert columns into categorical type to save memory\n", - "for col in [\n", - " \"event_name_1\",\n", - " \"event_type_1\",\n", - " \"event_name_2\",\n", - " \"event_type_2\",\n", - " \"snap_CA\",\n", - " \"snap_TX\",\n", - " \"snap_WI\",\n", - "]:\n", - " grid_df[col] = grid_df[col].astype(\"category\")\n", - "# Convert \"date\" column into timestamp type\n", - "grid_df[\"date\"] = cudf.to_datetime(grid_df[\"date\"])" - ] - }, - { - "cell_type": "markdown", - "id": "9963b630-c88a-4cb9-a04e-7b0594fc0cca", - "metadata": {}, - "source": [ - "Using the `date` column, we can generate related features, such as day, week, or month." - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "185fa77b-6cd9-47aa-a30c-1fe7e704fbb3", - "metadata": {}, - "outputs": [], - "source": [ - "grid_df[\"tm_d\"] = grid_df[\"date\"].dt.day.astype(np.int8)\n", - "grid_df[\"tm_w\"] = grid_df[\"date\"].dt.isocalendar().week.astype(np.int8)\n", - "grid_df[\"tm_m\"] = grid_df[\"date\"].dt.month.astype(np.int8)\n", - "grid_df[\"tm_y\"] = grid_df[\"date\"].dt.year\n", - "grid_df[\"tm_y\"] = (grid_df[\"tm_y\"] - grid_df[\"tm_y\"].min()).astype(np.int8)\n", - "grid_df[\"tm_wm\"] = cp.ceil(grid_df[\"tm_d\"].to_cupy() / 7).astype(\n", - " np.int8\n", - ") # which week in tje month?\n", - "grid_df[\"tm_dw\"] = grid_df[\"date\"].dt.dayofweek.astype(\n", - " np.int8\n", - ") # which day in the week?\n", - "grid_df[\"tm_w_end\"] = (grid_df[\"tm_dw\"] >= 5).astype(\n", - " np.int8\n", - ") # whether today is in the weekend\n", - "del grid_df[\"date\"] # no longer needed" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "6674ebe5-ef74-4e81-9ec1-41446707b3e3", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
idday_idevent_name_1event_type_1event_name_2event_type_2snap_CAsnap_TXsnap_WItm_dtm_wtm_mtm_ytm_wmtm_dwtm_w_end
0FOODS_1_001_TX_3_evaluationd_1522<NA><NA><NA><NA>000301434500
1FOODS_1_001_TX_3_evaluationd_1523<NA><NA><NA><NA>000311434510
2FOODS_1_001_TX_3_evaluationd_1524<NA><NA><NA><NA>11011444120
3FOODS_1_001_TX_3_evaluationd_1525<NA><NA><NA><NA>10121444130
4FOODS_1_001_TX_3_evaluationd_1526<NA><NA><NA><NA>11131444140
...................................................
47735392HOUSEHOLD_2_516_WI_3_evaluationd_52<NA><NA><NA><NA>000211230300
47735393HOUSEHOLD_2_516_WI_3_evaluationd_53<NA><NA><NA><NA>000221230410
47735394HOUSEHOLD_2_516_WI_3_evaluationd_54<NA><NA><NA><NA>000231230420
47735395HOUSEHOLD_2_516_WI_3_evaluationd_55<NA><NA><NA><NA>000241230430
47735396HOUSEHOLD_2_516_WI_3_evaluationd_49<NA><NA><NA><NA>000181130340
\n", - "

47735397 rows × 16 columns

\n", - "
" - ], - "text/plain": [ - " id day_id event_name_1 event_type_1 \\\n", - "0 FOODS_1_001_TX_3_evaluation d_1522 \n", - "1 FOODS_1_001_TX_3_evaluation d_1523 \n", - "2 FOODS_1_001_TX_3_evaluation d_1524 \n", - "3 FOODS_1_001_TX_3_evaluation d_1525 \n", - "4 FOODS_1_001_TX_3_evaluation d_1526 \n", - "... ... ... ... ... \n", - "47735392 HOUSEHOLD_2_516_WI_3_evaluation d_52 \n", - "47735393 HOUSEHOLD_2_516_WI_3_evaluation d_53 \n", - "47735394 HOUSEHOLD_2_516_WI_3_evaluation d_54 \n", - "47735395 HOUSEHOLD_2_516_WI_3_evaluation d_55 \n", - "47735396 HOUSEHOLD_2_516_WI_3_evaluation d_49 \n", - "\n", - " event_name_2 event_type_2 snap_CA snap_TX snap_WI tm_d tm_w tm_m \\\n", - "0 0 0 0 30 14 3 \n", - "1 0 0 0 31 14 3 \n", - "2 1 1 0 1 14 4 \n", - "3 1 0 1 2 14 4 \n", - "4 1 1 1 3 14 4 \n", - "... ... ... ... ... ... ... ... ... \n", - "47735392 0 0 0 21 12 3 \n", - "47735393 0 0 0 22 12 3 \n", - "47735394 0 0 0 23 12 3 \n", - "47735395 0 0 0 24 12 3 \n", - "47735396 0 0 0 18 11 3 \n", - "\n", - " tm_y tm_wm tm_dw tm_w_end \n", - "0 4 5 0 0 \n", - "1 4 5 1 0 \n", - "2 4 1 2 0 \n", - "3 4 1 3 0 \n", - "4 4 1 4 0 \n", - "... ... ... ... ... \n", - "47735392 0 3 0 0 \n", - "47735393 0 4 1 0 \n", - "47735394 0 4 2 0 \n", - "47735395 0 4 3 0 \n", - "47735396 0 3 4 0 \n", - "\n", - "[47735397 rows x 16 columns]" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "grid_df" - ] - }, - { - "cell_type": "markdown", - "id": "4dab0da0-77e1-4462-bae7-d8a8f6f35aa7", - "metadata": {}, - "source": [ - "Now we can persist the table to the disk." - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "52065927-8a8b-4b3c-8fde-355b2961645f", - "metadata": {}, - "outputs": [], - "source": [ - "grid_df.to_pandas().to_pickle(processed_data_dir + \"grid_df_part3.pkl\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7b94402c-293f-4e8e-bddc-8416809da94f", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.12" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/source/examples/time-series-forecasting-with-hpo/preprocessing_part4.ipynb b/source/examples/time-series-forecasting-with-hpo/preprocessing_part4.ipynb deleted file mode 100644 index 5f101042..00000000 --- a/source/examples/time-series-forecasting-with-hpo/preprocessing_part4.ipynb +++ /dev/null @@ -1,628 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "24f14a4c-4adf-4fe8-9f20-5a95f4a40fca", - "metadata": {}, - "source": [ - "# Data preprocesing, Part 4" - ] - }, - { - "cell_type": "markdown", - "id": "7c34d719-2223-4593-8eaf-95bad7b3d1d4", - "metadata": {}, - "source": [ - "## Import modules" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "5b6b0e92-e176-4b22-be2c-374ab2a1c037", - "metadata": {}, - "outputs": [], - "source": [ - "import cudf\n", - "import numpy as np\n", - "import pandas as pd\n", - "import gc" - ] - }, - { - "cell_type": "markdown", - "id": "509f32d0-9adb-4eda-9a16-02b37d1e2b8f", - "metadata": {}, - "source": [ - "## Load data" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "d013b288-07ff-4414-b5c4-985584158ed7", - "metadata": {}, - "outputs": [], - "source": [ - "raw_data_dir = \"./data/\"\n", - "processed_data_dir = \"./processed_data/\"" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "a8ec4415-1bca-4542-a28e-df810b1bc8bf", - "metadata": {}, - "outputs": [], - "source": [ - "grid_df = cudf.DataFrame(pd.read_pickle(processed_data_dir + \"grid_df_part1.pkl\"))\n", - "grid_df = grid_df[[\"id\", \"day_id\", \"sales\"]]\n", - "SHIFT_DAY = 28" - ] - }, - { - "cell_type": "markdown", - "id": "7b17f489-a11c-4152-bedd-da22a5a08782", - "metadata": {}, - "source": [ - "## Generate lag features" - ] - }, - { - "cell_type": "markdown", - "id": "35c3ce2b-164c-464d-a3d1-c84311cac9ae", - "metadata": {}, - "source": [ - "**Lag features** are the value of the target variable at prior timestamps. Lag features are useful because what happens in the past often influences what would happen in the future. In our example, we generate lag features by reading the sales amount at X days prior, where X = 28, 29, ..., 42." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "7cb4b9a2-1332-41ef-8a17-ebe31037d712", - "metadata": {}, - "outputs": [], - "source": [ - "LAG_DAYS = [col for col in range(SHIFT_DAY, SHIFT_DAY + 15)]\n", - "\n", - "# Need to first ensure that rows in each time series are sorted by day_id\n", - "grid_df = grid_df.sort_values([\"id\", \"day_id\"])\n", - "\n", - "grid_df = grid_df.assign(\n", - " **{f\"sales_lag_{l}\": grid_df.groupby([\"id\"])[\"sales\"].shift(l) for l in LAG_DAYS}\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "92377656-ce8e-4101-88ba-260c14b9583b", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
idday_idsalessales_lag_28sales_lag_29sales_lag_30sales_lag_31sales_lag_32sales_lag_33sales_lag_34sales_lag_35sales_lag_36sales_lag_37sales_lag_38sales_lag_39sales_lag_40sales_lag_41sales_lag_42
13225FOODS_1_001_CA_1_evaluationd_13.0<NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA>
13226FOODS_1_001_CA_1_evaluationd_20.0<NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA>
13227FOODS_1_001_CA_1_evaluationd_30.0<NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA>
13228FOODS_1_001_CA_1_evaluationd_41.0<NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA>
13229FOODS_1_001_CA_1_evaluationd_54.0<NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA>
.........................................................
47734672HOUSEHOLD_2_516_WI_3_evaluationd_1965<NA>0.00.00.00.01.00.00.00.00.00.00.00.00.00.00.0
47734673HOUSEHOLD_2_516_WI_3_evaluationd_1966<NA>0.00.00.00.00.01.00.00.00.00.00.00.00.00.00.0
47734674HOUSEHOLD_2_516_WI_3_evaluationd_1967<NA>0.00.00.00.00.00.01.00.00.00.00.00.00.00.00.0
47734675HOUSEHOLD_2_516_WI_3_evaluationd_1968<NA>0.00.00.00.00.00.00.01.00.00.00.00.00.00.00.0
47734676HOUSEHOLD_2_516_WI_3_evaluationd_1969<NA>0.00.00.00.00.00.00.00.01.00.00.00.00.00.00.0
\n", - "

47735397 rows × 18 columns

\n", - "
" - ], - "text/plain": [ - " id day_id sales sales_lag_28 \\\n", - "13225 FOODS_1_001_CA_1_evaluation d_1 3.0 \n", - "13226 FOODS_1_001_CA_1_evaluation d_2 0.0 \n", - "13227 FOODS_1_001_CA_1_evaluation d_3 0.0 \n", - "13228 FOODS_1_001_CA_1_evaluation d_4 1.0 \n", - "13229 FOODS_1_001_CA_1_evaluation d_5 4.0 \n", - "... ... ... ... ... \n", - "47734672 HOUSEHOLD_2_516_WI_3_evaluation d_1965 0.0 \n", - "47734673 HOUSEHOLD_2_516_WI_3_evaluation d_1966 0.0 \n", - "47734674 HOUSEHOLD_2_516_WI_3_evaluation d_1967 0.0 \n", - "47734675 HOUSEHOLD_2_516_WI_3_evaluation d_1968 0.0 \n", - "47734676 HOUSEHOLD_2_516_WI_3_evaluation d_1969 0.0 \n", - "\n", - " sales_lag_29 sales_lag_30 sales_lag_31 sales_lag_32 sales_lag_33 \\\n", - "13225 \n", - "13226 \n", - "13227 \n", - "13228 \n", - "13229 \n", - "... ... ... ... ... ... \n", - "47734672 0.0 0.0 0.0 1.0 0.0 \n", - "47734673 0.0 0.0 0.0 0.0 1.0 \n", - "47734674 0.0 0.0 0.0 0.0 0.0 \n", - "47734675 0.0 0.0 0.0 0.0 0.0 \n", - "47734676 0.0 0.0 0.0 0.0 0.0 \n", - "\n", - " sales_lag_34 sales_lag_35 sales_lag_36 sales_lag_37 sales_lag_38 \\\n", - "13225 \n", - "13226 \n", - "13227 \n", - "13228 \n", - "13229 \n", - "... ... ... ... ... ... \n", - "47734672 0.0 0.0 0.0 0.0 0.0 \n", - "47734673 0.0 0.0 0.0 0.0 0.0 \n", - "47734674 1.0 0.0 0.0 0.0 0.0 \n", - "47734675 0.0 1.0 0.0 0.0 0.0 \n", - "47734676 0.0 0.0 1.0 0.0 0.0 \n", - "\n", - " sales_lag_39 sales_lag_40 sales_lag_41 sales_lag_42 \n", - "13225 \n", - "13226 \n", - "13227 \n", - "13228 \n", - "13229 \n", - "... ... ... ... ... \n", - "47734672 0.0 0.0 0.0 0.0 \n", - "47734673 0.0 0.0 0.0 0.0 \n", - "47734674 0.0 0.0 0.0 0.0 \n", - "47734675 0.0 0.0 0.0 0.0 \n", - "47734676 0.0 0.0 0.0 0.0 \n", - "\n", - "[47735397 rows x 18 columns]" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "grid_df" - ] - }, - { - "cell_type": "markdown", - "id": "2778fa50-d38b-4eb9-9ce4-a52a6455ac91", - "metadata": {}, - "source": [ - "## Compute rolling window statistics" - ] - }, - { - "cell_type": "markdown", - "id": "aabcb64a-2a43-4f8f-a1f8-5425c9fec2d3", - "metadata": {}, - "source": [ - "In the previous cell, we used the value of sales at a single timestamp to generate lag features. To capture richer information about the past, let us also get the distribution of the sales value over multiple timestamps, by computing **rolling window statistics**. Rolling window statistics are statistics (e.g. mean, standard deviation) over a time duration in the past. Rolling windows statistics complement lag features and provide more information about the past behavior of the target variable.\n", - "\n", - "Read more about lag features and rolling window statistics in [Introduction to feature engineering for time series forecasting](https://medium.com/data-science-at-microsoft/introduction-to-feature-engineering-for-time-series-forecasting-620aa55fcab0)." - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "9cd8023a-4727-4365-8bc5-693d6dcd4979", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Shift size: 28\n", - " Window size: 7\n", - " Window size: 14\n", - " Window size: 30\n", - " Window size: 60\n", - " Window size: 180\n" - ] - } - ], - "source": [ - "# Shift by 28 days and apply windows of various sizes\n", - "print(f\"Shift size: {SHIFT_DAY}\")\n", - "for i in [7, 14, 30, 60, 180]:\n", - " print(f\" Window size: {i}\")\n", - " grid_df[f\"rolling_mean_{i}\"] = (\n", - " grid_df.groupby([\"id\"])[\"sales\"]\n", - " .shift(SHIFT_DAY)\n", - " .rolling(i)\n", - " .mean()\n", - " .astype(np.float32)\n", - " )\n", - " grid_df[f\"rolling_std_{i}\"] = (\n", - " grid_df.groupby([\"id\"])[\"sales\"]\n", - " .shift(SHIFT_DAY)\n", - " .rolling(i)\n", - " .std()\n", - " .astype(np.float32)\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "6b4197ca-23a4-4b17-981f-1d35b5f6b671", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Index(['id', 'day_id', 'sales', 'sales_lag_28', 'sales_lag_29', 'sales_lag_30',\n", - " 'sales_lag_31', 'sales_lag_32', 'sales_lag_33', 'sales_lag_34',\n", - " 'sales_lag_35', 'sales_lag_36', 'sales_lag_37', 'sales_lag_38',\n", - " 'sales_lag_39', 'sales_lag_40', 'sales_lag_41', 'sales_lag_42',\n", - " 'rolling_mean_7', 'rolling_std_7', 'rolling_mean_14', 'rolling_std_14',\n", - " 'rolling_mean_30', 'rolling_std_30', 'rolling_mean_60',\n", - " 'rolling_std_60', 'rolling_mean_180', 'rolling_std_180'],\n", - " dtype='object')" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "grid_df.columns" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "f6b439b6-3b44-44f4-84bd-4c99edc59448", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "id category\n", - "day_id category\n", - "sales float32\n", - "sales_lag_28 float32\n", - "sales_lag_29 float32\n", - "sales_lag_30 float32\n", - "sales_lag_31 float32\n", - "sales_lag_32 float32\n", - "sales_lag_33 float32\n", - "sales_lag_34 float32\n", - "sales_lag_35 float32\n", - "sales_lag_36 float32\n", - "sales_lag_37 float32\n", - "sales_lag_38 float32\n", - "sales_lag_39 float32\n", - "sales_lag_40 float32\n", - "sales_lag_41 float32\n", - "sales_lag_42 float32\n", - "rolling_mean_7 float32\n", - "rolling_std_7 float32\n", - "rolling_mean_14 float32\n", - "rolling_std_14 float32\n", - "rolling_mean_30 float32\n", - "rolling_std_30 float32\n", - "rolling_mean_60 float32\n", - "rolling_std_60 float32\n", - "rolling_mean_180 float32\n", - "rolling_std_180 float32\n", - "dtype: object" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "grid_df.dtypes" - ] - }, - { - "cell_type": "markdown", - "id": "79a2ae07-fa57-4579-9443-0cc2bb15a922", - "metadata": {}, - "source": [ - "Once lag features and rolling window statistics are computed, persist them to the disk." - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "a6cd2826-0d33-42b8-8616-227a717639d0", - "metadata": {}, - "outputs": [], - "source": [ - "grid_df.to_pandas().to_pickle(processed_data_dir + \"lags_df_28.pkl\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f4979567-1c59-47cf-83c2-21d61d06b701", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.12" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/source/examples/time-series-forecasting-with-hpo/preprocessing_part5.ipynb b/source/examples/time-series-forecasting-with-hpo/preprocessing_part5.ipynb deleted file mode 100644 index 4cdb522e..00000000 --- a/source/examples/time-series-forecasting-with-hpo/preprocessing_part5.ipynb +++ /dev/null @@ -1,596 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "0929302e-1b8d-49ed-b4ba-be26f7579ec0", - "metadata": {}, - "source": [ - "# Data preprocesing, Part 5" - ] - }, - { - "cell_type": "markdown", - "id": "5ccecbb0-e81e-41e8-bb71-539d3e02bb73", - "metadata": {}, - "source": [ - "## Import modules" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "c6032bfd-3dd7-4d5c-a043-51945b068826", - "metadata": {}, - "outputs": [], - "source": [ - "import cudf\n", - "import numpy as np\n", - "import pandas as pd\n", - "import gc" - ] - }, - { - "cell_type": "markdown", - "id": "a14ddfe9-82b5-41ab-8612-a3e626e1792c", - "metadata": {}, - "source": [ - "## Load data" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "a4c537de-a406-42d4-8ad8-b2f9aad6bb80", - "metadata": {}, - "outputs": [], - "source": [ - "raw_data_dir = \"./data/\"\n", - "processed_data_dir = \"./processed_data/\"" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "882bfc6a-24f5-4690-a5ec-543f578a0843", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
iditem_iddept_idcat_idstore_idstate_idday_idsalesrelease_weekwm_yr_wk
0FOODS_1_001_CA_1_evaluationFOODS_1_001FOODS_1FOODSCA_1CAd_15371.01110111511
1FOODS_1_001_CA_1_evaluationFOODS_1_001FOODS_1FOODSCA_1CAd_15380.01110111511
2FOODS_1_001_CA_1_evaluationFOODS_1_001FOODS_1FOODSCA_1CAd_15392.01110111511
3FOODS_1_001_CA_1_evaluationFOODS_1_001FOODS_1FOODSCA_1CAd_15400.01110111511
4FOODS_1_001_CA_1_evaluationFOODS_1_001FOODS_1FOODSCA_1CAd_15410.01110111512
.................................
47735392HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_520.01110111108
47735393HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_530.01110111108
47735394HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_540.01110111108
47735395HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_550.01110111108
47735396HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_490.01110111107
\n", - "

47735397 rows × 10 columns

\n", - "
" - ], - "text/plain": [ - " id item_id dept_id \\\n", - "0 FOODS_1_001_CA_1_evaluation FOODS_1_001 FOODS_1 \n", - "1 FOODS_1_001_CA_1_evaluation FOODS_1_001 FOODS_1 \n", - "2 FOODS_1_001_CA_1_evaluation FOODS_1_001 FOODS_1 \n", - "3 FOODS_1_001_CA_1_evaluation FOODS_1_001 FOODS_1 \n", - "4 FOODS_1_001_CA_1_evaluation FOODS_1_001 FOODS_1 \n", - "... ... ... ... \n", - "47735392 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", - "47735393 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", - "47735394 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", - "47735395 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", - "47735396 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", - "\n", - " cat_id store_id state_id day_id sales release_week wm_yr_wk \n", - "0 FOODS CA_1 CA d_1537 1.0 11101 11511 \n", - "1 FOODS CA_1 CA d_1538 0.0 11101 11511 \n", - "2 FOODS CA_1 CA d_1539 2.0 11101 11511 \n", - "3 FOODS CA_1 CA d_1540 0.0 11101 11511 \n", - "4 FOODS CA_1 CA d_1541 0.0 11101 11512 \n", - "... ... ... ... ... ... ... ... \n", - "47735392 HOUSEHOLD WI_3 WI d_52 0.0 11101 11108 \n", - "47735393 HOUSEHOLD WI_3 WI d_53 0.0 11101 11108 \n", - "47735394 HOUSEHOLD WI_3 WI d_54 0.0 11101 11108 \n", - "47735395 HOUSEHOLD WI_3 WI d_55 0.0 11101 11108 \n", - "47735396 HOUSEHOLD WI_3 WI d_49 0.0 11101 11107 \n", - "\n", - "[47735397 rows x 10 columns]" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "grid_df = cudf.DataFrame(pd.read_pickle(processed_data_dir + \"grid_df_part1.pkl\"))\n", - "grid_df" - ] - }, - { - "cell_type": "markdown", - "id": "f338a3cd-2504-435a-8ac7-7fb5f71cea76", - "metadata": {}, - "source": [ - "## Target encoding" - ] - }, - { - "cell_type": "markdown", - "id": "0db75d1c-dcf6-4404-a4fe-b954d3388add", - "metadata": {}, - "source": [ - "Categorical variables present challenges to many machine learning algorithms such as XGBoost. One way to overcome the challenge is to use **target encoding**, where we encode categorical variables by replacing them with a statistic for the target variable. In this example, we will use the mean and the standard deviation.\n", - "\n", - "Read more about target encoding in [Target-encoding Categorical Variables](https://towardsdatascience.com/dealing-with-categorical-variables-by-using-target-encoder-a0f1733a4c69)." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "4dba104c-8d86-47c5-a044-eab5796a6f2e", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Encoding columns ['store_id', 'dept_id']\n", - "Encoding columns ['item_id', 'state_id']\n" - ] - } - ], - "source": [ - "icols = [[\"store_id\", \"dept_id\"], [\"item_id\", \"state_id\"]]\n", - "new_columns = []\n", - "\n", - "for col in icols:\n", - " print(f\"Encoding columns {col}\")\n", - " col_name = \"_\" + \"_\".join(col) + \"_\"\n", - " grid_df[\"enc\" + col_name + \"mean\"] = (\n", - " grid_df.groupby(col)[\"sales\"].transform(\"mean\").astype(np.float32)\n", - " )\n", - " grid_df[\"enc\" + col_name + \"std\"] = (\n", - " grid_df.groupby(col)[\"sales\"].transform(\"std\").astype(np.float32)\n", - " )\n", - " new_columns.extend([\"enc\" + col_name + \"mean\", \"enc\" + col_name + \"std\"])" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "0cd78500-2c68-449b-8c8c-93eb57a860f5", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
idday_idenc_store_id_dept_id_meanenc_store_id_dept_id_stdenc_item_id_state_id_meanenc_item_id_state_id_std
0FOODS_1_001_CA_1_evaluationd_15371.6131123.2166720.8733901.666305
1FOODS_1_001_CA_1_evaluationd_15381.6131123.2166720.8733901.666305
2FOODS_1_001_CA_1_evaluationd_15391.6131123.2166720.8733901.666305
3FOODS_1_001_CA_1_evaluationd_15401.6131123.2166720.8733901.666305
4FOODS_1_001_CA_1_evaluationd_15411.6131123.2166720.8733901.666305
.....................
47735392HOUSEHOLD_2_516_WI_3_evaluationd_520.2614860.6663800.0832760.301445
47735393HOUSEHOLD_2_516_WI_3_evaluationd_530.2614860.6663800.0832760.301445
47735394HOUSEHOLD_2_516_WI_3_evaluationd_540.2614860.6663800.0832760.301445
47735395HOUSEHOLD_2_516_WI_3_evaluationd_550.2614860.6663800.0832760.301445
47735396HOUSEHOLD_2_516_WI_3_evaluationd_490.2614860.6663800.0832760.301445
\n", - "

47735397 rows × 6 columns

\n", - "
" - ], - "text/plain": [ - " id day_id enc_store_id_dept_id_mean \\\n", - "0 FOODS_1_001_CA_1_evaluation d_1537 1.613112 \n", - "1 FOODS_1_001_CA_1_evaluation d_1538 1.613112 \n", - "2 FOODS_1_001_CA_1_evaluation d_1539 1.613112 \n", - "3 FOODS_1_001_CA_1_evaluation d_1540 1.613112 \n", - "4 FOODS_1_001_CA_1_evaluation d_1541 1.613112 \n", - "... ... ... ... \n", - "47735392 HOUSEHOLD_2_516_WI_3_evaluation d_52 0.261486 \n", - "47735393 HOUSEHOLD_2_516_WI_3_evaluation d_53 0.261486 \n", - "47735394 HOUSEHOLD_2_516_WI_3_evaluation d_54 0.261486 \n", - "47735395 HOUSEHOLD_2_516_WI_3_evaluation d_55 0.261486 \n", - "47735396 HOUSEHOLD_2_516_WI_3_evaluation d_49 0.261486 \n", - "\n", - " enc_store_id_dept_id_std enc_item_id_state_id_mean \\\n", - "0 3.216672 0.873390 \n", - "1 3.216672 0.873390 \n", - "2 3.216672 0.873390 \n", - "3 3.216672 0.873390 \n", - "4 3.216672 0.873390 \n", - "... ... ... \n", - "47735392 0.666380 0.083276 \n", - "47735393 0.666380 0.083276 \n", - "47735394 0.666380 0.083276 \n", - "47735395 0.666380 0.083276 \n", - "47735396 0.666380 0.083276 \n", - "\n", - " enc_item_id_state_id_std \n", - "0 1.666305 \n", - "1 1.666305 \n", - "2 1.666305 \n", - "3 1.666305 \n", - "4 1.666305 \n", - "... ... \n", - "47735392 0.301445 \n", - "47735393 0.301445 \n", - "47735394 0.301445 \n", - "47735395 0.301445 \n", - "47735396 0.301445 \n", - "\n", - "[47735397 rows x 6 columns]" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "grid_df = grid_df[[\"id\", \"day_id\"] + new_columns]\n", - "grid_df" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "44fc899c-9567-4d70-a5d8-cd37648900b3", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "id category\n", - "day_id category\n", - "enc_store_id_dept_id_mean float32\n", - "enc_store_id_dept_id_std float32\n", - "enc_item_id_state_id_mean float32\n", - "enc_item_id_state_id_std float32\n", - "dtype: object" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "grid_df.dtypes" - ] - }, - { - "cell_type": "markdown", - "id": "8ebcfb76-d807-4fb1-93ab-e30b2e38c173", - "metadata": {}, - "source": [ - "Once we computed the target encoding, we persist the table to the disk." - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "c5cfa397-a31f-4f28-aecd-255bf94ea90b", - "metadata": {}, - "outputs": [], - "source": [ - "grid_df.to_pandas().to_pickle(processed_data_dir + \"target_encoding_df.pkl\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "438484e3-deb4-437f-89cd-7603c1f8b9f7", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.12" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/source/examples/time-series-forecasting-with-hpo/preprocessing_part6.ipynb b/source/examples/time-series-forecasting-with-hpo/preprocessing_part6.ipynb deleted file mode 100644 index 149c77ca..00000000 --- a/source/examples/time-series-forecasting-with-hpo/preprocessing_part6.ipynb +++ /dev/null @@ -1,511 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "a7f7926c-e8f7-41f0-98a4-1211da546adc", - "metadata": {}, - "source": [ - "# Data preprocesing, Part 6" - ] - }, - { - "cell_type": "markdown", - "id": "d21c08f5-07b8-4f01-b584-d0c6d2bdc86e", - "metadata": {}, - "source": [ - "## Import modules" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "594c2c7f-3fff-4cba-86e7-8f3031ea21d9", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "import cudf\n", - "import numpy as np\n", - "import pandas as pd\n", - "import gc\n", - "import glob\n", - "import pathlib\n", - "import gcsfs" - ] - }, - { - "cell_type": "markdown", - "id": "a33cb0b4-9e25-4965-8313-42707160e4fd", - "metadata": {}, - "source": [ - "Enter the name of the Cloud Storage bucket you used in `start_here.ipynb`." - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "9fc13697-f9f2-4931-8fee-85148c600d89", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "bucket_name = \"\"" - ] - }, - { - "cell_type": "markdown", - "id": "23de1e65-8e82-4339-baa6-696abd247f22", - "metadata": {}, - "source": [ - "## Filter by store and product department and create data segments" - ] - }, - { - "cell_type": "markdown", - "id": "bfe723f6-7c62-4de5-99c6-c3a31253be61", - "metadata": {}, - "source": [ - "After combining all columns produced in the previous notebooks, we filter the rows in the data set by `store_id` and `dept_id` and create a segment. Each segment is saved as a pickle file and then upload to Cloud Storage." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "dc332147-56fb-4f11-95f1-12102aa6f1cf", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "processed_data_dir = \"./processed_data/\"\n", - "segmented_data_dir = \"./segmented_data/\"\n", - "pathlib.Path(segmented_data_dir).mkdir(exist_ok=True)\n", - "\n", - "STORES = [\n", - " \"CA_1\",\n", - " \"CA_2\",\n", - " \"CA_3\",\n", - " \"CA_4\",\n", - " \"TX_1\",\n", - " \"TX_2\",\n", - " \"TX_3\",\n", - " \"WI_1\",\n", - " \"WI_2\",\n", - " \"WI_3\",\n", - "]\n", - "DEPTS = [\n", - " \"HOBBIES_1\",\n", - " \"HOBBIES_2\",\n", - " \"HOUSEHOLD_1\",\n", - " \"HOUSEHOLD_2\",\n", - " \"FOODS_1\",\n", - " \"FOODS_2\",\n", - " \"FOODS_3\",\n", - "]\n", - "\n", - "grid2_colnm = [\n", - " \"sell_price\",\n", - " \"price_max\",\n", - " \"price_min\",\n", - " \"price_std\",\n", - " \"price_mean\",\n", - " \"price_norm\",\n", - " \"price_nunique\",\n", - " \"item_nunique\",\n", - " \"price_momentum\",\n", - " \"price_momentum_m\",\n", - " \"price_momentum_y\",\n", - "]\n", - "\n", - "grid3_colnm = [\n", - " \"event_name_1\",\n", - " \"event_type_1\",\n", - " \"event_name_2\",\n", - " \"event_type_2\",\n", - " \"snap_CA\",\n", - " \"snap_TX\",\n", - " \"snap_WI\",\n", - " \"tm_d\",\n", - " \"tm_w\",\n", - " \"tm_m\",\n", - " \"tm_y\",\n", - " \"tm_wm\",\n", - " \"tm_dw\",\n", - " \"tm_w_end\",\n", - "]\n", - "\n", - "lag_colnm = [\n", - " \"sales_lag_28\",\n", - " \"sales_lag_29\",\n", - " \"sales_lag_30\",\n", - " \"sales_lag_31\",\n", - " \"sales_lag_32\",\n", - " \"sales_lag_33\",\n", - " \"sales_lag_34\",\n", - " \"sales_lag_35\",\n", - " \"sales_lag_36\",\n", - " \"sales_lag_37\",\n", - " \"sales_lag_38\",\n", - " \"sales_lag_39\",\n", - " \"sales_lag_40\",\n", - " \"sales_lag_41\",\n", - " \"sales_lag_42\",\n", - " \"rolling_mean_7\",\n", - " \"rolling_std_7\",\n", - " \"rolling_mean_14\",\n", - " \"rolling_std_14\",\n", - " \"rolling_mean_30\",\n", - " \"rolling_std_30\",\n", - " \"rolling_mean_60\",\n", - " \"rolling_std_60\",\n", - " \"rolling_mean_180\",\n", - " \"rolling_std_180\",\n", - "]\n", - "\n", - "target_enc_colnm = [\n", - " \"enc_store_id_dept_id_mean\",\n", - " \"enc_store_id_dept_id_std\",\n", - " \"enc_item_id_state_id_mean\",\n", - " \"enc_item_id_state_id_std\",\n", - "]" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "36fbfd16-83e6-42ab-a337-5dd9a009cd7a", - "metadata": {}, - "outputs": [], - "source": [ - "def prepare_data(store, dept=None):\n", - " \"\"\"\n", - " Filter and clean data according to stores and product departments\n", - "\n", - " Parameters\n", - " ----------\n", - " store: Filter data by retaining rows whose store_id matches this parameter.\n", - " dept: Filter data by retaining rows whose dept_id matches this parameter.\n", - " This parameter can be set to None to indicate that we shouldn't filter by dept_id.\n", - " \"\"\"\n", - " if store is None:\n", - " raise ValueError(f\"store parameter must not be None\")\n", - "\n", - " grid1 = cudf.DataFrame(pd.read_pickle(processed_data_dir + \"grid_df_part1.pkl\"))\n", - "\n", - " if dept is None:\n", - " grid1 = grid1[grid1[\"store_id\"] == store]\n", - " else:\n", - " grid1 = grid1[(grid1[\"store_id\"] == store) & (grid1[\"dept_id\"] == dept)].drop(\n", - " columns=[\"dept_id\"]\n", - " )\n", - " grid1 = grid1.drop(columns=[\"release_week\", \"wm_yr_wk\", \"store_id\", \"state_id\"])\n", - "\n", - " grid2 = cudf.DataFrame(pd.read_pickle(processed_data_dir + \"grid_df_part2.pkl\"))[\n", - " [\"id\", \"day_id\"] + grid2_colnm\n", - " ]\n", - " grid_df = grid1.merge(grid2, on=[\"id\", \"day_id\"], how=\"left\")\n", - " del grid1, grid2\n", - "\n", - " grid3 = cudf.DataFrame(pd.read_pickle(processed_data_dir + \"grid_df_part3.pkl\"))[\n", - " [\"id\", \"day_id\"] + grid3_colnm\n", - " ]\n", - " grid_df = grid_df.merge(grid3, on=[\"id\", \"day_id\"], how=\"left\")\n", - " del grid3\n", - "\n", - " lag_df = cudf.DataFrame(pd.read_pickle(processed_data_dir + \"lags_df_28.pkl\"))[\n", - " [\"id\", \"day_id\"] + lag_colnm\n", - " ]\n", - "\n", - " grid_df = grid_df.merge(lag_df, on=[\"id\", \"day_id\"], how=\"left\")\n", - " del lag_df\n", - "\n", - " target_enc_df = cudf.DataFrame(\n", - " pd.read_pickle(processed_data_dir + \"target_encoding_df.pkl\")\n", - " )[[\"id\", \"day_id\"] + target_enc_colnm]\n", - "\n", - " grid_df = grid_df.merge(target_enc_df, on=[\"id\", \"day_id\"], how=\"left\")\n", - " del target_enc_df\n", - " gc.collect()\n", - "\n", - " grid_df = grid_df.drop(columns=[\"id\"])\n", - " grid_df[\"day_id\"] = (\n", - " grid_df[\"day_id\"]\n", - " .to_pandas()\n", - " .astype(\"str\")\n", - " .apply(lambda x: x[2:])\n", - " .astype(np.int16)\n", - " )\n", - "\n", - " return grid_df" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "9e61b762-2722-4b83-9220-326006880acd", - "metadata": { - "scrolled": true, - "tags": [] - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Processing store CA_1...\n", - "Processing store CA_2...\n", - "Processing store CA_3...\n", - "Processing store CA_4...\n", - "Processing store TX_1...\n", - "Processing store TX_2...\n", - "Processing store TX_3...\n", - "Processing store WI_1...\n", - "Processing store WI_2...\n", - "Processing store WI_3...\n", - "Processing (store CA_1, department HOBBIES_1)...\n", - "Processing (store CA_1, department HOBBIES_2)...\n", - "Processing (store CA_1, department HOUSEHOLD_1)...\n", - "Processing (store CA_1, department HOUSEHOLD_2)...\n", - "Processing (store CA_1, department FOODS_1)...\n", - "Processing (store CA_1, department FOODS_2)...\n", - "Processing (store CA_1, department FOODS_3)...\n", - "Processing (store CA_2, department HOBBIES_1)...\n", - "Processing (store CA_2, department HOBBIES_2)...\n", - "Processing (store CA_2, department HOUSEHOLD_1)...\n", - "Processing (store CA_2, department HOUSEHOLD_2)...\n", - "Processing (store CA_2, department FOODS_1)...\n", - "Processing (store CA_2, department FOODS_2)...\n", - "Processing (store CA_2, department FOODS_3)...\n", - "Processing (store CA_3, department HOBBIES_1)...\n", - "Processing (store CA_3, department HOBBIES_2)...\n", - "Processing (store CA_3, department HOUSEHOLD_1)...\n", - "Processing (store CA_3, department HOUSEHOLD_2)...\n", - "Processing (store CA_3, department FOODS_1)...\n", - "Processing (store CA_3, department FOODS_2)...\n", - "Processing (store CA_3, department FOODS_3)...\n", - "Processing (store CA_4, department HOBBIES_1)...\n", - "Processing (store CA_4, department HOBBIES_2)...\n", - "Processing (store CA_4, department HOUSEHOLD_1)...\n", - "Processing (store CA_4, department HOUSEHOLD_2)...\n", - "Processing (store CA_4, department FOODS_1)...\n", - "Processing (store CA_4, department FOODS_2)...\n", - "Processing (store CA_4, department FOODS_3)...\n", - "Processing (store TX_1, department HOBBIES_1)...\n", - "Processing (store TX_1, department HOBBIES_2)...\n", - "Processing (store TX_1, department HOUSEHOLD_1)...\n", - "Processing (store TX_1, department HOUSEHOLD_2)...\n", - "Processing (store TX_1, department FOODS_1)...\n", - "Processing (store TX_1, department FOODS_2)...\n", - "Processing (store TX_1, department FOODS_3)...\n", - "Processing (store TX_2, department HOBBIES_1)...\n", - "Processing (store TX_2, department HOBBIES_2)...\n", - "Processing (store TX_2, department HOUSEHOLD_1)...\n", - "Processing (store TX_2, department HOUSEHOLD_2)...\n", - "Processing (store TX_2, department FOODS_1)...\n", - "Processing (store TX_2, department FOODS_2)...\n", - "Processing (store TX_2, department FOODS_3)...\n", - "Processing (store TX_3, department HOBBIES_1)...\n", - "Processing (store TX_3, department HOBBIES_2)...\n", - "Processing (store TX_3, department HOUSEHOLD_1)...\n", - "Processing (store TX_3, department HOUSEHOLD_2)...\n", - "Processing (store TX_3, department FOODS_1)...\n", - "Processing (store TX_3, department FOODS_2)...\n", - "Processing (store TX_3, department FOODS_3)...\n", - "Processing (store WI_1, department HOBBIES_1)...\n", - "Processing (store WI_1, department HOBBIES_2)...\n", - "Processing (store WI_1, department HOUSEHOLD_1)...\n", - "Processing (store WI_1, department HOUSEHOLD_2)...\n", - "Processing (store WI_1, department FOODS_1)...\n", - "Processing (store WI_1, department FOODS_2)...\n", - "Processing (store WI_1, department FOODS_3)...\n", - "Processing (store WI_2, department HOBBIES_1)...\n", - "Processing (store WI_2, department HOBBIES_2)...\n", - "Processing (store WI_2, department HOUSEHOLD_1)...\n", - "Processing (store WI_2, department HOUSEHOLD_2)...\n", - "Processing (store WI_2, department FOODS_1)...\n", - "Processing (store WI_2, department FOODS_2)...\n", - "Processing (store WI_2, department FOODS_3)...\n", - "Processing (store WI_3, department HOBBIES_1)...\n", - "Processing (store WI_3, department HOBBIES_2)...\n", - "Processing (store WI_3, department HOUSEHOLD_1)...\n", - "Processing (store WI_3, department HOUSEHOLD_2)...\n", - "Processing (store WI_3, department FOODS_1)...\n", - "Processing (store WI_3, department FOODS_2)...\n", - "Processing (store WI_3, department FOODS_3)...\n" - ] - } - ], - "source": [ - "# First save the segment to the disk\n", - "for store in STORES:\n", - " print(f\"Processing store {store}...\")\n", - " grid_df = prepare_data(store=store)\n", - " grid_df.to_pandas().to_pickle(segmented_data_dir + f\"combined_df_store_{store}.pkl\")\n", - " del grid_df\n", - " gc.collect()\n", - "\n", - "for store in STORES:\n", - " for dept in DEPTS:\n", - " print(f\"Processing (store {store}, department {dept})...\")\n", - " grid_df = prepare_data(store=store, dept=dept)\n", - " grid_df.to_pandas().to_pickle(\n", - " segmented_data_dir + f\"combined_df_store_{store}_dept_{dept}.pkl\"\n", - " )\n", - " del grid_df\n", - " gc.collect()" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "6e12341b-c879-4012-8c73-8316278b9a6f", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Uploading ./segmented_data/combined_df_store_WI_1_dept_FOODS_3.pkl...\n", - "Uploading ./segmented_data/combined_df_store_TX_2_dept_FOODS_3.pkl...\n", - "Uploading ./segmented_data/combined_df_store_CA_2_dept_HOBBIES_2.pkl...\n", - "Uploading ./segmented_data/combined_df_store_WI_2_dept_HOBBIES_2.pkl...\n", - "Uploading ./segmented_data/combined_df_store_CA_1_dept_FOODS_3.pkl...\n", - "Uploading ./segmented_data/combined_df_store_TX_1_dept_FOODS_1.pkl...\n", - "Uploading ./segmented_data/combined_df_store_CA_3_dept_HOBBIES_2.pkl...\n", - "Uploading ./segmented_data/combined_df_store_CA_3_dept_HOUSEHOLD_2.pkl...\n", - "Uploading ./segmented_data/combined_df_store_WI_2_dept_HOBBIES_1.pkl...\n", - "Uploading ./segmented_data/combined_df_store_TX_1.pkl...\n", - "Uploading ./segmented_data/combined_df_store_CA_4_dept_HOBBIES_1.pkl...\n", - "Uploading ./segmented_data/combined_df_store_WI_3_dept_FOODS_1.pkl...\n", - "Uploading ./segmented_data/combined_df_store_WI_1_dept_HOBBIES_2.pkl...\n", - "Uploading ./segmented_data/combined_df_store_WI_1_dept_HOUSEHOLD_2.pkl...\n", - "Uploading ./segmented_data/combined_df_store_CA_4_dept_FOODS_1.pkl...\n", - "Uploading ./segmented_data/combined_df_store_WI_1_dept_HOUSEHOLD_1.pkl...\n", - "Uploading ./segmented_data/combined_df_store_TX_3_dept_HOUSEHOLD_1.pkl...\n", - "Uploading ./segmented_data/combined_df_store_WI_1.pkl...\n", - "Uploading ./segmented_data/combined_df_store_TX_3_dept_FOODS_1.pkl...\n", - "Uploading ./segmented_data/combined_df_store_CA_1_dept_HOBBIES_2.pkl...\n", - "Uploading ./segmented_data/combined_df_store_TX_1_dept_FOODS_3.pkl...\n", - "Uploading ./segmented_data/combined_df_store_CA_4.pkl...\n", - "Uploading ./segmented_data/combined_df_store_TX_1_dept_HOBBIES_1.pkl...\n", - "Uploading ./segmented_data/combined_df_store_WI_2_dept_HOUSEHOLD_1.pkl...\n", - "Uploading ./segmented_data/combined_df_store_WI_3_dept_FOODS_2.pkl...\n", - "Uploading ./segmented_data/combined_df_store_TX_1_dept_FOODS_2.pkl...\n", - "Uploading ./segmented_data/combined_df_store_CA_1_dept_HOBBIES_1.pkl...\n", - "Uploading ./segmented_data/combined_df_store_TX_2_dept_HOBBIES_2.pkl...\n", - "Uploading ./segmented_data/combined_df_store_CA_2_dept_FOODS_2.pkl...\n", - "Uploading ./segmented_data/combined_df_store_WI_3_dept_HOUSEHOLD_1.pkl...\n", - "Uploading ./segmented_data/combined_df_store_CA_3_dept_HOBBIES_1.pkl...\n", - "Uploading ./segmented_data/combined_df_store_CA_2_dept_FOODS_1.pkl...\n", - "Uploading ./segmented_data/combined_df_store_WI_2_dept_HOUSEHOLD_2.pkl...\n", - "Uploading ./segmented_data/combined_df_store_WI_3_dept_HOBBIES_1.pkl...\n", - "Uploading ./segmented_data/combined_df_store_CA_4_dept_FOODS_2.pkl...\n", - "Uploading ./segmented_data/combined_df_store_CA_3.pkl...\n", - "Uploading ./segmented_data/combined_df_store_CA_4_dept_FOODS_3.pkl...\n", - "Uploading ./segmented_data/combined_df_store_TX_3_dept_HOBBIES_1.pkl...\n", - "Uploading ./segmented_data/combined_df_store_TX_1_dept_HOUSEHOLD_2.pkl...\n", - "Uploading ./segmented_data/combined_df_store_TX_2.pkl...\n", - "Uploading ./segmented_data/combined_df_store_CA_2.pkl...\n", - "Uploading ./segmented_data/combined_df_store_CA_3_dept_HOUSEHOLD_1.pkl...\n", - "Uploading ./segmented_data/combined_df_store_CA_4_dept_HOBBIES_2.pkl...\n", - "Uploading ./segmented_data/combined_df_store_CA_1_dept_HOUSEHOLD_1.pkl...\n", - "Uploading ./segmented_data/combined_df_store_TX_3_dept_FOODS_2.pkl...\n", - "Uploading ./segmented_data/combined_df_store_WI_3_dept_HOUSEHOLD_2.pkl...\n", - "Uploading ./segmented_data/combined_df_store_CA_2_dept_HOUSEHOLD_1.pkl...\n", - "Uploading ./segmented_data/combined_df_store_TX_2_dept_HOBBIES_1.pkl...\n", - "Uploading ./segmented_data/combined_df_store_CA_1_dept_HOUSEHOLD_2.pkl...\n", - "Uploading ./segmented_data/combined_df_store_CA_1_dept_FOODS_1.pkl...\n", - "Uploading ./segmented_data/combined_df_store_CA_3_dept_FOODS_1.pkl...\n", - "Uploading ./segmented_data/combined_df_store_TX_1_dept_HOUSEHOLD_1.pkl...\n", - "Uploading ./segmented_data/combined_df_store_TX_2_dept_HOUSEHOLD_2.pkl...\n", - "Uploading ./segmented_data/combined_df_store_CA_4_dept_HOUSEHOLD_1.pkl...\n", - "Uploading ./segmented_data/combined_df_store_WI_1_dept_FOODS_2.pkl...\n", - "Uploading ./segmented_data/combined_df_store_CA_2_dept_HOUSEHOLD_2.pkl...\n", - "Uploading ./segmented_data/combined_df_store_CA_2_dept_HOBBIES_1.pkl...\n", - "Uploading ./segmented_data/combined_df_store_CA_1.pkl...\n", - "Uploading ./segmented_data/combined_df_store_CA_1_dept_FOODS_2.pkl...\n", - "Uploading ./segmented_data/combined_df_store_TX_3.pkl...\n", - "Uploading ./segmented_data/combined_df_store_TX_3_dept_FOODS_3.pkl...\n", - "Uploading ./segmented_data/combined_df_store_CA_2_dept_FOODS_3.pkl...\n", - "Uploading ./segmented_data/combined_df_store_WI_2_dept_FOODS_3.pkl...\n", - "Uploading ./segmented_data/combined_df_store_TX_2_dept_HOUSEHOLD_1.pkl...\n", - "Uploading ./segmented_data/combined_df_store_WI_2_dept_FOODS_1.pkl...\n", - "Uploading ./segmented_data/combined_df_store_TX_2_dept_FOODS_1.pkl...\n", - "Uploading ./segmented_data/combined_df_store_TX_1_dept_HOBBIES_2.pkl...\n", - "Uploading ./segmented_data/combined_df_store_WI_3_dept_HOBBIES_2.pkl...\n", - "Uploading ./segmented_data/combined_df_store_WI_3.pkl...\n", - "Uploading ./segmented_data/combined_df_store_CA_3_dept_FOODS_3.pkl...\n", - "Uploading ./segmented_data/combined_df_store_CA_4_dept_HOUSEHOLD_2.pkl...\n", - "Uploading ./segmented_data/combined_df_store_CA_3_dept_FOODS_2.pkl...\n", - "Uploading ./segmented_data/combined_df_store_TX_3_dept_HOUSEHOLD_2.pkl...\n", - "Uploading ./segmented_data/combined_df_store_WI_2.pkl...\n", - "Uploading ./segmented_data/combined_df_store_WI_2_dept_FOODS_2.pkl...\n", - "Uploading ./segmented_data/combined_df_store_WI_1_dept_FOODS_1.pkl...\n", - "Uploading ./segmented_data/combined_df_store_TX_2_dept_FOODS_2.pkl...\n", - "Uploading ./segmented_data/combined_df_store_TX_3_dept_HOBBIES_2.pkl...\n", - "Uploading ./segmented_data/combined_df_store_WI_1_dept_HOBBIES_1.pkl...\n", - "Uploading ./segmented_data/combined_df_store_WI_3_dept_FOODS_3.pkl...\n" - ] - } - ], - "source": [ - "# Then copy the segment to Cloud Storage\n", - "fs = gcsfs.GCSFileSystem()\n", - "\n", - "for e in glob.glob(segmented_data_dir + \"*\"):\n", - " print(f\"Uploading {e}...\")\n", - " basename = pathlib.Path(e).name\n", - " fs.put_file(e, f\"{bucket_name}/{basename}\")" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "58eb85e7-3d1d-4f72-8c9a-57b672344d73", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "# Also upload the product weights\n", - "fs = gcsfs.GCSFileSystem()\n", - "fs.put_file(\n", - " processed_data_dir + \"product_weights.pkl\", f\"{bucket_name}/product_weights.pkl\"\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b1cf03ee-3e08-429d-bff0-8023ee6de2af", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.12" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/source/examples/time-series-forecasting-with-hpo/start_here.ipynb b/source/examples/time-series-forecasting-with-hpo/start_here.ipynb index a3f23b8a..6afba35b 100644 --- a/source/examples/time-series-forecasting-with-hpo/start_here.ipynb +++ b/source/examples/time-series-forecasting-with-hpo/start_here.ipynb @@ -39,7 +39,12 @@ "source": [ "To run the example, you will need a working Google Kubernetes Engine (GKE) cluster with access to NVIDIA GPUs. Use the following resources to set up a cluster:\n", "\n", - "* [Set up a GKE cluster with access to NVIDIA GPUs](https://docs.rapids.ai/deployment/stable/cloud/gcp/gke/)\n", + "````{docref} /cloud/gcp/gke\n", + "Set up a Google Kubernetes Engine (GKE) cluster with access to NVIDIA GPUs. Follow instructions in [Google Kubernetes Engine](../../cloud/gcp/gke).\n", + "````\n", + "\n", + "To ensure that the example runs smoothly, ensure that you have ample memory in your GPUs. This notebook has been tested with NVIDIA A100.\n", + "\n", "* [Install the Dask-Kubernetes operator](https://kubernetes.dask.org/en/latest/operator_installation.html)\n", "* [Install Kubeflow](https://www.kubeflow.org/docs/started/installing-kubeflow/)\n", "\n", @@ -49,7 +54,7 @@ "* 1 NVIDIA GPU\n", "* 40 GiB disk volume\n", "\n", - "After uploading all the notebooks in the example, run this notebook (`start_here.ipynb`) in the notebook environment.\n", + "After uploading all the notebooks in the example, run this notebook (`time_series_forecasting.ipynb`) in the notebook environment.\n", "\n", "Note: We will use the worker pods to speed up the training stage. The preprocessing steps will run solely on the scheduler node." ] @@ -110,7 +115,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "id": "cdf7b111-3aba-4fae-b805-fb3063d5a621", "metadata": { "tags": [] @@ -122,7 +127,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "id": "13737fe2-3df5-4614-8675-fb20bccf7a19", "metadata": { "tags": [] @@ -134,7 +139,7 @@ "[]" ] }, - "execution_count": 4, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -165,7 +170,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "id": "8c2eb62a-a0f5-4801-9254-367fcef05e05", "metadata": { "tags": [] @@ -186,7 +191,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "id": "3c6fd3fe-b2d3-4bff-88f5-b86de68878a6", "metadata": { "tags": [] @@ -209,80 +214,7783 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 6, "id": "b7bf2248-cd3b-47c5-8923-ec9a6e868a49", "metadata": { "tags": [] }, + "outputs": [], + "source": [ + "import zipfile\n", + "\n", + "with zipfile.ZipFile(\"m5-forecasting-accuracy.zip\", \"r\") as zf:\n", + " zf.extractall(path=\"./data\")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "673468ae-9f6d-499f-86c4-230021bbf1b0", + "metadata": { + "tags": [] + }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Archive: m5-forecasting-accuracy.zip\n", - " inflating: data/calendar.csv \n", - " inflating: data/sales_train_evaluation.csv \n", - " inflating: data/sales_train_validation.csv \n", - " inflating: data/sample_submission.csv \n", - " inflating: data/sell_prices.csv \n" + "-rw-r--r-- 1 rapids conda 102K Sep 28 18:59 data/calendar.csv\n", + "-rw-r--r-- 1 rapids conda 117M Sep 28 18:59 data/sales_train_evaluation.csv\n", + "-rw-r--r-- 1 rapids conda 115M Sep 28 18:59 data/sales_train_validation.csv\n", + "-rw-r--r-- 1 rapids conda 5.0M Sep 28 18:59 data/sample_submission.csv\n", + "-rw-r--r-- 1 rapids conda 194M Sep 28 18:59 data/sell_prices.csv\n" ] } ], "source": [ - "!unzip m5-forecasting-accuracy.zip -d data/" + "!ls -lh data/*.csv" + ] + }, + { + "cell_type": "markdown", + "id": "f304ea68-381f-45b4-9e27-201a35e31239", + "metadata": {}, + "source": [ + "## Data preprocessing" + ] + }, + { + "cell_type": "markdown", + "id": "d9903e47-2a83-40d1-b65b-d5818e9f0647", + "metadata": {}, + "source": [ + "We are now ready to run the preprocessing steps." + ] + }, + { + "cell_type": "markdown", + "id": "be6740ef-538d-4624-aa00-95d0029ee90b", + "metadata": {}, + "source": [ + "### Import modules and define utility functions" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "2de6774b-9641-437e-8a03-f578c977a6ac", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import cudf\n", + "import numpy as np\n", + "import gc\n", + "import pathlib\n", + "import gcsfs\n", + "\n", + "\n", + "def sizeof_fmt(num, suffix=\"B\"):\n", + " for unit in [\"\", \"Ki\", \"Mi\", \"Gi\", \"Ti\", \"Pi\", \"Ei\", \"Zi\"]:\n", + " if abs(num) < 1024.0:\n", + " return f\"{num:3.1f}{unit}{suffix}\"\n", + " num /= 1024.0\n", + " return \"%.1f%s%s\" % (num, \"Yi\", suffix)\n", + "\n", + "\n", + "def report_dataframe_size(df, name):\n", + " print(\n", + " \"{} takes up {} memory on GPU\".format(\n", + " name, sizeof_fmt(grid_df.memory_usage(index=True).sum())\n", + " )\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "f87688b3-49e9-4d2b-a18c-3d4185c96b49", + "metadata": {}, + "source": [ + "### Load Data" ] }, { "cell_type": "code", "execution_count": 9, - "id": "673468ae-9f6d-499f-86c4-230021bbf1b0", + "id": "a0724160-5394-4976-b5e4-2cb5bd672d59", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "TARGET = \"sales\" # Our main target\n", + "END_TRAIN = 1941 # Last day in train set" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "aebe04ff-8a3d-451d-a9ca-d7e7d7ecd739", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "raw_data_dir = pathlib.Path(\"./data/\")" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "914cf6bd-62b9-4bbc-aad2-1029e701d0cd", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "train_df = cudf.read_csv(raw_data_dir / \"sales_train_evaluation.csv\")\n", + "prices_df = cudf.read_csv(raw_data_dir / \"sell_prices.csv\")\n", + "calendar_df = cudf.read_csv(raw_data_dir / \"calendar.csv\").rename(\n", + " columns={\"d\": \"day_id\"}\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "4d9f912b-31e1-44a7-8613-92881e9e88f9", "metadata": { "tags": [] }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "-rw-r--r-- 1 root users 102K Jun 1 2020 data/calendar.csv\n", - "-rw-r--r-- 1 root users 117M Jun 1 2020 data/sales_train_evaluation.csv\n", - "-rw-r--r-- 1 root users 115M Jun 1 2020 data/sales_train_validation.csv\n", - "-rw-r--r-- 1 root users 5.0M Jun 1 2020 data/sample_submission.csv\n", - "-rw-r--r-- 1 root users 194M Jun 1 2020 data/sell_prices.csv\n" - ] + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
iditem_iddept_idcat_idstore_idstate_idd_1d_2d_3d_4...d_1932d_1933d_1934d_1935d_1936d_1937d_1938d_1939d_1940d_1941
0HOBBIES_1_001_CA_1_evaluationHOBBIES_1_001HOBBIES_1HOBBIESCA_1CA0000...2400003301
1HOBBIES_1_002_CA_1_evaluationHOBBIES_1_002HOBBIES_1HOBBIESCA_1CA0000...0121100000
2HOBBIES_1_003_CA_1_evaluationHOBBIES_1_003HOBBIES_1HOBBIESCA_1CA0000...1020002301
3HOBBIES_1_004_CA_1_evaluationHOBBIES_1_004HOBBIES_1HOBBIESCA_1CA0000...1104013026
4HOBBIES_1_005_CA_1_evaluationHOBBIES_1_005HOBBIES_1HOBBIESCA_1CA0000...0002100210
..................................................................
30485FOODS_3_823_WI_3_evaluationFOODS_3_823FOODS_3FOODSWI_3WI0022...1030110011
30486FOODS_3_824_WI_3_evaluationFOODS_3_824FOODS_3FOODSWI_3WI0000...0000001010
30487FOODS_3_825_WI_3_evaluationFOODS_3_825FOODS_3FOODSWI_3WI0602...0012010102
30488FOODS_3_826_WI_3_evaluationFOODS_3_826FOODS_3FOODSWI_3WI0000...1114601110
30489FOODS_3_827_WI_3_evaluationFOODS_3_827FOODS_3FOODSWI_3WI0000...1205402251
\n", + "

30490 rows × 1947 columns

\n", + "
" + ], + "text/plain": [ + " id item_id dept_id cat_id \\\n", + "0 HOBBIES_1_001_CA_1_evaluation HOBBIES_1_001 HOBBIES_1 HOBBIES \n", + "1 HOBBIES_1_002_CA_1_evaluation HOBBIES_1_002 HOBBIES_1 HOBBIES \n", + "2 HOBBIES_1_003_CA_1_evaluation HOBBIES_1_003 HOBBIES_1 HOBBIES \n", + "3 HOBBIES_1_004_CA_1_evaluation HOBBIES_1_004 HOBBIES_1 HOBBIES \n", + "4 HOBBIES_1_005_CA_1_evaluation HOBBIES_1_005 HOBBIES_1 HOBBIES \n", + "... ... ... ... ... \n", + "30485 FOODS_3_823_WI_3_evaluation FOODS_3_823 FOODS_3 FOODS \n", + "30486 FOODS_3_824_WI_3_evaluation FOODS_3_824 FOODS_3 FOODS \n", + "30487 FOODS_3_825_WI_3_evaluation FOODS_3_825 FOODS_3 FOODS \n", + "30488 FOODS_3_826_WI_3_evaluation FOODS_3_826 FOODS_3 FOODS \n", + "30489 FOODS_3_827_WI_3_evaluation FOODS_3_827 FOODS_3 FOODS \n", + "\n", + " store_id state_id d_1 d_2 d_3 d_4 ... d_1932 d_1933 d_1934 \\\n", + "0 CA_1 CA 0 0 0 0 ... 2 4 0 \n", + "1 CA_1 CA 0 0 0 0 ... 0 1 2 \n", + "2 CA_1 CA 0 0 0 0 ... 1 0 2 \n", + "3 CA_1 CA 0 0 0 0 ... 1 1 0 \n", + "4 CA_1 CA 0 0 0 0 ... 0 0 0 \n", + "... ... ... ... ... ... ... ... ... ... ... \n", + "30485 WI_3 WI 0 0 2 2 ... 1 0 3 \n", + "30486 WI_3 WI 0 0 0 0 ... 0 0 0 \n", + "30487 WI_3 WI 0 6 0 2 ... 0 0 1 \n", + "30488 WI_3 WI 0 0 0 0 ... 1 1 1 \n", + "30489 WI_3 WI 0 0 0 0 ... 1 2 0 \n", + "\n", + " d_1935 d_1936 d_1937 d_1938 d_1939 d_1940 d_1941 \n", + "0 0 0 0 3 3 0 1 \n", + "1 1 1 0 0 0 0 0 \n", + "2 0 0 0 2 3 0 1 \n", + "3 4 0 1 3 0 2 6 \n", + "4 2 1 0 0 2 1 0 \n", + "... ... ... ... ... ... ... ... \n", + "30485 0 1 1 0 0 1 1 \n", + "30486 0 0 0 1 0 1 0 \n", + "30487 2 0 1 0 1 0 2 \n", + "30488 4 6 0 1 1 1 0 \n", + "30489 5 4 0 2 2 5 1 \n", + "\n", + "[30490 rows x 1947 columns]" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "!ls -lh data/*.csv" + "train_df" ] }, { "cell_type": "markdown", - "id": "f304ea68-381f-45b4-9e27-201a35e31239", + "id": "7f834bdb-0c99-4d9f-b207-ac1f423756c9", + "metadata": {}, + "source": [ + "The columns `d_1`, `d_2`, ..., `d_1941` indicate the sales data at days 1, 2, ..., 1941 from 2011-01-29." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "2af6b2d9-14f9-42e4-81f9-1ad2fe3b92fe", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
store_iditem_idwm_yr_wksell_price
0CA_1HOBBIES_1_001113259.58
1CA_1HOBBIES_1_001113269.58
2CA_1HOBBIES_1_001113278.26
3CA_1HOBBIES_1_001113288.26
4CA_1HOBBIES_1_001113298.26
...............
6841116WI_3FOODS_3_827116171.00
6841117WI_3FOODS_3_827116181.00
6841118WI_3FOODS_3_827116191.00
6841119WI_3FOODS_3_827116201.00
6841120WI_3FOODS_3_827116211.00
\n", + "

6841121 rows × 4 columns

\n", + "
" + ], + "text/plain": [ + " store_id item_id wm_yr_wk sell_price\n", + "0 CA_1 HOBBIES_1_001 11325 9.58\n", + "1 CA_1 HOBBIES_1_001 11326 9.58\n", + "2 CA_1 HOBBIES_1_001 11327 8.26\n", + "3 CA_1 HOBBIES_1_001 11328 8.26\n", + "4 CA_1 HOBBIES_1_001 11329 8.26\n", + "... ... ... ... ...\n", + "6841116 WI_3 FOODS_3_827 11617 1.00\n", + "6841117 WI_3 FOODS_3_827 11618 1.00\n", + "6841118 WI_3 FOODS_3_827 11619 1.00\n", + "6841119 WI_3 FOODS_3_827 11620 1.00\n", + "6841120 WI_3 FOODS_3_827 11621 1.00\n", + "\n", + "[6841121 rows x 4 columns]" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "prices_df" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "34eb6624-3960-4b80-a92a-82c26ea4d974", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
datewm_yr_wkweekdaywdaymonthyearday_idevent_name_1event_type_1event_name_2event_type_2snap_CAsnap_TXsnap_WI
02011-01-2911101Saturday112011d_1<NA><NA><NA><NA>000
12011-01-3011101Sunday212011d_2<NA><NA><NA><NA>000
22011-01-3111101Monday312011d_3<NA><NA><NA><NA>000
32011-02-0111101Tuesday422011d_4<NA><NA><NA><NA>110
42011-02-0211101Wednesday522011d_5<NA><NA><NA><NA>101
.............................................
19642016-06-1511620Wednesday562016d_1965<NA><NA><NA><NA>011
19652016-06-1611620Thursday662016d_1966<NA><NA><NA><NA>000
19662016-06-1711620Friday762016d_1967<NA><NA><NA><NA>000
19672016-06-1811621Saturday162016d_1968<NA><NA><NA><NA>000
19682016-06-1911621Sunday262016d_1969NBAFinalsEndSportingFather's dayCultural000
\n", + "

1969 rows × 14 columns

\n", + "
" + ], + "text/plain": [ + " date wm_yr_wk weekday wday month year day_id \\\n", + "0 2011-01-29 11101 Saturday 1 1 2011 d_1 \n", + "1 2011-01-30 11101 Sunday 2 1 2011 d_2 \n", + "2 2011-01-31 11101 Monday 3 1 2011 d_3 \n", + "3 2011-02-01 11101 Tuesday 4 2 2011 d_4 \n", + "4 2011-02-02 11101 Wednesday 5 2 2011 d_5 \n", + "... ... ... ... ... ... ... ... \n", + "1964 2016-06-15 11620 Wednesday 5 6 2016 d_1965 \n", + "1965 2016-06-16 11620 Thursday 6 6 2016 d_1966 \n", + "1966 2016-06-17 11620 Friday 7 6 2016 d_1967 \n", + "1967 2016-06-18 11621 Saturday 1 6 2016 d_1968 \n", + "1968 2016-06-19 11621 Sunday 2 6 2016 d_1969 \n", + "\n", + " event_name_1 event_type_1 event_name_2 event_type_2 snap_CA snap_TX \\\n", + "0 0 0 \n", + "1 0 0 \n", + "2 0 0 \n", + "3 1 1 \n", + "4 1 0 \n", + "... ... ... ... ... ... ... \n", + "1964 0 1 \n", + "1965 0 0 \n", + "1966 0 0 \n", + "1967 0 0 \n", + "1968 NBAFinalsEnd Sporting Father's day Cultural 0 0 \n", + "\n", + " snap_WI \n", + "0 0 \n", + "1 0 \n", + "2 0 \n", + "3 0 \n", + "4 1 \n", + "... ... \n", + "1964 1 \n", + "1965 0 \n", + "1966 0 \n", + "1967 0 \n", + "1968 0 \n", + "\n", + "[1969 rows x 14 columns]" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "calendar_df" + ] + }, + { + "cell_type": "markdown", + "id": "c957eb3a-93a5-4800-bc9d-dfd3722c3ccd", "metadata": {}, "source": [ - "# Next steps" + "### Reformat sales times series data" ] }, { "cell_type": "markdown", - "id": "d9903e47-2a83-40d1-b65b-d5818e9f0647", + "id": "0db2e815-7de8-4496-b785-0f7d046f6fc1", + "metadata": {}, + "source": [ + "Pivot the columns `d_1`, `d_2`, ..., `d_1941` into separate rows using `cudf.melt`." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "5113af62-bea8-469c-bb3e-3bfc2e8805ee", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
iditem_iddept_idcat_idstore_idstate_idday_idsales
0HOBBIES_1_001_CA_1_evaluationHOBBIES_1_001HOBBIES_1HOBBIESCA_1CAd_10
1HOBBIES_1_002_CA_1_evaluationHOBBIES_1_002HOBBIES_1HOBBIESCA_1CAd_10
2HOBBIES_1_003_CA_1_evaluationHOBBIES_1_003HOBBIES_1HOBBIESCA_1CAd_10
3HOBBIES_1_004_CA_1_evaluationHOBBIES_1_004HOBBIES_1HOBBIESCA_1CAd_10
4HOBBIES_1_005_CA_1_evaluationHOBBIES_1_005HOBBIES_1HOBBIESCA_1CAd_10
...........................
59181085FOODS_3_823_WI_3_evaluationFOODS_3_823FOODS_3FOODSWI_3WId_19411
59181086FOODS_3_824_WI_3_evaluationFOODS_3_824FOODS_3FOODSWI_3WId_19410
59181087FOODS_3_825_WI_3_evaluationFOODS_3_825FOODS_3FOODSWI_3WId_19412
59181088FOODS_3_826_WI_3_evaluationFOODS_3_826FOODS_3FOODSWI_3WId_19410
59181089FOODS_3_827_WI_3_evaluationFOODS_3_827FOODS_3FOODSWI_3WId_19411
\n", + "

59181090 rows × 8 columns

\n", + "
" + ], + "text/plain": [ + " id item_id dept_id cat_id \\\n", + "0 HOBBIES_1_001_CA_1_evaluation HOBBIES_1_001 HOBBIES_1 HOBBIES \n", + "1 HOBBIES_1_002_CA_1_evaluation HOBBIES_1_002 HOBBIES_1 HOBBIES \n", + "2 HOBBIES_1_003_CA_1_evaluation HOBBIES_1_003 HOBBIES_1 HOBBIES \n", + "3 HOBBIES_1_004_CA_1_evaluation HOBBIES_1_004 HOBBIES_1 HOBBIES \n", + "4 HOBBIES_1_005_CA_1_evaluation HOBBIES_1_005 HOBBIES_1 HOBBIES \n", + "... ... ... ... ... \n", + "59181085 FOODS_3_823_WI_3_evaluation FOODS_3_823 FOODS_3 FOODS \n", + "59181086 FOODS_3_824_WI_3_evaluation FOODS_3_824 FOODS_3 FOODS \n", + "59181087 FOODS_3_825_WI_3_evaluation FOODS_3_825 FOODS_3 FOODS \n", + "59181088 FOODS_3_826_WI_3_evaluation FOODS_3_826 FOODS_3 FOODS \n", + "59181089 FOODS_3_827_WI_3_evaluation FOODS_3_827 FOODS_3 FOODS \n", + "\n", + " store_id state_id day_id sales \n", + "0 CA_1 CA d_1 0 \n", + "1 CA_1 CA d_1 0 \n", + "2 CA_1 CA d_1 0 \n", + "3 CA_1 CA d_1 0 \n", + "4 CA_1 CA d_1 0 \n", + "... ... ... ... ... \n", + "59181085 WI_3 WI d_1941 1 \n", + "59181086 WI_3 WI d_1941 0 \n", + "59181087 WI_3 WI d_1941 2 \n", + "59181088 WI_3 WI d_1941 0 \n", + "59181089 WI_3 WI d_1941 1 \n", + "\n", + "[59181090 rows x 8 columns]" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "index_columns = [\"id\", \"item_id\", \"dept_id\", \"cat_id\", \"store_id\", \"state_id\"]\n", + "grid_df = cudf.melt(\n", + " train_df, id_vars=index_columns, var_name=\"day_id\", value_name=TARGET\n", + ")\n", + "grid_df" + ] + }, + { + "cell_type": "markdown", + "id": "6b0fc482-3d30-4d8a-be62-caf7a3e3b9a1", + "metadata": {}, + "source": [ + "For each time series, add 28 rows that corresponds to the future forecast horizon:" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "e7c95c1e-1f33-4a53-b390-f37cffd398d1", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
iditem_iddept_idcat_idstore_idstate_idday_idsales
0HOBBIES_1_001_CA_1_evaluationHOBBIES_1_001HOBBIES_1HOBBIESCA_1CAd_10.0
1HOBBIES_1_002_CA_1_evaluationHOBBIES_1_002HOBBIES_1HOBBIESCA_1CAd_10.0
2HOBBIES_1_003_CA_1_evaluationHOBBIES_1_003HOBBIES_1HOBBIESCA_1CAd_10.0
3HOBBIES_1_004_CA_1_evaluationHOBBIES_1_004HOBBIES_1HOBBIESCA_1CAd_10.0
4HOBBIES_1_005_CA_1_evaluationHOBBIES_1_005HOBBIES_1HOBBIESCA_1CAd_10.0
...........................
60034805FOODS_3_823_WI_3_evaluationFOODS_3_823FOODS_3FOODSWI_3WId_1969NaN
60034806FOODS_3_824_WI_3_evaluationFOODS_3_824FOODS_3FOODSWI_3WId_1969NaN
60034807FOODS_3_825_WI_3_evaluationFOODS_3_825FOODS_3FOODSWI_3WId_1969NaN
60034808FOODS_3_826_WI_3_evaluationFOODS_3_826FOODS_3FOODSWI_3WId_1969NaN
60034809FOODS_3_827_WI_3_evaluationFOODS_3_827FOODS_3FOODSWI_3WId_1969NaN
\n", + "

60034810 rows × 8 columns

\n", + "
" + ], + "text/plain": [ + " id item_id dept_id cat_id \\\n", + "0 HOBBIES_1_001_CA_1_evaluation HOBBIES_1_001 HOBBIES_1 HOBBIES \n", + "1 HOBBIES_1_002_CA_1_evaluation HOBBIES_1_002 HOBBIES_1 HOBBIES \n", + "2 HOBBIES_1_003_CA_1_evaluation HOBBIES_1_003 HOBBIES_1 HOBBIES \n", + "3 HOBBIES_1_004_CA_1_evaluation HOBBIES_1_004 HOBBIES_1 HOBBIES \n", + "4 HOBBIES_1_005_CA_1_evaluation HOBBIES_1_005 HOBBIES_1 HOBBIES \n", + "... ... ... ... ... \n", + "60034805 FOODS_3_823_WI_3_evaluation FOODS_3_823 FOODS_3 FOODS \n", + "60034806 FOODS_3_824_WI_3_evaluation FOODS_3_824 FOODS_3 FOODS \n", + "60034807 FOODS_3_825_WI_3_evaluation FOODS_3_825 FOODS_3 FOODS \n", + "60034808 FOODS_3_826_WI_3_evaluation FOODS_3_826 FOODS_3 FOODS \n", + "60034809 FOODS_3_827_WI_3_evaluation FOODS_3_827 FOODS_3 FOODS \n", + "\n", + " store_id state_id day_id sales \n", + "0 CA_1 CA d_1 0.0 \n", + "1 CA_1 CA d_1 0.0 \n", + "2 CA_1 CA d_1 0.0 \n", + "3 CA_1 CA d_1 0.0 \n", + "4 CA_1 CA d_1 0.0 \n", + "... ... ... ... ... \n", + "60034805 WI_3 WI d_1969 NaN \n", + "60034806 WI_3 WI d_1969 NaN \n", + "60034807 WI_3 WI d_1969 NaN \n", + "60034808 WI_3 WI d_1969 NaN \n", + "60034809 WI_3 WI d_1969 NaN \n", + "\n", + "[60034810 rows x 8 columns]" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "add_grid = cudf.DataFrame()\n", + "for i in range(1, 29):\n", + " temp_df = train_df[index_columns]\n", + " temp_df = temp_df.drop_duplicates()\n", + " temp_df[\"day_id\"] = \"d_\" + str(END_TRAIN + i)\n", + " temp_df[TARGET] = np.nan # Sales amount at time (n + i) is unknown\n", + " add_grid = cudf.concat([add_grid, temp_df])\n", + "add_grid[\"day_id\"] = add_grid[\"day_id\"].astype(\n", + " \"category\"\n", + ") # The day_id column is categorical, after cudf.melt\n", + "\n", + "grid_df = cudf.concat([grid_df, add_grid])\n", + "grid_df = grid_df.reset_index(drop=True)\n", + "grid_df[\"sales\"] = grid_df[\"sales\"].astype(\n", + " np.float32\n", + ") # Use float32 type for sales column, to conserve memory\n", + "grid_df" + ] + }, + { + "cell_type": "markdown", + "id": "e250945a-e251-4b79-9fa8-25fffb96815b", "metadata": {}, "source": [ - "We are now ready to run the preprocessing steps. You should run the six notebooks in order, to process the raw data into a form that can be used for model training:\n", + "### Free up GPU memory\n", + "\n", + "GPU memory is a precious resource, so let's try to free up some memory. First, delete temporary variables we no longer need:" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "bae6a99d-1674-4937-aab4-7098eec87dff", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "8136" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Use xdel magic to scrub extra references from Jupyter notebook\n", + "%xdel temp_df\n", + "%xdel add_grid\n", + "%xdel train_df\n", "\n", - "```{toctree}\n", - "---\n", - "maxdepth: 1\n", - "---\n", - "preprocessing_part1\n", - "preprocessing_part2\n", - "preprocessing_part3\n", - "preprocessing_part4\n", - "preprocessing_part5\n", - "preprocessing_part6\n", - "training_and_evaluation\n" + "# Invoke the garbage collector explicitly to free up memory\n", + "gc.collect()" + ] + }, + { + "cell_type": "markdown", + "id": "b126fa0f-18fe-4ba2-9dad-ee278851f781", + "metadata": {}, + "source": [ + "Second, let's reduce the footprint of `grid_df` by converting strings into categoricals:" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "5906f5c1-54ba-47fc-977a-e6516076e333", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "grid_df takes up 5.2GiB memory on GPU\n" + ] + } + ], + "source": [ + "report_dataframe_size(grid_df, \"grid_df\")" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "0c6db032-8798-4cf0-a9db-d483c836ead5", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "id object\n", + "item_id object\n", + "dept_id object\n", + "cat_id object\n", + "store_id object\n", + "state_id object\n", + "day_id category\n", + "sales float32\n", + "dtype: object" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "grid_df.dtypes" ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "82e650e3-c001-4c33-b75d-77897c7c9a2c", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "grid_df takes up 802.6MiB memory on GPU\n" + ] + } + ], + "source": [ + "for col in index_columns:\n", + " grid_df[col] = grid_df[col].astype(\"category\")\n", + " gc.collect()\n", + "report_dataframe_size(grid_df, \"grid_df\")" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "1f4f4421-75fa-471a-aa02-8521d2497b02", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "id category\n", + "item_id category\n", + "dept_id category\n", + "cat_id category\n", + "store_id category\n", + "state_id category\n", + "day_id category\n", + "sales float32\n", + "dtype: object" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "grid_df.dtypes" + ] + }, + { + "cell_type": "markdown", + "id": "b62205cf-3e18-4f4c-ae98-e818e510e931", + "metadata": {}, + "source": [ + "### Identify the release week of each product\n", + "\n", + "Each row in the `prices_df` table contains the price of a product sold at a store for a given week." + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "774b69c6-b3f5-44ee-b769-df78397c4b37", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
store_iditem_idwm_yr_wksell_price
0CA_1HOBBIES_1_001113259.58
1CA_1HOBBIES_1_001113269.58
2CA_1HOBBIES_1_001113278.26
3CA_1HOBBIES_1_001113288.26
4CA_1HOBBIES_1_001113298.26
...............
6841116WI_3FOODS_3_827116171.00
6841117WI_3FOODS_3_827116181.00
6841118WI_3FOODS_3_827116191.00
6841119WI_3FOODS_3_827116201.00
6841120WI_3FOODS_3_827116211.00
\n", + "

6841121 rows × 4 columns

\n", + "
" + ], + "text/plain": [ + " store_id item_id wm_yr_wk sell_price\n", + "0 CA_1 HOBBIES_1_001 11325 9.58\n", + "1 CA_1 HOBBIES_1_001 11326 9.58\n", + "2 CA_1 HOBBIES_1_001 11327 8.26\n", + "3 CA_1 HOBBIES_1_001 11328 8.26\n", + "4 CA_1 HOBBIES_1_001 11329 8.26\n", + "... ... ... ... ...\n", + "6841116 WI_3 FOODS_3_827 11617 1.00\n", + "6841117 WI_3 FOODS_3_827 11618 1.00\n", + "6841118 WI_3 FOODS_3_827 11619 1.00\n", + "6841119 WI_3 FOODS_3_827 11620 1.00\n", + "6841120 WI_3 FOODS_3_827 11621 1.00\n", + "\n", + "[6841121 rows x 4 columns]" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "prices_df" + ] + }, + { + "cell_type": "markdown", + "id": "47df7d5b-5697-4fa9-8ab8-b68ac3264259", + "metadata": {}, + "source": [ + "Notice that not all products were sold over every week. Some products were sold only during some weeks. Let's use the groupby operation to identify the first week in which each product went on the shelf." + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "c5d1d2e6-6f8c-4150-9c11-e5519c41a461", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
store_iditem_idrelease_week
0CA_4FOODS_3_52911421
1TX_1HOUSEHOLD_1_40911230
2WI_2FOODS_3_14511214
3CA_4HOUSEHOLD_1_49411106
4WI_3HOBBIES_1_09311223
............
30485CA_3HOUSEHOLD_1_36911205
30486CA_2FOODS_3_10911101
30487CA_4FOODS_2_11911101
30488CA_4HOUSEHOLD_2_38411110
30489WI_3HOBBIES_1_13511328
\n", + "

30490 rows × 3 columns

\n", + "
" + ], + "text/plain": [ + " store_id item_id release_week\n", + "0 CA_4 FOODS_3_529 11421\n", + "1 TX_1 HOUSEHOLD_1_409 11230\n", + "2 WI_2 FOODS_3_145 11214\n", + "3 CA_4 HOUSEHOLD_1_494 11106\n", + "4 WI_3 HOBBIES_1_093 11223\n", + "... ... ... ...\n", + "30485 CA_3 HOUSEHOLD_1_369 11205\n", + "30486 CA_2 FOODS_3_109 11101\n", + "30487 CA_4 FOODS_2_119 11101\n", + "30488 CA_4 HOUSEHOLD_2_384 11110\n", + "30489 WI_3 HOBBIES_1_135 11328\n", + "\n", + "[30490 rows x 3 columns]" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "release_df = (\n", + " prices_df.groupby([\"store_id\", \"item_id\"])[\"wm_yr_wk\"].agg(\"min\").reset_index()\n", + ")\n", + "release_df.columns = [\"store_id\", \"item_id\", \"release_week\"]\n", + "release_df" + ] + }, + { + "cell_type": "markdown", + "id": "f4c4a98b-20fe-494a-8b42-b292c8c35f53", + "metadata": {}, + "source": [ + "Now that we've computed the release week for each product, let's merge it back to `grid_df`:" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "96439c30-64fe-4f8e-a872-0f2f1982a48f", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
iditem_iddept_idcat_idstore_idstate_idday_idsalesrelease_week
0FOODS_1_001_CA_1_evaluationFOODS_1_001FOODS_1FOODSCA_1CAd_13.011101
1FOODS_1_001_CA_1_evaluationFOODS_1_001FOODS_1FOODSCA_1CAd_20.011101
2FOODS_1_001_CA_1_evaluationFOODS_1_001FOODS_1FOODSCA_1CAd_30.011101
3FOODS_1_001_CA_1_evaluationFOODS_1_001FOODS_1FOODSCA_1CAd_41.011101
4FOODS_1_001_CA_1_evaluationFOODS_1_001FOODS_1FOODSCA_1CAd_54.011101
..............................
60034805HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_1965NaN11101
60034806HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_1966NaN11101
60034807HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_1967NaN11101
60034808HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_1968NaN11101
60034809HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_1969NaN11101
\n", + "

60034810 rows × 9 columns

\n", + "
" + ], + "text/plain": [ + " id item_id dept_id \\\n", + "0 FOODS_1_001_CA_1_evaluation FOODS_1_001 FOODS_1 \n", + "1 FOODS_1_001_CA_1_evaluation FOODS_1_001 FOODS_1 \n", + "2 FOODS_1_001_CA_1_evaluation FOODS_1_001 FOODS_1 \n", + "3 FOODS_1_001_CA_1_evaluation FOODS_1_001 FOODS_1 \n", + "4 FOODS_1_001_CA_1_evaluation FOODS_1_001 FOODS_1 \n", + "... ... ... ... \n", + "60034805 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "60034806 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "60034807 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "60034808 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "60034809 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "\n", + " cat_id store_id state_id day_id sales release_week \n", + "0 FOODS CA_1 CA d_1 3.0 11101 \n", + "1 FOODS CA_1 CA d_2 0.0 11101 \n", + "2 FOODS CA_1 CA d_3 0.0 11101 \n", + "3 FOODS CA_1 CA d_4 1.0 11101 \n", + "4 FOODS CA_1 CA d_5 4.0 11101 \n", + "... ... ... ... ... ... ... \n", + "60034805 HOUSEHOLD WI_3 WI d_1965 NaN 11101 \n", + "60034806 HOUSEHOLD WI_3 WI d_1966 NaN 11101 \n", + "60034807 HOUSEHOLD WI_3 WI d_1967 NaN 11101 \n", + "60034808 HOUSEHOLD WI_3 WI d_1968 NaN 11101 \n", + "60034809 HOUSEHOLD WI_3 WI d_1969 NaN 11101 \n", + "\n", + "[60034810 rows x 9 columns]" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "grid_df = grid_df.merge(release_df, on=[\"store_id\", \"item_id\"], how=\"left\")\n", + "grid_df = grid_df.sort_values(index_columns + [\"day_id\"]).reset_index(drop=True)\n", + "grid_df" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "18226d36-4e8c-4b49-ac99-b126790a21a4", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "139" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "del release_df # No longer needed\n", + "gc.collect()" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "5f1a3c6e-9411-42bc-8e2c-73f9abb7c09b", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "grid_df takes up 1.2GiB memory on GPU\n" + ] + } + ], + "source": [ + "report_dataframe_size(grid_df, \"grid_df\")" + ] + }, + { + "cell_type": "markdown", + "id": "cbfca2cc-c90f-4852-b9f0-d3882603a981", + "metadata": {}, + "source": [ + "### Filter out entries with zero sales\n", + "\n", + "We can further save space by dropping rows from `grid_df` that correspond to zero sales. Since each product doesn't go on the shelf until its release week, its sale must be zero during any week that's prior to the release week.\n", + "\n", + "To make use of this insight, we bring in the `wm_yr_wk` column from `calendar_df`:" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "acdf40b5-395d-4dc2-a620-8cdd9eec70cf", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
iditem_iddept_idcat_idstore_idstate_idday_idsalesrelease_weekwm_yr_wk
0FOODS_1_001_WI_2_evaluationFOODS_1_001FOODS_1FOODSWI_2WId_8090.01110111312
1FOODS_1_001_WI_2_evaluationFOODS_1_001FOODS_1FOODSWI_2WId_8100.01110111312
2FOODS_1_001_WI_2_evaluationFOODS_1_001FOODS_1FOODSWI_2WId_8112.01110111312
3FOODS_1_001_WI_2_evaluationFOODS_1_001FOODS_1FOODSWI_2WId_8120.01110111312
4FOODS_1_001_WI_2_evaluationFOODS_1_001FOODS_1FOODSWI_2WId_8131.01110111313
.................................
60034805HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_520.01110111108
60034806HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_530.01110111108
60034807HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_540.01110111108
60034808HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_550.01110111108
60034809HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_490.01110111107
\n", + "

60034810 rows × 10 columns

\n", + "
" + ], + "text/plain": [ + " id item_id dept_id \\\n", + "0 FOODS_1_001_WI_2_evaluation FOODS_1_001 FOODS_1 \n", + "1 FOODS_1_001_WI_2_evaluation FOODS_1_001 FOODS_1 \n", + "2 FOODS_1_001_WI_2_evaluation FOODS_1_001 FOODS_1 \n", + "3 FOODS_1_001_WI_2_evaluation FOODS_1_001 FOODS_1 \n", + "4 FOODS_1_001_WI_2_evaluation FOODS_1_001 FOODS_1 \n", + "... ... ... ... \n", + "60034805 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "60034806 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "60034807 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "60034808 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "60034809 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "\n", + " cat_id store_id state_id day_id sales release_week wm_yr_wk \n", + "0 FOODS WI_2 WI d_809 0.0 11101 11312 \n", + "1 FOODS WI_2 WI d_810 0.0 11101 11312 \n", + "2 FOODS WI_2 WI d_811 2.0 11101 11312 \n", + "3 FOODS WI_2 WI d_812 0.0 11101 11312 \n", + "4 FOODS WI_2 WI d_813 1.0 11101 11313 \n", + "... ... ... ... ... ... ... ... \n", + "60034805 HOUSEHOLD WI_3 WI d_52 0.0 11101 11108 \n", + "60034806 HOUSEHOLD WI_3 WI d_53 0.0 11101 11108 \n", + "60034807 HOUSEHOLD WI_3 WI d_54 0.0 11101 11108 \n", + "60034808 HOUSEHOLD WI_3 WI d_55 0.0 11101 11108 \n", + "60034809 HOUSEHOLD WI_3 WI d_49 0.0 11101 11107 \n", + "\n", + "[60034810 rows x 10 columns]" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "grid_df = grid_df.merge(calendar_df[[\"wm_yr_wk\", \"day_id\"]], on=[\"day_id\"], how=\"left\")\n", + "grid_df" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "43705184-96ac-420c-94b4-ac932558484a", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "grid_df takes up 1.7GiB memory on GPU\n" + ] + } + ], + "source": [ + "report_dataframe_size(grid_df, \"grid_df\")" + ] + }, + { + "cell_type": "markdown", + "id": "98a935f3-d410-4f45-a745-8c875ea31a39", + "metadata": {}, + "source": [ + "The `wm_yr_wk` column identifies the week that contains the day given by the `day_id` column. Now let's filter all rows in `grid_df` for which `wm_yr_wk` is less than `release_week`:" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "d696b5dd-fbf8-4875-903a-db50b166cfe6", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
iditem_iddept_idcat_idstore_idstate_idday_idsalesrelease_weekwm_yr_wk
6766FOODS_1_002_TX_1_evaluationFOODS_1_002FOODS_1FOODSTX_1TXd_10.01110211101
6767FOODS_1_002_TX_1_evaluationFOODS_1_002FOODS_1FOODSTX_1TXd_20.01110211101
19686FOODS_1_001_TX_3_evaluationFOODS_1_001FOODS_1FOODSTX_3TXd_10.01110211101
19687FOODS_1_001_TX_3_evaluationFOODS_1_001FOODS_1FOODSTX_3TXd_20.01110211101
19688FOODS_1_001_TX_3_evaluationFOODS_1_001FOODS_1FOODSTX_3TXd_30.01110211101
.................................
60033493HOUSEHOLD_2_516_WI_2_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_2WId_200.01110611103
60033494HOUSEHOLD_2_516_WI_2_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_2WId_210.01110611103
60033495HOUSEHOLD_2_516_WI_2_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_2WId_220.01110611104
60033496HOUSEHOLD_2_516_WI_2_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_2WId_230.01110611104
60033497HOUSEHOLD_2_516_WI_2_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_2WId_240.01110611104
\n", + "

12299413 rows × 10 columns

\n", + "
" + ], + "text/plain": [ + " id item_id dept_id \\\n", + "6766 FOODS_1_002_TX_1_evaluation FOODS_1_002 FOODS_1 \n", + "6767 FOODS_1_002_TX_1_evaluation FOODS_1_002 FOODS_1 \n", + "19686 FOODS_1_001_TX_3_evaluation FOODS_1_001 FOODS_1 \n", + "19687 FOODS_1_001_TX_3_evaluation FOODS_1_001 FOODS_1 \n", + "19688 FOODS_1_001_TX_3_evaluation FOODS_1_001 FOODS_1 \n", + "... ... ... ... \n", + "60033493 HOUSEHOLD_2_516_WI_2_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "60033494 HOUSEHOLD_2_516_WI_2_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "60033495 HOUSEHOLD_2_516_WI_2_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "60033496 HOUSEHOLD_2_516_WI_2_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "60033497 HOUSEHOLD_2_516_WI_2_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "\n", + " cat_id store_id state_id day_id sales release_week wm_yr_wk \n", + "6766 FOODS TX_1 TX d_1 0.0 11102 11101 \n", + "6767 FOODS TX_1 TX d_2 0.0 11102 11101 \n", + "19686 FOODS TX_3 TX d_1 0.0 11102 11101 \n", + "19687 FOODS TX_3 TX d_2 0.0 11102 11101 \n", + "19688 FOODS TX_3 TX d_3 0.0 11102 11101 \n", + "... ... ... ... ... ... ... ... \n", + "60033493 HOUSEHOLD WI_2 WI d_20 0.0 11106 11103 \n", + "60033494 HOUSEHOLD WI_2 WI d_21 0.0 11106 11103 \n", + "60033495 HOUSEHOLD WI_2 WI d_22 0.0 11106 11104 \n", + "60033496 HOUSEHOLD WI_2 WI d_23 0.0 11106 11104 \n", + "60033497 HOUSEHOLD WI_2 WI d_24 0.0 11106 11104 \n", + "\n", + "[12299413 rows x 10 columns]" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df = grid_df[grid_df[\"wm_yr_wk\"] < grid_df[\"release_week\"]]\n", + "df" + ] + }, + { + "cell_type": "markdown", + "id": "b76d6141-1f3e-43ec-9f2e-ad7da2ab32bd", + "metadata": {}, + "source": [ + "As we suspected, the sales amount is zero during weeks that come before the release week." + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "515404e3-ef4b-459f-a24e-00f3cf084151", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "assert (df[\"sales\"] == 0).all()" + ] + }, + { + "cell_type": "markdown", + "id": "42213544-9dd4-4045-88a5-215319c6ce5f", + "metadata": {}, + "source": [ + "For the purpose of our data analysis, we can safely drop the rows with zero sales:" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "2b697475-41c2-4a9c-bc00-c23085b71edc", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
iditem_iddept_idcat_idstore_idstate_idday_idsalesrelease_weekwm_yr_wk
0FOODS_1_001_WI_2_evaluationFOODS_1_001FOODS_1FOODSWI_2WId_8090.01110111312
1FOODS_1_001_WI_2_evaluationFOODS_1_001FOODS_1FOODSWI_2WId_8100.01110111312
2FOODS_1_001_WI_2_evaluationFOODS_1_001FOODS_1FOODSWI_2WId_8112.01110111312
3FOODS_1_001_WI_2_evaluationFOODS_1_001FOODS_1FOODSWI_2WId_8120.01110111312
4FOODS_1_001_WI_2_evaluationFOODS_1_001FOODS_1FOODSWI_2WId_8131.01110111313
.................................
47735392HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_520.01110111108
47735393HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_530.01110111108
47735394HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_540.01110111108
47735395HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_550.01110111108
47735396HOUSEHOLD_2_516_WI_3_evaluationHOUSEHOLD_2_516HOUSEHOLD_2HOUSEHOLDWI_3WId_490.01110111107
\n", + "

47735397 rows × 10 columns

\n", + "
" + ], + "text/plain": [ + " id item_id dept_id \\\n", + "0 FOODS_1_001_WI_2_evaluation FOODS_1_001 FOODS_1 \n", + "1 FOODS_1_001_WI_2_evaluation FOODS_1_001 FOODS_1 \n", + "2 FOODS_1_001_WI_2_evaluation FOODS_1_001 FOODS_1 \n", + "3 FOODS_1_001_WI_2_evaluation FOODS_1_001 FOODS_1 \n", + "4 FOODS_1_001_WI_2_evaluation FOODS_1_001 FOODS_1 \n", + "... ... ... ... \n", + "47735392 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "47735393 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "47735394 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "47735395 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "47735396 HOUSEHOLD_2_516_WI_3_evaluation HOUSEHOLD_2_516 HOUSEHOLD_2 \n", + "\n", + " cat_id store_id state_id day_id sales release_week wm_yr_wk \n", + "0 FOODS WI_2 WI d_809 0.0 11101 11312 \n", + "1 FOODS WI_2 WI d_810 0.0 11101 11312 \n", + "2 FOODS WI_2 WI d_811 2.0 11101 11312 \n", + "3 FOODS WI_2 WI d_812 0.0 11101 11312 \n", + "4 FOODS WI_2 WI d_813 1.0 11101 11313 \n", + "... ... ... ... ... ... ... ... \n", + "47735392 HOUSEHOLD WI_3 WI d_52 0.0 11101 11108 \n", + "47735393 HOUSEHOLD WI_3 WI d_53 0.0 11101 11108 \n", + "47735394 HOUSEHOLD WI_3 WI d_54 0.0 11101 11108 \n", + "47735395 HOUSEHOLD WI_3 WI d_55 0.0 11101 11108 \n", + "47735396 HOUSEHOLD WI_3 WI d_49 0.0 11101 11107 \n", + "\n", + "[47735397 rows x 10 columns]" + ] + }, + "execution_count": 31, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "grid_df = grid_df[grid_df[\"wm_yr_wk\"] >= grid_df[\"release_week\"]].reset_index(drop=True)\n", + "grid_df[\"wm_yr_wk\"] = grid_df[\"wm_yr_wk\"].astype(\n", + " np.int32\n", + ") # Convert wm_yr_wk column to int32, to conserve memory\n", + "grid_df" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "e7db4f50-d61b-4a44-8e7e-fb5dcf2fdc9d", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "grid_df takes up 1.2GiB memory on GPU\n" + ] + } + ], + "source": [ + "report_dataframe_size(grid_df, \"grid_df\")" + ] + }, + { + "cell_type": "markdown", + "id": "b1ac8ea2-8430-4760-8fc1-0d3f45300eb0", + "metadata": {}, + "source": [ + "### Assign weights for product items\n", + "\n", + "When we assess the accuracy of our machine learning model, we should assign a weight for each product item, to indicate the relative importance of the item. For the M5 competition, the weights are computed from the total sales amount (in US dollars) in the lastest 28 days." + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "3cfffc9d-2767-42fa-b67a-5056671bd2f8", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
sales_usd_sum
item_id
FOODS_1_0013516.80
FOODS_1_00212418.80
FOODS_1_0035943.20
FOODS_1_00454184.82
FOODS_1_00517877.00
......
HOUSEHOLD_2_5126034.40
HOUSEHOLD_2_5132668.80
HOUSEHOLD_2_5149574.60
HOUSEHOLD_2_515630.40
HOUSEHOLD_2_5162574.00
\n", + "

3049 rows × 1 columns

\n", + "
" + ], + "text/plain": [ + " sales_usd_sum\n", + "item_id \n", + "FOODS_1_001 3516.80\n", + "FOODS_1_002 12418.80\n", + "FOODS_1_003 5943.20\n", + "FOODS_1_004 54184.82\n", + "FOODS_1_005 17877.00\n", + "... ...\n", + "HOUSEHOLD_2_512 6034.40\n", + "HOUSEHOLD_2_513 2668.80\n", + "HOUSEHOLD_2_514 9574.60\n", + "HOUSEHOLD_2_515 630.40\n", + "HOUSEHOLD_2_516 2574.00\n", + "\n", + "[3049 rows x 1 columns]" + ] + }, + "execution_count": 33, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Convert day_id to integers\n", + "grid_df[\"day_id_int\"] = grid_df[\"day_id\"].to_pandas().apply(lambda x: x[2:]).astype(int)\n", + "\n", + "# Compute the total sales over the latest 28 days, per product item\n", + "last28 = grid_df[(grid_df[\"day_id_int\"] >= 1914) & (grid_df[\"day_id_int\"] < 1942)]\n", + "last28 = last28[[\"item_id\", \"wm_yr_wk\", \"sales\"]].merge(\n", + " prices_df[[\"item_id\", \"wm_yr_wk\", \"sell_price\"]], on=[\"item_id\", \"wm_yr_wk\"]\n", + ")\n", + "last28[\"sales_usd\"] = last28[\"sales\"] * last28[\"sell_price\"]\n", + "total_sales_usd = last28.groupby(\"item_id\")[[\"sales_usd\"]].agg([\"sum\"]).sort_index()\n", + "total_sales_usd.columns = total_sales_usd.columns.map(\"_\".join)\n", + "total_sales_usd" + ] + }, + { + "cell_type": "markdown", + "id": "ea0c6f49-60af-48b5-b020-42d35b3cc2eb", + "metadata": {}, + "source": [ + "To obtain weights, we normalize the sales amount for one item by the total sales for all items." + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "id": "e651c504-de65-4765-b9b1-0087a5dab3b3", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
weights
item_id
FOODS_1_0010.000090
FOODS_1_0020.000318
FOODS_1_0030.000152
FOODS_1_0040.001389
FOODS_1_0050.000458
......
HOUSEHOLD_2_5120.000155
HOUSEHOLD_2_5130.000068
HOUSEHOLD_2_5140.000245
HOUSEHOLD_2_5150.000016
HOUSEHOLD_2_5160.000066
\n", + "

3049 rows × 1 columns

\n", + "
" + ], + "text/plain": [ + " weights\n", + "item_id \n", + "FOODS_1_001 0.000090\n", + "FOODS_1_002 0.000318\n", + "FOODS_1_003 0.000152\n", + "FOODS_1_004 0.001389\n", + "FOODS_1_005 0.000458\n", + "... ...\n", + "HOUSEHOLD_2_512 0.000155\n", + "HOUSEHOLD_2_513 0.000068\n", + "HOUSEHOLD_2_514 0.000245\n", + "HOUSEHOLD_2_515 0.000016\n", + "HOUSEHOLD_2_516 0.000066\n", + "\n", + "[3049 rows x 1 columns]" + ] + }, + "execution_count": 34, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "weights = total_sales_usd / total_sales_usd.sum()\n", + "weights = weights.rename(columns={\"sales_usd_sum\": \"weights\"})\n", + "weights" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "9882bfec-f7e2-45f9-97e4-93a836b28a0a", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# No longer needed\n", + "del grid_df[\"day_id_int\"]" + ] + }, + { + "cell_type": "markdown", + "id": "dd8706a0-32d5-4cb3-9a2e-bf42fcc17254", + "metadata": {}, + "source": [ + "### Generate price-related features\n", + "Let us engineer additional features that are related to the sale price. We consider the distribution of the price of a given product over time and ask how the current price compares to the historical trend." + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "id": "451c180c-13a0-4517-a25e-97557bb12718", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Highest price over all weeks\n", + "prices_df[\"price_max\"] = prices_df.groupby([\"store_id\", \"item_id\"])[\n", + " \"sell_price\"\n", + "].transform(\"max\")\n", + "# Lowest price over all weeks\n", + "prices_df[\"price_min\"] = prices_df.groupby([\"store_id\", \"item_id\"])[\n", + " \"sell_price\"\n", + "].transform(\"min\")\n", + "# Standard deviation of the price\n", + "prices_df[\"price_std\"] = prices_df.groupby([\"store_id\", \"item_id\"])[\n", + " \"sell_price\"\n", + "].transform(\"std\")\n", + "# Mean (average) price over all weeks\n", + "prices_df[\"price_mean\"] = prices_df.groupby([\"store_id\", \"item_id\"])[\n", + " \"sell_price\"\n", + "].transform(\"mean\")" + ] + }, + { + "cell_type": "markdown", + "id": "ffbadd58-5836-4668-8011-0f110e333160", + "metadata": {}, + "source": [ + "We also consider the ratio of the current price to the max price." + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "id": "a737282a-a38b-497c-a611-64df7343d01a", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "prices_df[\"price_norm\"] = prices_df[\"sell_price\"] / prices_df[\"price_max\"]" + ] + }, + { + "cell_type": "markdown", + "id": "cdde0f24-4ae3-4352-a210-7c2a2a3ac503", + "metadata": {}, + "source": [ + "Some items have a very stable price, whereas other items respond to inflation quickly and rise in price. To capture the price elasticity, we count the number of unique price values for a given product over time." + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "id": "7369eb2f-3055-47e6-a9ed-cb225697e2e5", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "prices_df[\"price_nunique\"] = prices_df.groupby([\"store_id\", \"item_id\"])[\n", + " \"sell_price\"\n", + "].transform(\"nunique\")" + ] + }, + { + "cell_type": "markdown", + "id": "96bdd61a-db51-4a65-935f-65153fcb5b4a", + "metadata": {}, + "source": [ + "We also consider, for a given price, how many other items are being sold at the exact same price." + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "id": "1c264e78-abca-411f-a8c7-0d42f3b3910d", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "prices_df[\"item_nunique\"] = prices_df.groupby([\"store_id\", \"sell_price\"])[\n", + " \"item_id\"\n", + "].transform(\"nunique\")" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "id": "d7eda764-0745-48c6-b368-33e93c6df618", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
store_iditem_idwm_yr_wksell_priceprice_maxprice_minprice_stdprice_meanprice_normprice_nuniqueitem_nunique
0CA_1HOBBIES_1_001113259.589.588.260.1521398.2857141.00000033
1CA_1HOBBIES_1_001113269.589.588.260.1521398.2857141.00000033
2CA_1HOBBIES_1_001113278.269.588.260.1521398.2857140.86221335
3CA_1HOBBIES_1_001113288.269.588.260.1521398.2857140.86221335
4CA_1HOBBIES_1_001113298.269.588.260.1521398.2857140.86221335
....................................
6841116WI_3FOODS_3_827116171.001.001.000.0000001.0000001.0000001142
6841117WI_3FOODS_3_827116181.001.001.000.0000001.0000001.0000001142
6841118WI_3FOODS_3_827116191.001.001.000.0000001.0000001.0000001142
6841119WI_3FOODS_3_827116201.001.001.000.0000001.0000001.0000001142
6841120WI_3FOODS_3_827116211.001.001.000.0000001.0000001.0000001142
\n", + "

6841121 rows × 11 columns

\n", + "
" + ], + "text/plain": [ + " store_id item_id wm_yr_wk sell_price price_max price_min \\\n", + "0 CA_1 HOBBIES_1_001 11325 9.58 9.58 8.26 \n", + "1 CA_1 HOBBIES_1_001 11326 9.58 9.58 8.26 \n", + "2 CA_1 HOBBIES_1_001 11327 8.26 9.58 8.26 \n", + "3 CA_1 HOBBIES_1_001 11328 8.26 9.58 8.26 \n", + "4 CA_1 HOBBIES_1_001 11329 8.26 9.58 8.26 \n", + "... ... ... ... ... ... ... \n", + "6841116 WI_3 FOODS_3_827 11617 1.00 1.00 1.00 \n", + "6841117 WI_3 FOODS_3_827 11618 1.00 1.00 1.00 \n", + "6841118 WI_3 FOODS_3_827 11619 1.00 1.00 1.00 \n", + "6841119 WI_3 FOODS_3_827 11620 1.00 1.00 1.00 \n", + "6841120 WI_3 FOODS_3_827 11621 1.00 1.00 1.00 \n", + "\n", + " price_std price_mean price_norm price_nunique item_nunique \n", + "0 0.152139 8.285714 1.000000 3 3 \n", + "1 0.152139 8.285714 1.000000 3 3 \n", + "2 0.152139 8.285714 0.862213 3 5 \n", + "3 0.152139 8.285714 0.862213 3 5 \n", + "4 0.152139 8.285714 0.862213 3 5 \n", + "... ... ... ... ... ... \n", + "6841116 0.000000 1.000000 1.000000 1 142 \n", + "6841117 0.000000 1.000000 1.000000 1 142 \n", + "6841118 0.000000 1.000000 1.000000 1 142 \n", + "6841119 0.000000 1.000000 1.000000 1 142 \n", + "6841120 0.000000 1.000000 1.000000 1 142 \n", + "\n", + "[6841121 rows x 11 columns]" + ] + }, + "execution_count": 40, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "prices_df" + ] + }, + { + "cell_type": "markdown", + "id": "56187024-fd76-4782-8e05-2ebab5daf3d3", + "metadata": {}, + "source": [ + "Another useful way to put prices in context is to compare the price of a product to its historical price a week ago, a month ago, or an year ago." + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "id": "d66e5c04-afaf-49d9-b7f6-fceb756939a9", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Add \"month\" and \"year\" columns to prices_df\n", + "week_to_month_map = calendar_df[[\"wm_yr_wk\", \"month\", \"year\"]].drop_duplicates(\n", + " subset=[\"wm_yr_wk\"]\n", + ")\n", + "prices_df = prices_df.merge(week_to_month_map, on=[\"wm_yr_wk\"], how=\"left\")\n", + "\n", + "# Sort by wm_yr_wk. The rows will also be sorted in ascending months and years.\n", + "prices_df = prices_df.sort_values([\"store_id\", \"item_id\", \"wm_yr_wk\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "id": "01a27a82-daf8-4492-b2ab-9e1d77ecbdfc", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Compare with the average price in the previous week\n", + "prices_df[\"price_momentum\"] = prices_df[\"sell_price\"] / prices_df.groupby(\n", + " [\"store_id\", \"item_id\"]\n", + ")[\"sell_price\"].shift(1)\n", + "# Compare with the average price in the previous month\n", + "prices_df[\"price_momentum_m\"] = prices_df[\"sell_price\"] / prices_df.groupby(\n", + " [\"store_id\", \"item_id\", \"month\"]\n", + ")[\"sell_price\"].transform(\"mean\")\n", + "# Compare with the average price in the previous year\n", + "prices_df[\"price_momentum_y\"] = prices_df[\"sell_price\"] / prices_df.groupby(\n", + " [\"store_id\", \"item_id\", \"year\"]\n", + ")[\"sell_price\"].transform(\"mean\")" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "id": "ad812423-bde6-4d1a-9677-222f81af6e0e", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Remove \"month\" and \"year\" columns, as we don't need them any more\n", + "del prices_df[\"month\"], prices_df[\"year\"]\n", + "\n", + "# Convert float64 columns into float32 type to save memory\n", + "columns = [\n", + " \"sell_price\",\n", + " \"price_max\",\n", + " \"price_min\",\n", + " \"price_std\",\n", + " \"price_mean\",\n", + " \"price_norm\",\n", + " \"price_momentum\",\n", + " \"price_momentum_m\",\n", + " \"price_momentum_y\",\n", + "]\n", + "for col in columns:\n", + " prices_df[col] = prices_df[col].astype(np.float32)" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "id": "9ad80497-bfd8-4421-ad09-dfde9d12e943", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "store_id object\n", + "item_id object\n", + "wm_yr_wk int64\n", + "sell_price float32\n", + "price_max float32\n", + "price_min float32\n", + "price_std float32\n", + "price_mean float32\n", + "price_norm float32\n", + "price_nunique int32\n", + "item_nunique int32\n", + "price_momentum float32\n", + "price_momentum_m float32\n", + "price_momentum_y float32\n", + "dtype: object" + ] + }, + "execution_count": 44, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "prices_df.dtypes" + ] + }, + { + "cell_type": "markdown", + "id": "04ddd07f-f9dc-4b01-ada3-d0ea9710024d", + "metadata": {}, + "source": [ + "### Bring in price-related features into `grid_df`" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "id": "9a5a5f3e-585d-462e-9c96-c7432461a458", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idday_idsell_priceprice_maxprice_minprice_stdprice_meanprice_normprice_nuniqueitem_nuniqueprice_momentumprice_momentum_mprice_momentum_y
0FOODS_1_002_TX_3_evaluationd_3628.889.487.884.849956e-018.9353190.9367093101.00.9761041.0
1FOODS_1_002_TX_3_evaluationd_3638.889.487.884.849956e-018.9353190.9367093101.00.9761041.0
2FOODS_1_002_TX_3_evaluationd_3648.889.487.884.849956e-018.9353190.9367093101.00.9761041.0
3FOODS_1_002_TX_3_evaluationd_3658.889.487.884.849956e-018.9353190.9367093101.00.9761041.0
4FOODS_1_002_TX_3_evaluationd_3668.889.487.884.849956e-018.9353190.9367093101.00.9761041.0
..........................................
47735392HOUSEHOLD_2_516_WI_2_evaluationd_6955.945.945.943.648122e-145.9400001.0000001471.01.0000001.0
47735393HOUSEHOLD_2_516_WI_2_evaluationd_6965.945.945.943.648122e-145.9400001.0000001471.01.0000001.0
47735394HOUSEHOLD_2_516_WI_2_evaluationd_6905.945.945.943.648122e-145.9400001.0000001471.01.0000001.0
47735395HOUSEHOLD_2_516_WI_2_evaluationd_6915.945.945.943.648122e-145.9400001.0000001471.01.0000001.0
47735396HOUSEHOLD_2_516_WI_2_evaluationd_6945.945.945.943.648122e-145.9400001.0000001471.01.0000001.0
\n", + "

47735397 rows × 13 columns

\n", + "
" + ], + "text/plain": [ + " id day_id sell_price price_max \\\n", + "0 FOODS_1_002_TX_3_evaluation d_362 8.88 9.48 \n", + "1 FOODS_1_002_TX_3_evaluation d_363 8.88 9.48 \n", + "2 FOODS_1_002_TX_3_evaluation d_364 8.88 9.48 \n", + "3 FOODS_1_002_TX_3_evaluation d_365 8.88 9.48 \n", + "4 FOODS_1_002_TX_3_evaluation d_366 8.88 9.48 \n", + "... ... ... ... ... \n", + "47735392 HOUSEHOLD_2_516_WI_2_evaluation d_695 5.94 5.94 \n", + "47735393 HOUSEHOLD_2_516_WI_2_evaluation d_696 5.94 5.94 \n", + "47735394 HOUSEHOLD_2_516_WI_2_evaluation d_690 5.94 5.94 \n", + "47735395 HOUSEHOLD_2_516_WI_2_evaluation d_691 5.94 5.94 \n", + "47735396 HOUSEHOLD_2_516_WI_2_evaluation d_694 5.94 5.94 \n", + "\n", + " price_min price_std price_mean price_norm price_nunique \\\n", + "0 7.88 4.849956e-01 8.935319 0.936709 3 \n", + "1 7.88 4.849956e-01 8.935319 0.936709 3 \n", + "2 7.88 4.849956e-01 8.935319 0.936709 3 \n", + "3 7.88 4.849956e-01 8.935319 0.936709 3 \n", + "4 7.88 4.849956e-01 8.935319 0.936709 3 \n", + "... ... ... ... ... ... \n", + "47735392 5.94 3.648122e-14 5.940000 1.000000 1 \n", + "47735393 5.94 3.648122e-14 5.940000 1.000000 1 \n", + "47735394 5.94 3.648122e-14 5.940000 1.000000 1 \n", + "47735395 5.94 3.648122e-14 5.940000 1.000000 1 \n", + "47735396 5.94 3.648122e-14 5.940000 1.000000 1 \n", + "\n", + " item_nunique price_momentum price_momentum_m price_momentum_y \n", + "0 10 1.0 0.976104 1.0 \n", + "1 10 1.0 0.976104 1.0 \n", + "2 10 1.0 0.976104 1.0 \n", + "3 10 1.0 0.976104 1.0 \n", + "4 10 1.0 0.976104 1.0 \n", + "... ... ... ... ... \n", + "47735392 47 1.0 1.000000 1.0 \n", + "47735393 47 1.0 1.000000 1.0 \n", + "47735394 47 1.0 1.000000 1.0 \n", + "47735395 47 1.0 1.000000 1.0 \n", + "47735396 47 1.0 1.000000 1.0 \n", + "\n", + "[47735397 rows x 13 columns]" + ] + }, + "execution_count": 45, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# After merging price_df, keep columns id and day_id from grid_df and drop all other columns from grid_df\n", + "original_columns = list(grid_df)\n", + "grid_df_with_price = grid_df.copy()\n", + "grid_df_with_price = grid_df_with_price.merge(\n", + " prices_df, on=[\"store_id\", \"item_id\", \"wm_yr_wk\"], how=\"left\"\n", + ")\n", + "columns_to_keep = [\"id\", \"day_id\"] + [\n", + " col for col in list(grid_df_with_price) if col not in original_columns\n", + "]\n", + "grid_df_with_price = grid_df_with_price[[\"id\", \"day_id\"] + columns_to_keep]\n", + "grid_df_with_price" + ] + }, + { + "cell_type": "markdown", + "id": "60d7cd4d-9cf8-4097-90a2-a46f698ce4b7", + "metadata": {}, + "source": [ + "### Generate date-related features\n", + "We identify the date in each row of `grid_df` using information from `calendar_df`." + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "id": "6bb4475a-787c-4ec4-ad2c-1c72ec369b18", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idday_iddateevent_name_1event_type_1event_name_2event_type_2snap_CAsnap_TXsnap_WI
0FOODS_1_002_WI_3_evaluationd_16782015-09-02<NA><NA><NA><NA>101
1FOODS_1_002_WI_3_evaluationd_16792015-09-03<NA><NA><NA><NA>111
2FOODS_1_002_WI_3_evaluationd_16802015-09-04<NA><NA><NA><NA>100
3FOODS_1_002_WI_3_evaluationd_16812015-09-05<NA><NA><NA><NA>111
4FOODS_1_002_WI_3_evaluationd_16822015-09-06<NA><NA><NA><NA>111
.................................
47735392HOUSEHOLD_2_516_WI_2_evaluationd_9772013-10-01<NA><NA><NA><NA>110
47735393HOUSEHOLD_2_516_WI_2_evaluationd_9782013-10-02<NA><NA><NA><NA>101
47735394HOUSEHOLD_2_516_WI_2_evaluationd_9792013-10-03<NA><NA><NA><NA>111
47735395HOUSEHOLD_2_516_WI_3_evaluationd_9032013-07-19<NA><NA><NA><NA>000
47735396HOUSEHOLD_2_516_WI_3_evaluationd_8882013-07-04IndependenceDayNational<NA><NA>100
\n", + "

47735397 rows × 10 columns

\n", + "
" + ], + "text/plain": [ + " id day_id date \\\n", + "0 FOODS_1_002_WI_3_evaluation d_1678 2015-09-02 \n", + "1 FOODS_1_002_WI_3_evaluation d_1679 2015-09-03 \n", + "2 FOODS_1_002_WI_3_evaluation d_1680 2015-09-04 \n", + "3 FOODS_1_002_WI_3_evaluation d_1681 2015-09-05 \n", + "4 FOODS_1_002_WI_3_evaluation d_1682 2015-09-06 \n", + "... ... ... ... \n", + "47735392 HOUSEHOLD_2_516_WI_2_evaluation d_977 2013-10-01 \n", + "47735393 HOUSEHOLD_2_516_WI_2_evaluation d_978 2013-10-02 \n", + "47735394 HOUSEHOLD_2_516_WI_2_evaluation d_979 2013-10-03 \n", + "47735395 HOUSEHOLD_2_516_WI_3_evaluation d_903 2013-07-19 \n", + "47735396 HOUSEHOLD_2_516_WI_3_evaluation d_888 2013-07-04 \n", + "\n", + " event_name_1 event_type_1 event_name_2 event_type_2 snap_CA \\\n", + "0 1 \n", + "1 1 \n", + "2 1 \n", + "3 1 \n", + "4 1 \n", + "... ... ... ... ... ... \n", + "47735392 1 \n", + "47735393 1 \n", + "47735394 1 \n", + "47735395 0 \n", + "47735396 IndependenceDay National 1 \n", + "\n", + " snap_TX snap_WI \n", + "0 0 1 \n", + "1 1 1 \n", + "2 0 0 \n", + "3 1 1 \n", + "4 1 1 \n", + "... ... ... \n", + "47735392 1 0 \n", + "47735393 0 1 \n", + "47735394 1 1 \n", + "47735395 0 0 \n", + "47735396 0 0 \n", + "\n", + "[47735397 rows x 10 columns]" + ] + }, + "execution_count": 46, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Bring in the following columns from calendar_df into grid_df\n", + "grid_df_id_only = grid_df[[\"id\", \"day_id\"]].copy()\n", + "\n", + "icols = [\n", + " \"date\",\n", + " \"day_id\",\n", + " \"event_name_1\",\n", + " \"event_type_1\",\n", + " \"event_name_2\",\n", + " \"event_type_2\",\n", + " \"snap_CA\",\n", + " \"snap_TX\",\n", + " \"snap_WI\",\n", + "]\n", + "grid_df_with_calendar = grid_df_id_only.merge(\n", + " calendar_df[icols], on=[\"day_id\"], how=\"left\"\n", + ")\n", + "grid_df_with_calendar" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "id": "4dfcf0ed-d461-48d4-acb9-6aa1d0fa973e", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Convert columns into categorical type to save memory\n", + "for col in [\n", + " \"event_name_1\",\n", + " \"event_type_1\",\n", + " \"event_name_2\",\n", + " \"event_type_2\",\n", + " \"snap_CA\",\n", + " \"snap_TX\",\n", + " \"snap_WI\",\n", + "]:\n", + " grid_df_with_calendar[col] = grid_df_with_calendar[col].astype(\"category\")\n", + "# Convert \"date\" column into timestamp type\n", + "grid_df_with_calendar[\"date\"] = cudf.to_datetime(grid_df_with_calendar[\"date\"])" + ] + }, + { + "cell_type": "markdown", + "id": "9ee43075-54f8-4c62-a173-1e016e9d9697", + "metadata": {}, + "source": [ + "Using the `date` column, we can generate related features, such as day, week, or month." + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "id": "7a530f96-67b8-49f3-9f45-a1533eeef2c5", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idday_idevent_name_1event_type_1event_name_2event_type_2snap_CAsnap_TXsnap_WItm_dtm_wtm_mtm_ytm_wmtm_dwtm_w_end
0FOODS_1_002_WI_3_evaluationd_1678<NA><NA><NA><NA>10123694120
1FOODS_1_002_WI_3_evaluationd_1679<NA><NA><NA><NA>11133694130
2FOODS_1_002_WI_3_evaluationd_1680<NA><NA><NA><NA>10043694140
3FOODS_1_002_WI_3_evaluationd_1681<NA><NA><NA><NA>11153694151
4FOODS_1_002_WI_3_evaluationd_1682<NA><NA><NA><NA>11163694161
...................................................
47735392HOUSEHOLD_2_516_WI_2_evaluationd_977<NA><NA><NA><NA>110140102110
47735393HOUSEHOLD_2_516_WI_2_evaluationd_978<NA><NA><NA><NA>101240102120
47735394HOUSEHOLD_2_516_WI_2_evaluationd_979<NA><NA><NA><NA>111340102130
47735395HOUSEHOLD_2_516_WI_3_evaluationd_903<NA><NA><NA><NA>000192972340
47735396HOUSEHOLD_2_516_WI_3_evaluationd_888IndependenceDayNational<NA><NA>10042772130
\n", + "

47735397 rows × 16 columns

\n", + "
" + ], + "text/plain": [ + " id day_id event_name_1 \\\n", + "0 FOODS_1_002_WI_3_evaluation d_1678 \n", + "1 FOODS_1_002_WI_3_evaluation d_1679 \n", + "2 FOODS_1_002_WI_3_evaluation d_1680 \n", + "3 FOODS_1_002_WI_3_evaluation d_1681 \n", + "4 FOODS_1_002_WI_3_evaluation d_1682 \n", + "... ... ... ... \n", + "47735392 HOUSEHOLD_2_516_WI_2_evaluation d_977 \n", + "47735393 HOUSEHOLD_2_516_WI_2_evaluation d_978 \n", + "47735394 HOUSEHOLD_2_516_WI_2_evaluation d_979 \n", + "47735395 HOUSEHOLD_2_516_WI_3_evaluation d_903 \n", + "47735396 HOUSEHOLD_2_516_WI_3_evaluation d_888 IndependenceDay \n", + "\n", + " event_type_1 event_name_2 event_type_2 snap_CA snap_TX snap_WI tm_d \\\n", + "0 1 0 1 2 \n", + "1 1 1 1 3 \n", + "2 1 0 0 4 \n", + "3 1 1 1 5 \n", + "4 1 1 1 6 \n", + "... ... ... ... ... ... ... ... \n", + "47735392 1 1 0 1 \n", + "47735393 1 0 1 2 \n", + "47735394 1 1 1 3 \n", + "47735395 0 0 0 19 \n", + "47735396 National 1 0 0 4 \n", + "\n", + " tm_w tm_m tm_y tm_wm tm_dw tm_w_end \n", + "0 36 9 4 1 2 0 \n", + "1 36 9 4 1 3 0 \n", + "2 36 9 4 1 4 0 \n", + "3 36 9 4 1 5 1 \n", + "4 36 9 4 1 6 1 \n", + "... ... ... ... ... ... ... \n", + "47735392 40 10 2 1 1 0 \n", + "47735393 40 10 2 1 2 0 \n", + "47735394 40 10 2 1 3 0 \n", + "47735395 29 7 2 3 4 0 \n", + "47735396 27 7 2 1 3 0 \n", + "\n", + "[47735397 rows x 16 columns]" + ] + }, + "execution_count": 48, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import cupy as cp\n", + "\n", + "grid_df_with_calendar[\"tm_d\"] = grid_df_with_calendar[\"date\"].dt.day.astype(np.int8)\n", + "grid_df_with_calendar[\"tm_w\"] = (\n", + " grid_df_with_calendar[\"date\"].dt.isocalendar().week.astype(np.int8)\n", + ")\n", + "grid_df_with_calendar[\"tm_m\"] = grid_df_with_calendar[\"date\"].dt.month.astype(np.int8)\n", + "grid_df_with_calendar[\"tm_y\"] = grid_df_with_calendar[\"date\"].dt.year\n", + "grid_df_with_calendar[\"tm_y\"] = (\n", + " grid_df_with_calendar[\"tm_y\"] - grid_df_with_calendar[\"tm_y\"].min()\n", + ").astype(np.int8)\n", + "grid_df_with_calendar[\"tm_wm\"] = cp.ceil(\n", + " grid_df_with_calendar[\"tm_d\"].to_cupy() / 7\n", + ").astype(\n", + " np.int8\n", + ") # which week in tje month?\n", + "grid_df_with_calendar[\"tm_dw\"] = grid_df_with_calendar[\"date\"].dt.dayofweek.astype(\n", + " np.int8\n", + ") # which day in the week?\n", + "grid_df_with_calendar[\"tm_w_end\"] = (grid_df_with_calendar[\"tm_dw\"] >= 5).astype(\n", + " np.int8\n", + ") # whether today is in the weekend\n", + "del grid_df_with_calendar[\"date\"] # no longer needed\n", + "\n", + "grid_df_with_calendar" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "id": "f1f05d6a-7d92-4613-92fb-57c55a01f3db", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "96" + ] + }, + "execution_count": 49, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "del grid_df_id_only # No longer needed\n", + "gc.collect()" + ] + }, + { + "cell_type": "markdown", + "id": "0cd47558-1960-4a06-b941-884c6b0464fa", + "metadata": {}, + "source": [ + "### Generate lag features\n", + "\n", + "**Lag features** are the value of the target variable at prior timestamps. Lag features are useful because what happens in the past often influences what would happen in the future. In our example, we generate lag features by reading the sales amount at X days prior, where X = 28, 29, ..., 42." + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "id": "18792591-f6e1-49d9-b3b9-721fae1675aa", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "SHIFT_DAY = 28\n", + "LAG_DAYS = [col for col in range(SHIFT_DAY, SHIFT_DAY + 15)]\n", + "\n", + "# Need to first ensure that rows in each time series are sorted by day_id\n", + "grid_df_lags = grid_df[[\"id\", \"day_id\", \"sales\"]].copy()\n", + "grid_df_lags = grid_df_lags.sort_values([\"id\", \"day_id\"])\n", + "\n", + "grid_df_lags = grid_df_lags.assign(\n", + " **{\n", + " f\"sales_lag_{l}\": grid_df_lags.groupby([\"id\"])[\"sales\"].shift(l)\n", + " for l in LAG_DAYS\n", + " }\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "id": "ed47d972-f6c9-4618-986a-08649dfbfcb8", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idday_idsalessales_lag_28sales_lag_29sales_lag_30sales_lag_31sales_lag_32sales_lag_33sales_lag_34sales_lag_35sales_lag_36sales_lag_37sales_lag_38sales_lag_39sales_lag_40sales_lag_41sales_lag_42
34023FOODS_1_001_CA_1_evaluationd_13.0<NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA>
34024FOODS_1_001_CA_1_evaluationd_20.0<NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA>
34025FOODS_1_001_CA_1_evaluationd_30.0<NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA>
34026FOODS_1_001_CA_1_evaluationd_41.0<NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA>
34027FOODS_1_001_CA_1_evaluationd_54.0<NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA><NA>
.........................................................
47733744HOUSEHOLD_2_516_WI_3_evaluationd_1965NaN0.00.00.00.01.00.00.00.00.00.00.00.00.00.00.0
47733745HOUSEHOLD_2_516_WI_3_evaluationd_1966NaN0.00.00.00.00.01.00.00.00.00.00.00.00.00.00.0
47733746HOUSEHOLD_2_516_WI_3_evaluationd_1967NaN0.00.00.00.00.00.01.00.00.00.00.00.00.00.00.0
47733747HOUSEHOLD_2_516_WI_3_evaluationd_1968NaN0.00.00.00.00.00.00.01.00.00.00.00.00.00.00.0
47733748HOUSEHOLD_2_516_WI_3_evaluationd_1969NaN0.00.00.00.00.00.00.00.01.00.00.00.00.00.00.0
\n", + "

47735397 rows × 18 columns

\n", + "
" + ], + "text/plain": [ + " id day_id sales sales_lag_28 \\\n", + "34023 FOODS_1_001_CA_1_evaluation d_1 3.0 \n", + "34024 FOODS_1_001_CA_1_evaluation d_2 0.0 \n", + "34025 FOODS_1_001_CA_1_evaluation d_3 0.0 \n", + "34026 FOODS_1_001_CA_1_evaluation d_4 1.0 \n", + "34027 FOODS_1_001_CA_1_evaluation d_5 4.0 \n", + "... ... ... ... ... \n", + "47733744 HOUSEHOLD_2_516_WI_3_evaluation d_1965 NaN 0.0 \n", + "47733745 HOUSEHOLD_2_516_WI_3_evaluation d_1966 NaN 0.0 \n", + "47733746 HOUSEHOLD_2_516_WI_3_evaluation d_1967 NaN 0.0 \n", + "47733747 HOUSEHOLD_2_516_WI_3_evaluation d_1968 NaN 0.0 \n", + "47733748 HOUSEHOLD_2_516_WI_3_evaluation d_1969 NaN 0.0 \n", + "\n", + " sales_lag_29 sales_lag_30 sales_lag_31 sales_lag_32 sales_lag_33 \\\n", + "34023 \n", + "34024 \n", + "34025 \n", + "34026 \n", + "34027 \n", + "... ... ... ... ... ... \n", + "47733744 0.0 0.0 0.0 1.0 0.0 \n", + "47733745 0.0 0.0 0.0 0.0 1.0 \n", + "47733746 0.0 0.0 0.0 0.0 0.0 \n", + "47733747 0.0 0.0 0.0 0.0 0.0 \n", + "47733748 0.0 0.0 0.0 0.0 0.0 \n", + "\n", + " sales_lag_34 sales_lag_35 sales_lag_36 sales_lag_37 sales_lag_38 \\\n", + "34023 \n", + "34024 \n", + "34025 \n", + "34026 \n", + "34027 \n", + "... ... ... ... ... ... \n", + "47733744 0.0 0.0 0.0 0.0 0.0 \n", + "47733745 0.0 0.0 0.0 0.0 0.0 \n", + "47733746 1.0 0.0 0.0 0.0 0.0 \n", + "47733747 0.0 1.0 0.0 0.0 0.0 \n", + "47733748 0.0 0.0 1.0 0.0 0.0 \n", + "\n", + " sales_lag_39 sales_lag_40 sales_lag_41 sales_lag_42 \n", + "34023 \n", + "34024 \n", + "34025 \n", + "34026 \n", + "34027 \n", + "... ... ... ... ... \n", + "47733744 0.0 0.0 0.0 0.0 \n", + "47733745 0.0 0.0 0.0 0.0 \n", + "47733746 0.0 0.0 0.0 0.0 \n", + "47733747 0.0 0.0 0.0 0.0 \n", + "47733748 0.0 0.0 0.0 0.0 \n", + "\n", + "[47735397 rows x 18 columns]" + ] + }, + "execution_count": 51, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "grid_df_lags" + ] + }, + { + "cell_type": "markdown", + "id": "9e85f484-a5e9-4acd-a60b-f7cc407fe7a4", + "metadata": {}, + "source": [ + "### Compute rolling window statistics\n", + "\n", + "In the previous cell, we used the value of sales at a single timestamp to generate lag features. To capture richer information about the past, let us also get the distribution of the sales value over multiple timestamps, by computing **rolling window statistics**. Rolling window statistics are statistics (e.g. mean, standard deviation) over a time duration in the past. Rolling windows statistics complement lag features and provide more information about the past behavior of the target variable.\n", + "\n", + "Read more about lag features and rolling window statistics in [Introduction to feature engineering for time series forecasting](https://medium.com/data-science-at-microsoft/introduction-to-feature-engineering-for-time-series-forecasting-620aa55fcab0)." + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "id": "485b447e-213f-4091-a0df-874c4d35ac64", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Shift size: 28\n", + " Window size: 7\n", + " Window size: 14\n", + " Window size: 30\n", + " Window size: 60\n", + " Window size: 180\n" + ] + } + ], + "source": [ + "# Shift by 28 days and apply windows of various sizes\n", + "print(f\"Shift size: {SHIFT_DAY}\")\n", + "for i in [7, 14, 30, 60, 180]:\n", + " print(f\" Window size: {i}\")\n", + " grid_df_lags[f\"rolling_mean_{i}\"] = (\n", + " grid_df_lags.groupby([\"id\"])[\"sales\"]\n", + " .shift(SHIFT_DAY)\n", + " .rolling(i)\n", + " .mean()\n", + " .astype(np.float32)\n", + " )\n", + " grid_df_lags[f\"rolling_std_{i}\"] = (\n", + " grid_df_lags.groupby([\"id\"])[\"sales\"]\n", + " .shift(SHIFT_DAY)\n", + " .rolling(i)\n", + " .std()\n", + " .astype(np.float32)\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "id": "64d0e7bc-8c79-43be-9809-192807d57ada", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "Index(['id', 'day_id', 'sales', 'sales_lag_28', 'sales_lag_29', 'sales_lag_30',\n", + " 'sales_lag_31', 'sales_lag_32', 'sales_lag_33', 'sales_lag_34',\n", + " 'sales_lag_35', 'sales_lag_36', 'sales_lag_37', 'sales_lag_38',\n", + " 'sales_lag_39', 'sales_lag_40', 'sales_lag_41', 'sales_lag_42',\n", + " 'rolling_mean_7', 'rolling_std_7', 'rolling_mean_14', 'rolling_std_14',\n", + " 'rolling_mean_30', 'rolling_std_30', 'rolling_mean_60',\n", + " 'rolling_std_60', 'rolling_mean_180', 'rolling_std_180'],\n", + " dtype='object')" + ] + }, + "execution_count": 53, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "grid_df_lags.columns" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "id": "977f501a-f358-423c-9248-210405a0e51b", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "id category\n", + "day_id category\n", + "sales float32\n", + "sales_lag_28 float32\n", + "sales_lag_29 float32\n", + "sales_lag_30 float32\n", + "sales_lag_31 float32\n", + "sales_lag_32 float32\n", + "sales_lag_33 float32\n", + "sales_lag_34 float32\n", + "sales_lag_35 float32\n", + "sales_lag_36 float32\n", + "sales_lag_37 float32\n", + "sales_lag_38 float32\n", + "sales_lag_39 float32\n", + "sales_lag_40 float32\n", + "sales_lag_41 float32\n", + "sales_lag_42 float32\n", + "rolling_mean_7 float32\n", + "rolling_std_7 float32\n", + "rolling_mean_14 float32\n", + "rolling_std_14 float32\n", + "rolling_mean_30 float32\n", + "rolling_std_30 float32\n", + "rolling_mean_60 float32\n", + "rolling_std_60 float32\n", + "rolling_mean_180 float32\n", + "rolling_std_180 float32\n", + "dtype: object" + ] + }, + "execution_count": 54, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "grid_df_lags.dtypes" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "id": "22d8519f-73a5-4dde-be3a-1896986057e6", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idday_idsalessales_lag_28sales_lag_29sales_lag_30sales_lag_31sales_lag_32sales_lag_33sales_lag_34...rolling_mean_7rolling_std_7rolling_mean_14rolling_std_14rolling_mean_30rolling_std_30rolling_mean_60rolling_std_60rolling_mean_180rolling_std_180
34023FOODS_1_001_CA_1_evaluationd_13.0<NA><NA><NA><NA><NA><NA><NA>...<NA><NA><NA><NA><NA><NA><NA><NA><NA><NA>
34024FOODS_1_001_CA_1_evaluationd_20.0<NA><NA><NA><NA><NA><NA><NA>...<NA><NA><NA><NA><NA><NA><NA><NA><NA><NA>
34025FOODS_1_001_CA_1_evaluationd_30.0<NA><NA><NA><NA><NA><NA><NA>...<NA><NA><NA><NA><NA><NA><NA><NA><NA><NA>
34026FOODS_1_001_CA_1_evaluationd_41.0<NA><NA><NA><NA><NA><NA><NA>...<NA><NA><NA><NA><NA><NA><NA><NA><NA><NA>
34027FOODS_1_001_CA_1_evaluationd_54.0<NA><NA><NA><NA><NA><NA><NA>...<NA><NA><NA><NA><NA><NA><NA><NA><NA><NA>
..................................................................
47733744HOUSEHOLD_2_516_WI_3_evaluationd_1965NaN0.00.00.00.01.00.00.0...0.1428571490.3779644670.0714285750.2672612370.066666670.2537081240.0333333350.1810203340.0777777810.288621366
47733745HOUSEHOLD_2_516_WI_3_evaluationd_1966NaN0.00.00.00.00.01.00.0...0.1428571490.3779644670.0714285750.2672612370.066666670.2537081240.0333333350.1810203340.0777777810.288621366
47733746HOUSEHOLD_2_516_WI_3_evaluationd_1967NaN0.00.00.00.00.00.01.0...0.1428571490.3779644670.0714285750.2672612370.066666670.2537081240.0333333350.1810203340.0777777810.288621366
47733747HOUSEHOLD_2_516_WI_3_evaluationd_1968NaN0.00.00.00.00.00.00.0...0.00.00.0714285750.2672612370.066666670.2537081240.0333333350.1810203340.0777777810.288621366
47733748HOUSEHOLD_2_516_WI_3_evaluationd_1969NaN0.00.00.00.00.00.00.0...0.00.00.0714285750.2672612370.066666670.2537081240.0333333350.1810203340.0777777810.288621366
\n", + "

47735397 rows × 28 columns

\n", + "
" + ], + "text/plain": [ + " id day_id sales sales_lag_28 \\\n", + "34023 FOODS_1_001_CA_1_evaluation d_1 3.0 \n", + "34024 FOODS_1_001_CA_1_evaluation d_2 0.0 \n", + "34025 FOODS_1_001_CA_1_evaluation d_3 0.0 \n", + "34026 FOODS_1_001_CA_1_evaluation d_4 1.0 \n", + "34027 FOODS_1_001_CA_1_evaluation d_5 4.0 \n", + "... ... ... ... ... \n", + "47733744 HOUSEHOLD_2_516_WI_3_evaluation d_1965 NaN 0.0 \n", + "47733745 HOUSEHOLD_2_516_WI_3_evaluation d_1966 NaN 0.0 \n", + "47733746 HOUSEHOLD_2_516_WI_3_evaluation d_1967 NaN 0.0 \n", + "47733747 HOUSEHOLD_2_516_WI_3_evaluation d_1968 NaN 0.0 \n", + "47733748 HOUSEHOLD_2_516_WI_3_evaluation d_1969 NaN 0.0 \n", + "\n", + " sales_lag_29 sales_lag_30 sales_lag_31 sales_lag_32 sales_lag_33 \\\n", + "34023 \n", + "34024 \n", + "34025 \n", + "34026 \n", + "34027 \n", + "... ... ... ... ... ... \n", + "47733744 0.0 0.0 0.0 1.0 0.0 \n", + "47733745 0.0 0.0 0.0 0.0 1.0 \n", + "47733746 0.0 0.0 0.0 0.0 0.0 \n", + "47733747 0.0 0.0 0.0 0.0 0.0 \n", + "47733748 0.0 0.0 0.0 0.0 0.0 \n", + "\n", + " sales_lag_34 ... rolling_mean_7 rolling_std_7 rolling_mean_14 \\\n", + "34023 ... \n", + "34024 ... \n", + "34025 ... \n", + "34026 ... \n", + "34027 ... \n", + "... ... ... ... ... ... \n", + "47733744 0.0 ... 0.142857149 0.377964467 0.071428575 \n", + "47733745 0.0 ... 0.142857149 0.377964467 0.071428575 \n", + "47733746 1.0 ... 0.142857149 0.377964467 0.071428575 \n", + "47733747 0.0 ... 0.0 0.0 0.071428575 \n", + "47733748 0.0 ... 0.0 0.0 0.071428575 \n", + "\n", + " rolling_std_14 rolling_mean_30 rolling_std_30 rolling_mean_60 \\\n", + "34023 \n", + "34024 \n", + "34025 \n", + "34026 \n", + "34027 \n", + "... ... ... ... ... \n", + "47733744 0.267261237 0.06666667 0.253708124 0.033333335 \n", + "47733745 0.267261237 0.06666667 0.253708124 0.033333335 \n", + "47733746 0.267261237 0.06666667 0.253708124 0.033333335 \n", + "47733747 0.267261237 0.06666667 0.253708124 0.033333335 \n", + "47733748 0.267261237 0.06666667 0.253708124 0.033333335 \n", + "\n", + " rolling_std_60 rolling_mean_180 rolling_std_180 \n", + "34023 \n", + "34024 \n", + "34025 \n", + "34026 \n", + "34027 \n", + "... ... ... ... \n", + "47733744 0.181020334 0.077777781 0.288621366 \n", + "47733745 0.181020334 0.077777781 0.288621366 \n", + "47733746 0.181020334 0.077777781 0.288621366 \n", + "47733747 0.181020334 0.077777781 0.288621366 \n", + "47733748 0.181020334 0.077777781 0.288621366 \n", + "\n", + "[47735397 rows x 28 columns]" + ] + }, + "execution_count": 55, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "grid_df_lags" + ] + }, + { + "cell_type": "markdown", + "id": "898c2f27-50b0-4160-abd0-d6554c767c7f", + "metadata": {}, + "source": [ + "### Target encoding\n", + "Categorical variables present challenges to many machine learning algorithms such as XGBoost. One way to overcome the challenge is to use **target encoding**, where we encode categorical variables by replacing them with a statistic for the target variable. In this example, we will use the mean and the standard deviation.\n", + "\n", + "Read more about target encoding in [Target-encoding Categorical Variables](https://towardsdatascience.com/dealing-with-categorical-variables-by-using-target-encoder-a0f1733a4c69)." + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "id": "62cfbd3b-ede8-4bd1-967b-5a8785df49ea", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Encoding columns ['store_id', 'dept_id']\n", + "Encoding columns ['item_id', 'state_id']\n" + ] + } + ], + "source": [ + "icols = [[\"store_id\", \"dept_id\"], [\"item_id\", \"state_id\"]]\n", + "new_columns = []\n", + "\n", + "grid_df_target_enc = grid_df[\n", + " [\"id\", \"day_id\", \"item_id\", \"state_id\", \"store_id\", \"dept_id\", \"sales\"]\n", + "].copy()\n", + "grid_df_target_enc[\"sales\"].fillna(value=0, inplace=True)\n", + "\n", + "for col in icols:\n", + " print(f\"Encoding columns {col}\")\n", + " col_name = \"_\" + \"_\".join(col) + \"_\"\n", + " grid_df_target_enc[\"enc\" + col_name + \"mean\"] = (\n", + " grid_df_target_enc.groupby(col)[\"sales\"].transform(\"mean\").astype(np.float32)\n", + " )\n", + " grid_df_target_enc[\"enc\" + col_name + \"std\"] = (\n", + " grid_df_target_enc.groupby(col)[\"sales\"].transform(\"std\").astype(np.float32)\n", + " )\n", + " new_columns.extend([\"enc\" + col_name + \"mean\", \"enc\" + col_name + \"std\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "id": "8c83b674-7ec1-4203-acc3-7f12c0419001", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idday_idenc_store_id_dept_id_meanenc_store_id_dept_id_stdenc_item_id_state_id_meanenc_item_id_state_id_std
0FOODS_1_001_WI_2_evaluationd_8091.4929883.9876570.4335530.851153
1FOODS_1_001_WI_2_evaluationd_8101.4929883.9876570.4335530.851153
2FOODS_1_001_WI_2_evaluationd_8111.4929883.9876570.4335530.851153
3FOODS_1_001_WI_2_evaluationd_8121.4929883.9876570.4335530.851153
4FOODS_1_001_WI_2_evaluationd_8131.4929883.9876570.4335530.851153
.....................
47735392HOUSEHOLD_2_516_WI_3_evaluationd_520.2570270.6615410.0820840.299445
47735393HOUSEHOLD_2_516_WI_3_evaluationd_530.2570270.6615410.0820840.299445
47735394HOUSEHOLD_2_516_WI_3_evaluationd_540.2570270.6615410.0820840.299445
47735395HOUSEHOLD_2_516_WI_3_evaluationd_550.2570270.6615410.0820840.299445
47735396HOUSEHOLD_2_516_WI_3_evaluationd_490.2570270.6615410.0820840.299445
\n", + "

47735397 rows × 6 columns

\n", + "
" + ], + "text/plain": [ + " id day_id enc_store_id_dept_id_mean \\\n", + "0 FOODS_1_001_WI_2_evaluation d_809 1.492988 \n", + "1 FOODS_1_001_WI_2_evaluation d_810 1.492988 \n", + "2 FOODS_1_001_WI_2_evaluation d_811 1.492988 \n", + "3 FOODS_1_001_WI_2_evaluation d_812 1.492988 \n", + "4 FOODS_1_001_WI_2_evaluation d_813 1.492988 \n", + "... ... ... ... \n", + "47735392 HOUSEHOLD_2_516_WI_3_evaluation d_52 0.257027 \n", + "47735393 HOUSEHOLD_2_516_WI_3_evaluation d_53 0.257027 \n", + "47735394 HOUSEHOLD_2_516_WI_3_evaluation d_54 0.257027 \n", + "47735395 HOUSEHOLD_2_516_WI_3_evaluation d_55 0.257027 \n", + "47735396 HOUSEHOLD_2_516_WI_3_evaluation d_49 0.257027 \n", + "\n", + " enc_store_id_dept_id_std enc_item_id_state_id_mean \\\n", + "0 3.987657 0.433553 \n", + "1 3.987657 0.433553 \n", + "2 3.987657 0.433553 \n", + "3 3.987657 0.433553 \n", + "4 3.987657 0.433553 \n", + "... ... ... \n", + "47735392 0.661541 0.082084 \n", + "47735393 0.661541 0.082084 \n", + "47735394 0.661541 0.082084 \n", + "47735395 0.661541 0.082084 \n", + "47735396 0.661541 0.082084 \n", + "\n", + " enc_item_id_state_id_std \n", + "0 0.851153 \n", + "1 0.851153 \n", + "2 0.851153 \n", + "3 0.851153 \n", + "4 0.851153 \n", + "... ... \n", + "47735392 0.299445 \n", + "47735393 0.299445 \n", + "47735394 0.299445 \n", + "47735395 0.299445 \n", + "47735396 0.299445 \n", + "\n", + "[47735397 rows x 6 columns]" + ] + }, + "execution_count": 57, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "grid_df_target_enc = grid_df_target_enc[[\"id\", \"day_id\"] + new_columns]\n", + "grid_df_target_enc" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "id": "02c6b045-331e-4c4c-baa7-07a44e44ab39", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "id category\n", + "day_id category\n", + "enc_store_id_dept_id_mean float32\n", + "enc_store_id_dept_id_std float32\n", + "enc_item_id_state_id_mean float32\n", + "enc_item_id_state_id_std float32\n", + "dtype: object" + ] + }, + "execution_count": 58, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "grid_df_target_enc.dtypes" + ] + }, + { + "cell_type": "markdown", + "id": "e786fff5-0567-476d-be1a-aa440132fc3c", + "metadata": {}, + "source": [ + "### Filter by store and product department and create data segments\n", + "After combining all columns produced in the previous notebooks, we filter the rows in the data set by `store_id` and `dept_id` and create a segment. Each segment is saved as a pickle file and then upload to Cloud Storage." + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "id": "4ec8a3a1-f5a6-4ac9-aaef-2debcbe02919", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "segmented_data_dir = pathlib.Path(\"./segmented_data/\")\n", + "segmented_data_dir.mkdir(exist_ok=True)\n", + "\n", + "STORES = [\n", + " \"CA_1\",\n", + " \"CA_2\",\n", + " \"CA_3\",\n", + " \"CA_4\",\n", + " \"TX_1\",\n", + " \"TX_2\",\n", + " \"TX_3\",\n", + " \"WI_1\",\n", + " \"WI_2\",\n", + " \"WI_3\",\n", + "]\n", + "DEPTS = [\n", + " \"HOBBIES_1\",\n", + " \"HOBBIES_2\",\n", + " \"HOUSEHOLD_1\",\n", + " \"HOUSEHOLD_2\",\n", + " \"FOODS_1\",\n", + " \"FOODS_2\",\n", + " \"FOODS_3\",\n", + "]\n", + "\n", + "grid2_colnm = [\n", + " \"sell_price\",\n", + " \"price_max\",\n", + " \"price_min\",\n", + " \"price_std\",\n", + " \"price_mean\",\n", + " \"price_norm\",\n", + " \"price_nunique\",\n", + " \"item_nunique\",\n", + " \"price_momentum\",\n", + " \"price_momentum_m\",\n", + " \"price_momentum_y\",\n", + "]\n", + "\n", + "grid3_colnm = [\n", + " \"event_name_1\",\n", + " \"event_type_1\",\n", + " \"event_name_2\",\n", + " \"event_type_2\",\n", + " \"snap_CA\",\n", + " \"snap_TX\",\n", + " \"snap_WI\",\n", + " \"tm_d\",\n", + " \"tm_w\",\n", + " \"tm_m\",\n", + " \"tm_y\",\n", + " \"tm_wm\",\n", + " \"tm_dw\",\n", + " \"tm_w_end\",\n", + "]\n", + "\n", + "lag_colnm = [\n", + " \"sales_lag_28\",\n", + " \"sales_lag_29\",\n", + " \"sales_lag_30\",\n", + " \"sales_lag_31\",\n", + " \"sales_lag_32\",\n", + " \"sales_lag_33\",\n", + " \"sales_lag_34\",\n", + " \"sales_lag_35\",\n", + " \"sales_lag_36\",\n", + " \"sales_lag_37\",\n", + " \"sales_lag_38\",\n", + " \"sales_lag_39\",\n", + " \"sales_lag_40\",\n", + " \"sales_lag_41\",\n", + " \"sales_lag_42\",\n", + " \"rolling_mean_7\",\n", + " \"rolling_std_7\",\n", + " \"rolling_mean_14\",\n", + " \"rolling_std_14\",\n", + " \"rolling_mean_30\",\n", + " \"rolling_std_30\",\n", + " \"rolling_mean_60\",\n", + " \"rolling_std_60\",\n", + " \"rolling_mean_180\",\n", + " \"rolling_std_180\",\n", + "]\n", + "\n", + "target_enc_colnm = [\n", + " \"enc_store_id_dept_id_mean\",\n", + " \"enc_store_id_dept_id_std\",\n", + " \"enc_item_id_state_id_mean\",\n", + " \"enc_item_id_state_id_std\",\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "id": "23674edc-4f63-49dd-a421-a15dc4fb0a70", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def prepare_data(store, dept=None):\n", + " \"\"\"\n", + " Filter and clean data according to stores and product departments\n", + "\n", + " Parameters\n", + " ----------\n", + " store: Filter data by retaining rows whose store_id matches this parameter.\n", + " dept: Filter data by retaining rows whose dept_id matches this parameter.\n", + " This parameter can be set to None to indicate that we shouldn't filter by dept_id.\n", + " \"\"\"\n", + " if store is None:\n", + " raise ValueError(f\"store parameter must not be None\")\n", + "\n", + " if dept is None:\n", + " grid1 = grid_df[grid_df[\"store_id\"] == store]\n", + " else:\n", + " grid1 = grid_df[\n", + " (grid_df[\"store_id\"] == store) & (grid_df[\"dept_id\"] == dept)\n", + " ].drop(columns=[\"dept_id\"])\n", + " grid1 = grid1.drop(columns=[\"release_week\", \"wm_yr_wk\", \"store_id\", \"state_id\"])\n", + "\n", + " grid2 = grid_df_with_price[[\"id\", \"day_id\"] + grid2_colnm]\n", + " grid_combined = grid1.merge(grid2, on=[\"id\", \"day_id\"], how=\"left\")\n", + " del grid1, grid2\n", + "\n", + " grid3 = grid_df_with_calendar[[\"id\", \"day_id\"] + grid3_colnm]\n", + " grid_combined = grid_combined.merge(grid3, on=[\"id\", \"day_id\"], how=\"left\")\n", + " del grid3\n", + "\n", + " lag_df = grid_df_lags[[\"id\", \"day_id\"] + lag_colnm]\n", + " grid_combined = grid_combined.merge(lag_df, on=[\"id\", \"day_id\"], how=\"left\")\n", + " del lag_df\n", + "\n", + " target_enc_df = grid_df_target_enc[[\"id\", \"day_id\"] + target_enc_colnm]\n", + " grid_combined = grid_combined.merge(target_enc_df, on=[\"id\", \"day_id\"], how=\"left\")\n", + " del target_enc_df\n", + " gc.collect()\n", + "\n", + " grid_combined = grid_combined.drop(columns=[\"id\"])\n", + " grid_combined[\"day_id\"] = (\n", + " grid_combined[\"day_id\"]\n", + " .to_pandas()\n", + " .astype(\"str\")\n", + " .apply(lambda x: x[2:])\n", + " .astype(np.int16)\n", + " )\n", + "\n", + " return grid_combined" + ] + }, + { + "cell_type": "code", + "execution_count": 61, + "id": "9762f74a-3842-475a-9569-231df886cfdd", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Processing store CA_1...\n", + "Processing store CA_2...\n", + "Processing store CA_3...\n", + "Processing store CA_4...\n", + "Processing store TX_1...\n", + "Processing store TX_2...\n", + "Processing store TX_3...\n", + "Processing store WI_1...\n", + "Processing store WI_2...\n", + "Processing store WI_3...\n", + "Processing (store CA_1, department HOBBIES_1)...\n", + "Processing (store CA_1, department HOBBIES_2)...\n", + "Processing (store CA_1, department HOUSEHOLD_1)...\n", + "Processing (store CA_1, department HOUSEHOLD_2)...\n", + "Processing (store CA_1, department FOODS_1)...\n", + "Processing (store CA_1, department FOODS_2)...\n", + "Processing (store CA_1, department FOODS_3)...\n", + "Processing (store CA_2, department HOBBIES_1)...\n", + "Processing (store CA_2, department HOBBIES_2)...\n", + "Processing (store CA_2, department HOUSEHOLD_1)...\n", + "Processing (store CA_2, department HOUSEHOLD_2)...\n", + "Processing (store CA_2, department FOODS_1)...\n", + "Processing (store CA_2, department FOODS_2)...\n", + "Processing (store CA_2, department FOODS_3)...\n", + "Processing (store CA_3, department HOBBIES_1)...\n", + "Processing (store CA_3, department HOBBIES_2)...\n", + "Processing (store CA_3, department HOUSEHOLD_1)...\n", + "Processing (store CA_3, department HOUSEHOLD_2)...\n", + "Processing (store CA_3, department FOODS_1)...\n", + "Processing (store CA_3, department FOODS_2)...\n", + "Processing (store CA_3, department FOODS_3)...\n", + "Processing (store CA_4, department HOBBIES_1)...\n", + "Processing (store CA_4, department HOBBIES_2)...\n", + "Processing (store CA_4, department HOUSEHOLD_1)...\n", + "Processing (store CA_4, department HOUSEHOLD_2)...\n", + "Processing (store CA_4, department FOODS_1)...\n", + "Processing (store CA_4, department FOODS_2)...\n", + "Processing (store CA_4, department FOODS_3)...\n", + "Processing (store TX_1, department HOBBIES_1)...\n", + "Processing (store TX_1, department HOBBIES_2)...\n", + "Processing (store TX_1, department HOUSEHOLD_1)...\n", + "Processing (store TX_1, department HOUSEHOLD_2)...\n", + "Processing (store TX_1, department FOODS_1)...\n", + "Processing (store TX_1, department FOODS_2)...\n", + "Processing (store TX_1, department FOODS_3)...\n", + "Processing (store TX_2, department HOBBIES_1)...\n", + "Processing (store TX_2, department HOBBIES_2)...\n", + "Processing (store TX_2, department HOUSEHOLD_1)...\n", + "Processing (store TX_2, department HOUSEHOLD_2)...\n", + "Processing (store TX_2, department FOODS_1)...\n", + "Processing (store TX_2, department FOODS_2)...\n", + "Processing (store TX_2, department FOODS_3)...\n", + "Processing (store TX_3, department HOBBIES_1)...\n", + "Processing (store TX_3, department HOBBIES_2)...\n", + "Processing (store TX_3, department HOUSEHOLD_1)...\n", + "Processing (store TX_3, department HOUSEHOLD_2)...\n", + "Processing (store TX_3, department FOODS_1)...\n", + "Processing (store TX_3, department FOODS_2)...\n", + "Processing (store TX_3, department FOODS_3)...\n", + "Processing (store WI_1, department HOBBIES_1)...\n", + "Processing (store WI_1, department HOBBIES_2)...\n", + "Processing (store WI_1, department HOUSEHOLD_1)...\n", + "Processing (store WI_1, department HOUSEHOLD_2)...\n", + "Processing (store WI_1, department FOODS_1)...\n", + "Processing (store WI_1, department FOODS_2)...\n", + "Processing (store WI_1, department FOODS_3)...\n", + "Processing (store WI_2, department HOBBIES_1)...\n", + "Processing (store WI_2, department HOBBIES_2)...\n", + "Processing (store WI_2, department HOUSEHOLD_1)...\n", + "Processing (store WI_2, department HOUSEHOLD_2)...\n", + "Processing (store WI_2, department FOODS_1)...\n", + "Processing (store WI_2, department FOODS_2)...\n", + "Processing (store WI_2, department FOODS_3)...\n", + "Processing (store WI_3, department HOBBIES_1)...\n", + "Processing (store WI_3, department HOBBIES_2)...\n", + "Processing (store WI_3, department HOUSEHOLD_1)...\n", + "Processing (store WI_3, department HOUSEHOLD_2)...\n", + "Processing (store WI_3, department FOODS_1)...\n", + "Processing (store WI_3, department FOODS_2)...\n", + "Processing (store WI_3, department FOODS_3)...\n" + ] + } + ], + "source": [ + "# First save the segment to the disk\n", + "for store in STORES:\n", + " print(f\"Processing store {store}...\")\n", + " segment_df = prepare_data(store=store)\n", + " segment_df.to_pandas().to_pickle(\n", + " segmented_data_dir / f\"combined_df_store_{store}.pkl\"\n", + " )\n", + " del segment_df\n", + " gc.collect()\n", + "\n", + "for store in STORES:\n", + " for dept in DEPTS:\n", + " print(f\"Processing (store {store}, department {dept})...\")\n", + " segment_df = prepare_data(store=store, dept=dept)\n", + " segment_df.to_pandas().to_pickle(\n", + " segmented_data_dir / f\"combined_df_store_{store}_dept_{dept}.pkl\"\n", + " )\n", + " del segment_df\n", + " gc.collect()" + ] + }, + { + "cell_type": "code", + "execution_count": 62, + "id": "17c6859b-6159-424b-a041-1bc6b8201716", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Uploading segmented_data/combined_df_store_CA_3_dept_HOBBIES_2.pkl...\n", + "Uploading segmented_data/combined_df_store_TX_3_dept_FOODS_3.pkl...\n", + "Uploading segmented_data/combined_df_store_CA_1_dept_HOUSEHOLD_1.pkl...\n", + "Uploading segmented_data/combined_df_store_TX_3_dept_HOBBIES_1.pkl...\n", + "Uploading segmented_data/combined_df_store_WI_2_dept_HOUSEHOLD_2.pkl...\n", + "Uploading segmented_data/combined_df_store_TX_3_dept_HOUSEHOLD_1.pkl...\n", + "Uploading segmented_data/combined_df_store_WI_1_dept_HOUSEHOLD_2.pkl...\n", + "Uploading segmented_data/combined_df_store_CA_3_dept_HOBBIES_1.pkl...\n", + "Uploading segmented_data/combined_df_store_CA_1_dept_FOODS_3.pkl...\n", + "Uploading segmented_data/combined_df_store_TX_1_dept_HOBBIES_2.pkl...\n", + "Uploading segmented_data/combined_df_store_TX_2_dept_FOODS_3.pkl...\n", + "Uploading segmented_data/combined_df_store_CA_2_dept_FOODS_3.pkl...\n", + "Uploading segmented_data/combined_df_store_WI_3.pkl...\n", + "Uploading segmented_data/combined_df_store_TX_1_dept_HOUSEHOLD_2.pkl...\n", + "Uploading segmented_data/combined_df_store_WI_3_dept_FOODS_3.pkl...\n", + "Uploading segmented_data/combined_df_store_WI_2_dept_FOODS_1.pkl...\n", + "Uploading segmented_data/combined_df_store_TX_1_dept_FOODS_3.pkl...\n", + "Uploading segmented_data/combined_df_store_CA_3_dept_FOODS_3.pkl...\n", + "Uploading segmented_data/combined_df_store_WI_1_dept_HOBBIES_1.pkl...\n", + "Uploading segmented_data/combined_df_store_CA_1_dept_FOODS_2.pkl...\n", + "Uploading segmented_data/combined_df_store_TX_1_dept_HOBBIES_1.pkl...\n", + "Uploading segmented_data/combined_df_store_CA_3_dept_HOUSEHOLD_2.pkl...\n", + "Uploading segmented_data/combined_df_store_TX_1_dept_FOODS_2.pkl...\n", + "Uploading segmented_data/combined_df_store_CA_1.pkl...\n", + "Uploading segmented_data/combined_df_store_CA_2_dept_FOODS_2.pkl...\n", + "Uploading segmented_data/combined_df_store_TX_2_dept_HOUSEHOLD_2.pkl...\n", + "Uploading segmented_data/combined_df_store_WI_2.pkl...\n", + "Uploading segmented_data/combined_df_store_CA_4_dept_HOUSEHOLD_1.pkl...\n", + "Uploading segmented_data/combined_df_store_CA_3_dept_FOODS_1.pkl...\n", + "Uploading segmented_data/combined_df_store_WI_1.pkl...\n", + "Uploading segmented_data/combined_df_store_WI_1_dept_FOODS_1.pkl...\n", + "Uploading segmented_data/combined_df_store_CA_4_dept_FOODS_3.pkl...\n", + "Uploading segmented_data/combined_df_store_CA_2_dept_HOUSEHOLD_2.pkl...\n", + "Uploading segmented_data/combined_df_store_WI_2_dept_FOODS_2.pkl...\n", + "Uploading segmented_data/combined_df_store_CA_2_dept_FOODS_1.pkl...\n", + "Uploading segmented_data/combined_df_store_CA_1_dept_FOODS_1.pkl...\n", + "Uploading segmented_data/combined_df_store_TX_3.pkl...\n", + "Uploading segmented_data/combined_df_store_WI_1_dept_HOBBIES_2.pkl...\n", + "Uploading segmented_data/combined_df_store_CA_4.pkl...\n", + "Uploading segmented_data/combined_df_store_CA_1_dept_HOBBIES_1.pkl...\n", + "Uploading segmented_data/combined_df_store_WI_3_dept_HOUSEHOLD_1.pkl...\n", + "Uploading segmented_data/combined_df_store_CA_4_dept_HOUSEHOLD_2.pkl...\n", + "Uploading segmented_data/combined_df_store_CA_2_dept_HOBBIES_1.pkl...\n", + "Uploading segmented_data/combined_df_store_CA_2_dept_HOUSEHOLD_1.pkl...\n", + "Uploading segmented_data/combined_df_store_CA_4_dept_FOODS_1.pkl...\n", + "Uploading segmented_data/combined_df_store_WI_1_dept_HOUSEHOLD_1.pkl...\n", + "Uploading segmented_data/combined_df_store_CA_3_dept_FOODS_2.pkl...\n", + "Uploading segmented_data/combined_df_store_WI_1_dept_FOODS_2.pkl...\n", + "Uploading segmented_data/combined_df_store_WI_3_dept_HOBBIES_2.pkl...\n", + "Uploading segmented_data/combined_df_store_WI_3_dept_HOBBIES_1.pkl...\n", + "Uploading segmented_data/combined_df_store_WI_3_dept_HOUSEHOLD_2.pkl...\n", + "Uploading segmented_data/combined_df_store_TX_1_dept_FOODS_1.pkl...\n", + "Uploading segmented_data/combined_df_store_CA_3.pkl...\n", + "Uploading segmented_data/combined_df_store_TX_2.pkl...\n", + "Uploading segmented_data/combined_df_store_WI_2_dept_FOODS_3.pkl...\n", + "Uploading segmented_data/combined_df_store_CA_1_dept_HOUSEHOLD_2.pkl...\n", + "Uploading segmented_data/combined_df_store_WI_3_dept_FOODS_2.pkl...\n", + "Uploading segmented_data/combined_df_store_TX_2_dept_HOUSEHOLD_1.pkl...\n", + "Uploading segmented_data/combined_df_store_WI_2_dept_HOBBIES_1.pkl...\n", + "Uploading segmented_data/combined_df_store_CA_4_dept_HOBBIES_1.pkl...\n", + "Uploading segmented_data/combined_df_store_WI_2_dept_HOUSEHOLD_1.pkl...\n", + "Uploading segmented_data/combined_df_store_CA_4_dept_HOBBIES_2.pkl...\n", + "Uploading segmented_data/combined_df_store_TX_1_dept_HOUSEHOLD_1.pkl...\n", + "Uploading segmented_data/combined_df_store_WI_3_dept_FOODS_1.pkl...\n", + "Uploading segmented_data/combined_df_store_TX_3_dept_HOBBIES_2.pkl...\n", + "Uploading segmented_data/combined_df_store_CA_2_dept_HOBBIES_2.pkl...\n", + "Uploading segmented_data/combined_df_store_WI_1_dept_FOODS_3.pkl...\n", + "Uploading segmented_data/combined_df_store_TX_2_dept_FOODS_1.pkl...\n", + "Uploading segmented_data/combined_df_store_WI_2_dept_HOBBIES_2.pkl...\n", + "Uploading segmented_data/combined_df_store_CA_4_dept_FOODS_2.pkl...\n", + "Uploading segmented_data/combined_df_store_TX_3_dept_HOUSEHOLD_2.pkl...\n", + "Uploading segmented_data/combined_df_store_TX_2_dept_HOBBIES_2.pkl...\n", + "Uploading segmented_data/combined_df_store_TX_2_dept_HOBBIES_1.pkl...\n", + "Uploading segmented_data/combined_df_store_TX_2_dept_FOODS_2.pkl...\n", + "Uploading segmented_data/combined_df_store_TX_3_dept_FOODS_2.pkl...\n", + "Uploading segmented_data/combined_df_store_CA_3_dept_HOUSEHOLD_1.pkl...\n", + "Uploading segmented_data/combined_df_store_CA_1_dept_HOBBIES_2.pkl...\n", + "Uploading segmented_data/combined_df_store_TX_1.pkl...\n", + "Uploading segmented_data/combined_df_store_TX_3_dept_FOODS_1.pkl...\n", + "Uploading segmented_data/combined_df_store_CA_2.pkl...\n" + ] + } + ], + "source": [ + "# Then copy the segment to Cloud Storage\n", + "fs = gcsfs.GCSFileSystem()\n", + "\n", + "for e in segmented_data_dir.glob(\"*.pkl\"):\n", + " print(f\"Uploading {e}...\")\n", + " basename = e.name\n", + " fs.put_file(e, f\"{bucket_name}/{basename}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 64, + "id": "8357de8c-0b19-4914-9d41-67cf1a9359c4", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Also upload the product weights\n", + "fs = gcsfs.GCSFileSystem()\n", + "\n", + "weights.to_pandas().to_pickle(\"product_weights.pkl\")\n", + "fs.put_file(\"product_weights.pkl\", f\"{bucket_name}/product_weights.pkl\")" + ] + }, + { + "cell_type": "markdown", + "id": "8777971f-a46d-4954-b9c8-45a7c182ac96", + "metadata": {}, + "source": [ + "## Training and Evaluation with Hyperparameter Optimization (HPO)\n", + "\n", + "Now that we finished processing the data, we are now ready to train a model to forecast future sales. We will leverage the worker pods to run multiple training jobs in parallel, speeding up the hyperparameter search." + ] + }, + { + "cell_type": "markdown", + "id": "d14cff4c-5761-4c08-bd36-92710b2bfc32", + "metadata": {}, + "source": [ + "### Import modules and define constants" + ] + }, + { + "cell_type": "code", + "execution_count": 65, + "id": "b2d02ade-247c-4fff-bcd0-afd9ff30d026", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import cudf\n", + "import gcsfs\n", + "import xgboost as xgb\n", + "import pandas as pd\n", + "import numpy as np\n", + "import optuna\n", + "import gc\n", + "import time\n", + "import pickle\n", + "import copy\n", + "import json\n", + "\n", + "import matplotlib.pyplot as plt\n", + "from matplotlib.patches import Patch\n", + "import matplotlib\n", + "\n", + "from dask.distributed import wait\n", + "from dask_kubernetes.operator import KubeCluster\n", + "from dask.distributed import Client" + ] + }, + { + "cell_type": "code", + "execution_count": 66, + "id": "5343eb9a-2eb5-4e88-b18c-3cc5a1c84cec", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Choose the same RAPIDS image you used for launching the notebook session\n", + "rapids_image = \"rapidsai/notebooks:23.10a-cuda12.0-py3.10\"\n", + "# Use the number of worker nodes in your Kubernetes cluster.\n", + "n_workers = 2\n", + "# Bucket that contains the processed data pickles\n", + "bucket_name = \"\"\n", + "bucket_name = \"phcho-m5-competition-hpo-example\"\n", + "\n", + "# List of stores and product departments\n", + "STORES = [\n", + " \"CA_1\",\n", + " \"CA_2\",\n", + " \"CA_3\",\n", + " \"CA_4\",\n", + " \"TX_1\",\n", + " \"TX_2\",\n", + " \"TX_3\",\n", + " \"WI_1\",\n", + " \"WI_2\",\n", + " \"WI_3\",\n", + "]\n", + "DEPTS = [\n", + " \"HOBBIES_1\",\n", + " \"HOBBIES_2\",\n", + " \"HOUSEHOLD_1\",\n", + " \"HOUSEHOLD_2\",\n", + " \"FOODS_1\",\n", + " \"FOODS_2\",\n", + " \"FOODS_3\",\n", + "]" + ] + }, + { + "cell_type": "markdown", + "id": "9b4d2eea-7b13-4249-9b94-2573be07ad84", + "metadata": {}, + "source": [ + "### Define cross-validation folds\n", + "**[Cross-validation](https://en.wikipedia.org/wiki/Cross-validation_(statistics))** is a statistical method for estimating how well a machine learning model generalizes to an independent data set. The method is also useful for evaluating the choice of a given combination of model hyperparameters.\n", + "\n", + "To estimate the capacity to generalize, we define multiple cross-validation **folds** consisting of mulitple pairs of `(training set, validation set)`. For each fold, we fit a model using the training set and evaluate its accuracy on the validation set. The \"goodness\" score for a given hyperparameter combination is the accuracy of the model on each validation set, averaged over all cross-validation folds.\n", + "\n", + "Great care must be taken when defining cross-validation folds for time-series data. We are not allowed to use the future to predict the past, so the training set must precede (in time) the validation set. Consequently, we partition the data set in the time dimension and assign the training and validation sets using time ranges:" + ] + }, + { + "cell_type": "code", + "execution_count": 67, + "id": "525c26ce-df6d-42a1-90c4-2b3369105728", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Cross-validation folds and held-out test set (in time dimension)\n", + "# The held-out test set is used for final evaluation\n", + "cv_folds = [ # (train_set, validation_set)\n", + " ([0, 1114], [1114, 1314]),\n", + " ([0, 1314], [1314, 1514]),\n", + " ([0, 1514], [1514, 1714]),\n", + " ([0, 1714], [1714, 1914]),\n", + "]\n", + "n_folds = len(cv_folds)\n", + "holdout = [1914, 1942]\n", + "time_horizon = 1942" + ] + }, + { + "cell_type": "markdown", + "id": "50cd8aea-18bd-4005-be43-6906e8b31de2", + "metadata": {}, + "source": [ + "It is helpful to visualize the cross-validation folds using Matplotlib." + ] + }, + { + "cell_type": "code", + "execution_count": 68, + "id": "c0dddf7f-bc70-4d12-a108-9b1ecbb784ba", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAxQAAAEiCAYAAABgP5QIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy81sbWrAAAACXBIWXMAAA9hAAAPYQGoP6dpAAA2+UlEQVR4nO3deXxNd+L/8fdNIhHZLZEgEXtC7VvplFBLtIyOztDWIEU6Ru10MLV201W1X0Nnphp0fMuYajs/2k6lCJrY0kaVSDGxR9MaEoRE5Pz+8M0dV7ab4yY35PV8PO7j4Z7zOefzOZ987nXf95xzPxbDMAwBAAAAgAkuzm4AAAAAgLsXgQIAAACAaQQKAAAAAKYRKAAAAACYRqAAAAAAYBqBAgAAAIBpBAoAAAAAphEoAAAAAJjm5uwG3A3y8/N19uxZ+fj4yGKxOLs5AAAAgMMYhqFLly6pXr16cnEp+/kGAoUdzp49q5CQEGc3AwAAACg3p06dUoMGDcq8HYHCDj4+PpJudrKvr6+TWwMAAAA4TlZWlkJCQqyfecuKQGGHgsucfH19CRQAAAC4J5m9tJ+bsgEAAACYRqAAAAAAYBqBAgAAAIBpBAoAAAAAphEoAAAAAJhGoAAAAABgGoECAAAAgGkECgAAAACmESgAAAAAmEagAAAAAGAagQIAAACAaQQKAAAAAKa5ObsBKF1K4jTl5V5wdjMAAHZYH1dXksXZzXA6X19fzXh2trObAaACcIbiLkCYAIC7CWFCkrKyspzdBAAVhEABAAAAwDQCBQAAAADTCBQAAAAATCNQAAAAADCNQAEAAADANAIFAAAAANMIFAAAAABMq9SBIjIyUlOmTCmxTFhYmJYsWVIh7QEAAABgq1wDRXR0tCwWS6HH0aNHy7PaQj766CO1bNlSHh4eatmypT7++OMKrR8AAAC4V5X7GYqoqCilp6fbPBo1alTe1VolJiZq2LBhGjFihPbv368RI0Zo6NCh2r17d4W1AQAAALhXlXug8PDwUFBQkM3D1dVVkhQfH68uXbrIw8NDwcHBmjVrlvLy8ordV0ZGhgYNGiRPT081atRIa9asKbX+JUuWqG/fvpo9e7bCw8M1e/ZsPfTQQ1wmBQAAADiA0+6hOHPmjB5++GF17txZ+/fv1/Lly7VixQq9+OKLxW4THR2t48ePa8uWLfrHP/6hZcuWKSMjo8R6EhMT1a9fP5tl/fv3V0JCQrHb5OTkKCsry+YBAAAAoDC38q5g48aN8vb2tj4fMGCA1q9fr2XLlikkJERLly6VxWJReHi4zp49q5kzZ2revHlycbHNOj/88IM+//xz7dq1S127dpUkrVixQhERESXWf+7cOdWtW9dmWd26dXXu3Llit1m0aJEWLlxY1kMFAAAAqpxyDxS9evXS8uXLrc+9vLwkSSkpKerWrZssFot13QMPPKDLly/r9OnTCg0NtdlPSkqK3Nzc1KlTJ+uy8PBw+fv7l9qGW+uQJMMwCi271ezZszVt2jTr86ysLIWEhJRaDwAAAFDVlHug8PLyUtOmTQstL+pDvWEYkgoHgNLWlSQoKKjQ2YiMjIxCZy1u5eHhIQ8PjzLVAwAAAFRFTruHomXLlkpISLAGBUlKSEiQj4+P6tevX6h8RESE8vLytG/fPuuy1NRUXbx4scR6unXrps2bN9ss+/LLL9W9e/c7OwAAAAAAzgsU48eP16lTpzRx4kQdPnxYn376qebPn69p06YVun9Cklq0aKGoqCjFxMRo9+7dSkpK0tixY+Xp6VliPZMnT9aXX36pV199VYcPH9arr76quLi4UifMAwAAAFA6pwWK+vXr67PPPtOePXvUtm1bjRs3TmPGjNGcOXOK3SY2NlYhISHq2bOnhgwZoqefflqBgYEl1tO9e3etXbtWsbGxatOmjVauXKl169ZZb+wGAAAAYJ7FuPWaIxQpKytLfn5+yszMlK+vb4XXfyD+qQqvEwBgzvq4IGc3odJ4/oVFzm4CADvc6Wddp52hAAAAAHD3I1AAAAAAMI1AAQAAAMA0AgUAAAAA0wgUAAAAAEwjUAAAAAAwjUBxF3BzD3B2EwAAduPX2CU55WfWATiHm7MbgNJFdFvs7CYAAOzUuqezWwAAFYszFAAAAABMI1AAAAAAMI1AAQAAAMA0AgUAAAAA0wgUAAAAAEwjUAAAAAAwjUABAAAAwDQCBQAAAADTCBQAAAAATCNQAAAAADCNQAEAAADANAIFAAAAANMIFAAAAABMc3N2A1C6lMRpysu94OxmAABgl4076uhqjoski7Ob4lQWi0ULn3/Z2c0Ayh2B4i5AmAAA3E2u5rg6uwmVgmEYzm4CUCG45AkAAACAaQQKAAAAAKYRKAAAAACYRqAAAAAAYBqBAgAAAIBpBAoAAAAAphEoAAAAAJhWqQNFZGSkpkyZUmKZsLAwLVmypELaAwAAAMBWuQaK6OhoWSyWQo+jR4+WZ7U2Dh48qMcee0xhYWGyWCyEDwAAAMCByv0MRVRUlNLT020ejRo1Ku9qrbKzs9W4cWO98sorCgoKqrB6AQAAgKqg3AOFh4eHgoKCbB6urq6SpPj4eHXp0kUeHh4KDg7WrFmzlJeXV+y+MjIyNGjQIHl6eqpRo0Zas2ZNqfV37txZr7/+uh5//HF5eHg47LgAAAAASG7OqvjMmTN6+OGHFR0drdWrV+vw4cOKiYlR9erVtWDBgiK3iY6O1qlTp7Rlyxa5u7tr0qRJysjIcHjbcnJylJOTY32elZXl8DoAAACAe0G5B4qNGzfK29vb+nzAgAFav369li1bppCQEC1dulQWi0Xh4eE6e/asZs6cqXnz5snFxfbkyQ8//KDPP/9cu3btUteuXSVJK1asUEREhMPbvGjRIi1cuNDh+wUAAADuNeUeKHr16qXly5dbn3t5eUmSUlJS1K1bN1ksFuu6Bx54QJcvX9bp06cVGhpqs5+UlBS5ubmpU6dO1mXh4eHy9/d3eJtnz56tadOmWZ9nZWUpJCTE4fUAAAAAd7tyDxReXl5q2rRpoeWGYdiEiYJlkgotL22do3l4eHC/BQAAAGAHp81D0bJlSyUkJFiDgiQlJCTIx8dH9evXL1Q+IiJCeXl52rdvn3VZamqqLl68WBHNBQAAAFAEpwWK8ePH69SpU5o4caIOHz6sTz/9VPPnz9e0adMK3T8hSS1atFBUVJRiYmK0e/duJSUlaezYsfL09CyxntzcXCUnJys5OVm5ubk6c+aMkpOTK3QuDAAAAOBe5bRAUb9+fX322Wfas2eP2rZtq3HjxmnMmDGaM2dOsdvExsYqJCREPXv21JAhQ/T0008rMDCwxHrOnj2r9u3bq3379kpPT9cbb7yh9u3ba+zYsY4+JAAAAKDKsRi3XnOEImVlZcnPz0+ZmZny9fWt8PoPxD9V4XUCAGDW+jgmki3w/AuLnN0EoFR3+lnXaWcoAAAAANz9CBQAAAAATCNQAAAAADCNQAEAAADANAIFAAAAANMIFAAAAABMI1DcBdzcA5zdBAAA7ObpcUMSv0pvsVic3QSgQrg5uwEoXUS3xc5uAgAAdmvd09ktAFCROEMBAAAAwDQCBQAAAADTCBQAAAAATCNQAAAAADCNQAEAAADANAIFAAAAANMIFAAAAABMI1AAAAAAMI1AAQAAAMA0AgUAAAAA0wgUAAAAAEwjUAAAAAAwjUABAAAAwDQ3ZzcApUtJnKa83AvObgYAACiD9XF1JVmc3Qyn8/X11YxnZzu7GShHnKG4CxAmAAC4GxEmJCkrK8vZTUA5I1AAAAAAMI1AAQAAAMA0AgUAAAAA0wgUAAAAAEwjUAAAAAAwjUABAAAAwDQCBQAAAADTKnWgiIyM1JQpU0osExYWpiVLllRIewAAAADYKtdAER0dLYvFUuhx9OjR8qzWxl//+lc9+OCDCggIUEBAgPr06aM9e/ZUWP0AAADAvazcz1BERUUpPT3d5tGoUaPyrtZq27ZteuKJJ7R161YlJiYqNDRU/fr105kzZyqsDQAAAMC9qtwDhYeHh4KCgmwerq6ukqT4+Hh16dJFHh4eCg4O1qxZs5SXl1fsvjIyMjRo0CB5enqqUaNGWrNmTan1r1mzRuPHj1e7du0UHh6uv/71r8rPz9dXX33lsGMEAAAAqio3Z1V85swZPfzww4qOjtbq1at1+PBhxcTEqHr16lqwYEGR20RHR+vUqVPasmWL3N3dNWnSJGVkZJSp3uzsbF2/fl01a9YstkxOTo5ycnKsz7OysspUBwAAAFBVlHug2Lhxo7y9va3PBwwYoPXr12vZsmUKCQnR0qVLZbFYFB4errNnz2rmzJmaN2+eXFxsT5788MMP+vzzz7Vr1y517dpVkrRixQpFRESUqT2zZs1S/fr11adPn2LLLFq0SAsXLizTfgEAAICqqNwDRa9evbR8+XLrcy8vL0lSSkqKunXrJovFYl33wAMP6PLlyzp9+rRCQ0Nt9pOSkiI3Nzd16tTJuiw8PFz+/v52t+W1117Thx9+qG3btql69erFlps9e7amTZtmfZ6VlaWQkBC76wEAAACqinIPFF5eXmratGmh5YZh2ISJgmWSCi0vbZ093njjDb388suKi4tTmzZtSizr4eEhDw8PU/UAAAAAVYnT5qFo2bKlEhISrEFBkhISEuTj46P69esXKh8REaG8vDzt27fPuiw1NVUXL14sta7XX39dL7zwgr744gubMxwAAAAA7ozTAsX48eN16tQpTZw4UYcPH9ann36q+fPna9q0aYXun5CkFi1aKCoqSjExMdq9e7eSkpI0duxYeXp6lljPa6+9pjlz5uj9999XWFiYzp07p3Pnzuny5cvldWgAAABAleG0QFG/fn199tln2rNnj9q2batx48ZpzJgxmjNnTrHbxMbGKiQkRD179tSQIUP09NNPKzAwsMR6li1bptzcXP36179WcHCw9fHGG284+pAAAACAKsdi3HrNEYqUlZUlPz8/ZWZmytfXt8LrPxD/VIXXCQAA7sz6uCBnN6HSeP6FRc5uAkpwp591nXaGAgAAAMDdj0ABAAAAwDQCBQAAAADTCBQAAAAATCNQAAAAADCNQAEAAADANALFXcDNPcDZTQAAAGXGL/NLcspP7qNiuTm7AShdRLfFzm4CAAAoo9Y9nd0CoGJwhgIAAACAaQQKAAAAAKYRKAAAAACYRqAAAAAAYBqBAgAAAIBpBAoAAAAAphEoAAAAAJhGoAAAAABgGoECAAAAgGkECgAAAACmESgAAAAAmEagAAAAAGAagQIAAACAaW7ObgBKl5I4TXm5F5zdDAAAgDLZuKOOrua4SLI4uylOZbFYtPD5l53djHJDoLgLECYAAMDd6GqOq7ObUCkYhuHsJpQrLnkCAAAAYBqBAgAAAIBpBAoAAAAAphEoAAAAAJhGoAAAAABgGoECAAAAgGkECgAAAACmVepAERkZqSlTppRYJiwsTEuWLKmQ9gAAAACwVa6BIjo6WhaLpdDj6NGj5VmtjQ0bNqhTp07y9/eXl5eX2rVrpw8++KDC6gcAAADuZeU+U3ZUVJRiY2NtltWpU6e8q7WqWbOmnnvuOYWHh8vd3V0bN27UU089pcDAQPXv37/C2gEAAADci8r9kicPDw8FBQXZPFxdb07DHh8fry5dusjDw0PBwcGaNWuW8vLyit1XRkaGBg0aJE9PTzVq1Ehr1qwptf7IyEj96le/UkREhJo0aaLJkyerTZs22rlzp8OOEQAAAKiqyv0MRXHOnDmjhx9+WNHR0Vq9erUOHz6smJgYVa9eXQsWLChym+joaJ06dUpbtmyRu7u7Jk2apIyMDLvrNAxDW7ZsUWpqql599dViy+Xk5CgnJ8f6PCsry+46AAAAgKqk3APFxo0b5e3tbX0+YMAArV+/XsuWLVNISIiWLl0qi8Wi8PBwnT17VjNnztS8efPk4mJ78uSHH37Q559/rl27dqlr166SpBUrVigiIqLUNmRmZqp+/frKycmRq6urli1bpr59+xZbftGiRVq4cKHJIwYAAACqjnIPFL169dLy5cutz728vCRJKSkp6tatmywWi3XdAw88oMuXL+v06dMKDQ212U9KSorc3NzUqVMn67Lw8HD5+/uX2gYfHx8lJyfr8uXL+uqrrzRt2jQ1btxYkZGRRZafPXu2pk2bZn2elZWlkJAQew4XAAAAqFLKPVB4eXmpadOmhZYbhmETJgqWSSq0vLR1pXFxcbG2oV27dkpJSdGiRYuKDRQeHh7y8PAocz0AAABAVeO0eShatmyphIQEa1CQpISEBPn4+Kh+/fqFykdERCgvL0/79u2zLktNTdXFixfLXLdhGDb3SAAAAAAwx2mBYvz48Tp16pQmTpyow4cP69NPP9X8+fM1bdq0QvdPSFKLFi0UFRWlmJgY7d69W0lJSRo7dqw8PT1LrGfRokXavHmz/v3vf+vw4cNavHixVq9erd/+9rfldWgAAABAleG0X3mqX7++PvvsMz377LNq27atatasqTFjxmjOnDnFbhMbG6uxY8eqZ8+eqlu3rl588UXNnTu3xHquXLmi8ePH6/Tp0/L09FR4eLj+9re/adiwYY4+JAAAAKDKsRi3XnOEImVlZcnPz0+ZmZny9fWt8PoPxD9V4XUCAADcqfVxQc5uQqXx/AuLnN2EYt3pZ12nXfIEAAAA4O5HoAAAAABgGoECAAAAgGkECgAAAACmESgAAAAAmEagAAAAAGAageIu4OYe4OwmAAAAlJmnxw1JzFBgsVic3YRy5bSJ7WC/iG6Lnd0EAACAMmvd09ktQEXgDAUAAAAA0wgUAAAAAEwjUAAAAAAwjUABAAAAwDQCBQAAAADTCBQAAAAATCNQAAAAADCNQAEAAADANAIFAAAAANMIFAAAAABMI1AAAAAAMI1AAQAAAMA0AgUAAAAA09yc3QAAAADgXta692Tl5xullju47Z0KaI3jcYYCAAAAKEf2hIm7GYECAAAAgGkECgAAAACmESgAAAAAmEagAAAAAGAagQIAAACAaQQKAAAAAKYRKAAAAACY5pRAERYWpiVLlpRYxmKx6JNPPqmQ9gAAAAAwp0yBIjo6Wo8++mih5du2bZPFYtHFixcd1KzyYU+QAQAAAGA/N2c34F5hGIby8vJ048YNZzcFAIB7RrVq1eTq6ursZgAoQbkEio8++kjz5s3T0aNHFRwcrIkTJ2r69OnFlj9y5IjGjBmjPXv2qHHjxnr77bcLlTlw4IAmT56sxMRE1ahRQ4899pgWL14sb29vSVJkZKTatWtncwbi0Ucflb+/v1auXKnIyEidOHFCU6dO1dSpUyXdDAGOkJubq/T0dGVnZztkfwAA4CaLxaIGDRpY/78HUPk4PFAkJSVp6NChWrBggYYNG6aEhASNHz9etWrVUnR0dKHy+fn5GjJkiGrXrq1du3YpKytLU6ZMsSmTnZ2tqKgo3X///dq7d68yMjI0duxYTZgwQStXrrSrXRs2bFDbtm319NNPKyYmpsSyOTk5ysnJsT7Pysoqtmx+fr7S0tLk6uqqevXqyd3dXRaLxa42AQCA4hmGoZ9++kmnT59Ws2bNOFMBVFJlDhQbN24s9C3BrZf5LF68WA899JDmzp0rSWrevLkOHTqk119/vchAERcXp5SUFB0/flwNGjSQJL388ssaMGCAtcyaNWt09epVrV69Wl5eXpKkpUuXatCgQXr11VdVt27dUttds2ZNubq6ysfHR0FBQSWWXbRokRYuXFjqPqWbZyfy8/MVEhKiGjVq2LUNAACwT506dXT8+HFdv36dQAFUUmX+ladevXopOTnZ5vHee+9Z16ekpOiBBx6w2eaBBx7QkSNHiry/ICUlRaGhodYwIUndunUrVKZt27bWMFGwz/z8fKWmppb1EEo1e/ZsZWZmWh+nTp0qdRsXF36BFwAAR+OsP1D5lfkMhZeXl5o2bWqz7PTp09Z/G4ZR6MVf0r0KRa0ravvi3lAKlru4uBTa1/Xr14uttyQeHh7y8PAwtS0AAABQlTj8a/WWLVtq586dNssSEhLUvHnzIk9VtmzZUidPntTZs2etyxITEwuVSU5O1pUrV6zLvv76a7m4uKh58+aSbp4STU9Pt66/ceOGvv/+e5v9uLu78ytM5SgyMrLQ/S8lOX78uCwWi5KTk8utTbh73D5+KnK+Gua9wa2YKwkAysbhN2VPnz5dnTt31gsvvKBhw4YpMTFRS5cu1bJly4os36dPH7Vo0UIjR47Um2++qaysLD333HM2ZYYPH6758+dr1KhRWrBggX766SdNnDhRI0aMsN4/0bt3b02bNk2bNm1SkyZN9NZbbxWaFyMsLEzbt2/X448/Lg8PD9WuXdvRh28j99p53bh+qVzrKOBazUfu1WvZVba008ejRo2y+2b3W23YsEHVqlWzu3xISIjS09PL/e9wp6Kjo3Xx4kWnfni4ePGisrOvlF7QQWrU8JK/v79dZQcNGqSrV68qLi6u0LrExER1795dSUlJ6tChQ5nasHfvXpvLHB1hwYIF+uSTTwqF2PT0dAUEBDi0LkcLCwvTlClTyhTaHensj//RxcyKG4P+fl6qV7em3eWLe51u27ZNvXr10oULF+we05VJUb9gWJTyGB/21n0n7va/D4CbHB4oOnTooL///e+aN2+eXnjhBQUHB+v5558v8oZs6ealSh9//LHGjBmjLl26KCwsTO+8846ioqKsZWrUqKF//etfmjx5sjp37mzzs7EFRo8erf3792vkyJFyc3PT1KlT1atXL5u6nn/+ef3ud79TkyZNlJOT47CfjS1K7rXz+mHvbBn55i67KiuLSzU177zIrlBx65mcdevWad68eTb3onh6etqUv379ul1BoWZN+//zlyRXV9dSb5DHzTDxzttvKi8vr8LqdHNz06TJ0+36D37MmDEaMmSITpw4oYYNG9qse//999WuXbsyhwnp5lnHisI4LNnZH/+jR0a8qNzcihuD7u5u2vTBnDKFCgCAc5TpkqeVK1cW+S1tZGSkDMOwfvh47LHHdPDgQeXm5urEiROaMWOGTfnjx4/bfIvSvHlz7dixQzk5OUpNTVX//v1lGIbNrNytW7fWli1bdPXqVZ0/f15/+ctfbH5tqlq1alq2bJnOnz+vH3/8UbNmzdInn3xi8037/fffr/379+vatWvlGiYk6cb1SxUWJiTJyL9u99mQoKAg68PPz08Wi8X6/Nq1a/L399ff//53RUZGqnr16vrb3/6m8+fP64knnlCDBg1Uo0YNtW7dWh9++KHNfou6ZOXll1/W6NGj5ePjo9DQUP3lL3+xrr/9kqeCGde/+uorderUSTVq1FD37t0L3Xj/4osvKjAwUD4+Pho7dqxmzZqldu3aFXu8Fy5c0PDhw1WnTh15enqqWbNmio2Nta4/c+aMhg0bpoCAANWqVUuDBw/W8ePHJd38RnvVqlX69NNPZbFYZLFYtG3bNrv62VGys69UaJiQpLy8PLvPiAwcOFCBgYGFzmplZ2dr3bp1GjNmjF3j53a3X3Zy5MgR9ejRQ9WrV1fLli21efPmQtvMnDlTzZs3V40aNdS4cWPNnTvXei/VypUrtXDhQu3fv9/6tyxo8+2Xrxw4cEC9e/eWp6enatWqpaefflqXL1+2ro+Ojtajjz6qN954Q8HBwapVq5aeeeaZEu/b2r9/v3r16iUfHx/5+vqqY8eO2rdvn3V9QkKCevToIU9PT4WEhGjSpEnWyzxvnUenoO0V6WLmlQoNE5KUm5tXbmdESurrotgz9opy8uRJDR48WN7e3vL19dXQoUP1448/WtcXjKNbTZkyRZGRkdb18fHxevvtt61/94L3pluVND5KO9Zly5apWbNmql69uurWratf//rXZaq7pH1IN++BfO2119S4cWN5enqqbdu2+sc//iHp5v8BBV/8BQQEyGKxFPvlI4DKjZ8mQpFmzpypSZMmKSUlRf3799e1a9fUsWNHbdy4Ud9//72efvppjRgxQrt37y5xP2+++aY6deqkb7/9VuPHj9fvf/97HT58uMRtnnvuOb355pvat2+f3NzcNHr0aOu6NWvW6KWXXtKrr76qpKQkhYaGavny5SXub+7cuTp06JA+//xzpaSkaPny5dbLrLKzs9WrVy95e3tr+/bt2rlzp7y9vRUVFaXc3FzNmDFDQ4cOVVRUlNLT05Wenq7u3bvb2YtVg5ubm0aOHKmVK1faBPX169crNzdXw4cPNz1+ChTMV+Pq6qpdu3bp3Xff1cyZMwuV8/Hx0cqVK3Xo0CG9/fbb+utf/6q33npLkjRs2DBNnz5drVq1sv4thw0bVmgfBfPeBAQEaO/evVq/fr3i4uI0YcIEm3Jbt27VsWPHtHXrVq1atUorV64s8VLB4cOHq0GDBtq7d6+SkpI0a9Ys65m/AwcOqH///hoyZIi+++47rVu3Tjt37rTWuWHDBjVo0EDPP/+8te0wp7S+vp29Y+92BV+K/ec//1F8fLw2b96sY8eOFTnmivP222+rW7duiomJsf7dQ0JCCpUrbnyUdqz79u3TpEmT9Pzzzys1NVVffPGFevToUaa6S9qHJM2ZM0exsbFavny5Dh48qKlTp+q3v/2t4uPjFRISoo8++kiSlJqaqvT09CIntgVQ+ZXLTNm4+02ZMkVDhgyxWXbrmaaJEyfqiy++0Pr169W1a9di9/Pwww9r/Pjxkm6GlLfeekvbtm1TeHh4sdu89NJL6tmzpyRp1qxZeuSRR3Tt2jVVr15d//M//6MxY8boqaeekiTNmzdPX375pc23x7c7efKk2rdvr06dOkm6+c13gbVr18rFxUXvvfee9Vu92NhY+fv7a9u2berXr588PT2Vk5PDZTElGD16tF5//XXr9dDSzcudhgwZooCAAAUEBJgaPwXsma9GuvnhpUBYWJimT5+udevW6Q9/+IM8PT3l7e0tNze3Ev+W9s57ExAQoKVLl8rV1VXh4eF65JFH9NVXXxU7cebJkyf17LPPWsd+s2bNrOtef/11Pfnkk9YzfM2aNdM777yjnj17avny5WWaR6cqK22eJKn0vq5evbpNeXvH3u3i4uL03XffKS0tzfpB/IMPPlCrVq20d+9ede7cudTj8fPzk7u7u2rUqFHi37248VHasZ48eVJeXl4aOHCgfHx81LBhQ7Vv375MdZe0jytXrmjx4sXasmWL9efgGzdurJ07d+rPf/6zevbsab1UNjAwkHsogLsYZyhQpIIP3wVu3Lihl156SW3atFGtWrXk7e2tL7/8UidPnixxP23atLH+u+DSqoyMDLu3CQ4OliTrNqmpqerSpYtN+duf3+73v/+91q5dq3bt2ukPf/iDEhISrOuSkpJ09OhR+fj4yNvbW97e3qpZs6auXbumY8eOlbhf/Fd4eLi6d++u999/X5J07Ngx7dixw3p2yez4KWDPfDWS9I9//EO/+MUvFBQUJG9vb82dO9fuOm6ty555b1q1amXzy3XBwcElju1p06Zp7Nix6tOnj1555RWb8ZWUlKSVK1dax6C3t7f69++v/Px8paWllan9VVlp8yRJZe9re8begAEDrPtq1aqVdbuQkBCbb/Vbtmwpf39/paSkOPKwi1Xasfbt21cNGzZU48aNNWLECK1Zs0bZ2dllqqOkfRw6dEjXrl1T3759bdqwevVq3l+BewxnKFCk239d580339Rbb72lJUuWqHXr1vLy8tKUKVOUm5tb4n5uv5nbYrEoPz/f7m0Kzhrcuk1Z5jmRbv5nf+LECW3atElxcXF66KGH9Mwzz+iNN95Qfn6+OnbsqDVr1hTariJvCr4XjBkzRhMmTNCf/vQnxcbGqmHDhnrooYckmR8/BeyZr2bXrl16/PHHtXDhQvXv319+fn5au3at3nzzzTIdhz3z3khlH9sLFizQk08+qU2bNunzzz/X/PnztXbtWv3qV79Sfn6+fve732nSpEmFtgsNDS1T+6uy0uZJklTmvrZn7L333nu6evWqpP+Oi+LG0a3LHTl/UlFKO1Z3d3d988032rZtm7788kvNmzdPCxYs0N69e+0+W+Dj41PsPgpeD5s2bVL9+vVttmOuJ+DeQqCAXXbs2KHBgwfrt7/9raSb/1EdOXJEERERFdqOFi1aaM+ePRoxYoR12a03thanTp06io6OVnR0tB588EE9++yzeuONN9ShQwetW7dOgYGB8vX1LXJb5i+xz9ChQzV58mT97//+r1atWqWYmBjrB6c7HT+3zldTr149SYXnq/n666/VsGFDm5+dPnHihE0Ze/6WLVu21KpVq3TlyhVrsL593huzmjdvrubNm2vq1Kl64oknFBsbq1/96lfq0KGDDh48WOjDcFnbjtLZ09e3smfs3f5h+dbtTp06ZT1LcejQIWVmZlrHfZ06dQrNl5ScnGwTVu39uxdVzp5jdXNzU58+fdSnTx/Nnz9f/v7+2rJli4YMGWJ33cXto2/fvvLw8NDJkyetl7EW1W6p8KVpAO4uXPIEuzRt2lSbN29WQkKCUlJS9Lvf/U7nzp2r8HZMnDhRK1as0KpVq3TkyBG9+OKL+u6770r81Zt58+bp008/1dGjR3Xw4EFt3LjR+h/68OHDVbt2bQ0ePFg7duxQWlqa4uPjNXnyZOs3m2FhYfruu++Umpqqn3/+2aHfIN5LvL29NWzYMP3xj3/U2bNnbX6t5U7Hz63z1ezfv187duwoNF9N06ZNdfLkSa1du1bHjh3TO++8o48//timTFhYmNLS0pScnKyff/5ZOTk5heoaPny4qlevrlGjRun777/X1q1bC817U1ZXr17VhAkTtG3bNp04cUJff/219u7dax2HM2fOVGJiop555hklJyfryJEj+uc//6mJEyfatH379u06c+aMfv75Z1PtgH19fSt7xl5x27Vp00bDhw/XN998oz179mjkyJHq2bOn9ZLS3r17a9++fVq9erWOHDmi+fPnFwoYYWFh2r17t44fP66ff/652LNgRY2P0o5148aNeuedd5ScnKwTJ05o9erVys/PV4sWLeyuu6R9+Pj4aMaMGZo6dapWrVqlY8eO6dtvv9Wf/vQnrVq1SpLUsGFDWSwWbdy4UT/99FOJ98MBqLwIFLDL3Llz1aFDB/Xv31+RkZEKCgoq9HOHFWH48OGaPXu2ZsyYoQ4dOigtLU3R0dGFbqS8lbu7u2bPnq02bdqoR48ecnV11dq1ayXdnONk+/btCg0N1ZAhQxQREaHRo0fr6tWr1jMWMTExatGihTp16qQ6dero66+/rpBjvRuNGTNGFy5cUJ8+fWwuH7nT8VMwX01OTo66dOmisWPH6qWXXrIpM3jwYE2dOlUTJkxQu3btlJCQoLlz59qUeeyxxxQVFaVevXqpTp06Rf50bcG8N//5z3/UuXNn/frXv9ZDDz2kpUuXlq0zbuHq6qrz589r5MiRat68uYYOHaoBAwZo4cKFkm7eNxQfH68jR47owQcfVPv27TV37lzrPUTSzXl0jh8/riZNmnA53h2wp69vZc/YK0rBTxEHBASoR48e6tOnjxo3bqx169ZZy/Tv319z587VH/7wB3Xu3FmXLl3SyJEjbfYzY8YMubq6qmXLlqpTp06x9wQVNT5KO1Z/f39t2LBBvXv3VkREhN599119+OGH1vtA7Km7tH288MILmjdvnhYtWqSIiAj1799f/+///T81atRI0s2zOwsXLtSsWbNUt27dYn9tC7jbubhU7M99VzSLUd4TMtwDsrKy5Ofnp8zMzEKXxVy7dk1paWlq1KiRzYfayjyx3b2mb9++CgoK0gcffODsppSLyj6xHe59TGwHZyru/1kAjlPSZ117cA9FOXGvXkvNOy+ye7K5O+VazadKhIns7Gy9++676t+/v1xdXfXhhx8qLi7O7omm7kb+/v6aNHm63RPNOUKNGl6ECVjVq1tTmz6YU24TzRXF38+LMAEAdwkCRTlyr15LqgIf8iuSxWLRZ599phdffFE5OTlq0aKFPvroI/Xp08fZTStX/v7+fMCHU9WrW5MP+ACAIhEocFfx9PRUXFycs5sBAACA/8NN2QAAAABMI1AAAAAAMI1A4SD8WBYAAI7H/69A5UeguEMFM5pmZ2c7uSUAANx7cnNzJd2czwVA5cRN2XfI1dVV/v7+ysjIkHRzUqySZm0GAAD2yc/P108//aQaNWrIzY2PLEBlxavTAYKCgiTJGioAAIBjuLi4KDQ0lC/rgEqMQOEAFotFwcHBCgwM1PXrFTMzNgAAVYG7u7tcXLhCG6jMCBQO5OrqyjWeAAAAqFKI/AAAAABMI1AAAAAAMI1AAQAAAMA07qGwQ8GkOllZWU5uCQAAAOBYBZ9xzU4kSaCww6VLlyRJISEhTm4JAAAAUD4uXbokPz+/Mm9nMZjTvlT5+fk6e/asfHx8nPI72FlZWQoJCdGpU6fk6+tb4fXfK+hHx6Af7xx96Bj0o2PQj45BPzoG/egYZe1HwzB06dIl1atXz9TPNHOGwg4uLi5q0KCBs5shX19fXlwOQD86Bv145+hDx6AfHYN+dAz60THoR8coSz+aOTNRgJuyAQAAAJhGoAAAAABgGoHiLuDh4aH58+fLw8PD2U25q9GPjkE/3jn60DHoR8egHx2DfnQM+tExKrofuSkbAAAAgGmcoQAAAABgGoECAAAAgGkECgAAAACmESgquWXLlqlRo0aqXr26OnbsqB07dji7SZXGokWL1LlzZ/n4+CgwMFCPPvqoUlNTbcpER0fLYrHYPO6//36bMjk5OZo4caJq164tLy8v/fKXv9Tp06cr8lCcasGCBYX6KCgoyLreMAwtWLBA9erVk6enpyIjI3Xw4EGbfVT1PpSksLCwQv1osVj0zDPPSGIsFmf79u0aNGiQ6tWrJ4vFok8++cRmvaPG34ULFzRixAj5+fnJz89PI0aM0MWLF8v56CpOSf14/fp1zZw5U61bt5aXl5fq1aunkSNH6uzZszb7iIyMLDRGH3/8cZsyVbkfJce9jqt6Pxb1XmmxWPT6669by1T18WjPZ5zK9P5IoKjE1q1bpylTpui5557Tt99+qwcffFADBgzQyZMnnd20SiE+Pl7PPPOMdu3apc2bNysvL0/9+vXTlStXbMpFRUUpPT3d+vjss89s1k+ZMkUff/yx1q5dq507d+ry5csaOHCgbty4UZGH41StWrWy6aMDBw5Y17322mtavHixli5dqr179yooKEh9+/bVpUuXrGXoQ2nv3r02fbh582ZJ0m9+8xtrGcZiYVeuXFHbtm21dOnSItc7avw9+eSTSk5O1hdffKEvvvhCycnJGjFiRLkfX0UpqR+zs7P1zTffaO7cufrmm2+0YcMG/fDDD/rlL39ZqGxMTIzNGP3zn/9ss74q92MBR7yOq3o/3tp/6enpev/992WxWPTYY4/ZlKvK49GezziV6v3RQKXVpUsXY9y4cTbLwsPDjVmzZjmpRZVbRkaGIcmIj4+3Lhs1apQxePDgYre5ePGiUa1aNWPt2rXWZWfOnDFcXFyML774ojybW2nMnz/faNu2bZHr8vPzjaCgIOOVV16xLrt27Zrh5+dnvPvuu4Zh0IfFmTx5stGkSRMjPz/fMAzGoj0kGR9//LH1uaPG36FDhwxJxq5du6xlEhMTDUnG4cOHy/moKt7t/ViUPXv2GJKMEydOWJf17NnTmDx5crHb0I+OeR3Tj4UNHjzY6N27t80yxqOt2z/jVLb3R85QVFK5ublKSkpSv379bJb369dPCQkJTmpV5ZaZmSlJqlmzps3ybdu2KTAwUM2bN1dMTIwyMjKs65KSknT9+nWbfq5Xr57uu+++KtXPR44cUb169dSoUSM9/vjj+ve//y1JSktL07lz52z6x8PDQz179rT2D31YWG5urv72t79p9OjRslgs1uWMxbJx1PhLTEyUn5+funbtai1z//33y8/Pr8r2bWZmpiwWi/z9/W2Wr1mzRrVr11arVq00Y8YMm2866ceb7vR1TD/a+vHHH7Vp0yaNGTOm0DrG43/d/hmnsr0/upk/NJSnn3/+WTdu3FDdunVtltetW1fnzp1zUqsqL8MwNG3aNP3iF7/QfffdZ10+YMAA/eY3v1HDhg2VlpamuXPnqnfv3kpKSpKHh4fOnTsnd3d3BQQE2OyvKvVz165dtXr1ajVv3lw//vijXnzxRXXv3l0HDx609kFR4/DEiROSRB8W4ZNPPtHFixcVHR1tXcZYLDtHjb9z584pMDCw0P4DAwOrZN9eu3ZNs2bN0pNPPilfX1/r8uHDh6tRo0YKCgrS999/r9mzZ2v//v3Wy/foR8e8julHW6tWrZKPj4+GDBlis5zx+F9FfcapbO+PBIpK7tZvN6Wbg+r2ZZAmTJig7777Tjt37rRZPmzYMOu/77vvPnXq1EkNGzbUpk2bCr153aoq9fOAAQOs/27durW6deumJk2aaNWqVdabDc2Mw6rUh7dbsWKFBgwYoHr16lmXMRbNc8T4K6p8Vezb69ev6/HHH1d+fr6WLVtmsy4mJsb67/vuu0/NmjVTp06d9M0336hDhw6S6EdHvY6rej/e6v3339fw4cNVvXp1m+WMx/8q7jOOVHneH7nkqZKqXbu2XF1dC6XDjIyMQmm0qps4caL++c9/auvWrWrQoEGJZYODg9WwYUMdOXJEkhQUFKTc3FxduHDBplxV7mcvLy+1bt1aR44csf7aU0njkD60deLECcXFxWns2LEllmMsls5R4y8oKEg//vhjof3/9NNPVapvr1+/rqFDhyotLU2bN2+2OTtRlA4dOqhatWo2Y5R+tGXmdUw//teOHTuUmppa6vulVHXHY3GfcSrb+yOBopJyd3dXx44draf2CmzevFndu3d3UqsqF8MwNGHCBG3YsEFbtmxRo0aNSt3m/PnzOnXqlIKDgyVJHTt2VLVq1Wz6OT09Xd9//32V7eecnBylpKQoODjYerr51v7Jzc1VfHy8tX/oQ1uxsbEKDAzUI488UmI5xmLpHDX+unXrpszMTO3Zs8daZvfu3crMzKwyfVsQJo4cOaK4uDjVqlWr1G0OHjyo69evW8co/ViYmdcx/fhfK1asUMeOHdW2bdtSy1a18VjaZ5xK9/5o//3lqGhr1641qlWrZqxYscI4dOiQMWXKFMPLy8s4fvy4s5tWKfz+9783/Pz8jG3bthnp6enWR3Z2tmEYhnHp0iVj+vTpRkJCgpGWlmZs3brV6Natm1G/fn0jKyvLup9x48YZDRo0MOLi4oxvvvnG6N27t9G2bVsjLy/PWYdWoaZPn25s27bN+Pe//23s2rXLGDhwoOHj42MdZ6+88orh5+dnbNiwwThw4IDxxBNPGMHBwfRhEW7cuGGEhoYaM2fOtFnOWCzepUuXjG+//db49ttvDUnG4sWLjW+//db660OOGn9RUVFGmzZtjMTERCMxMdFo3bq1MXDgwAo/3vJSUj9ev37d+OUvf2k0aNDASE5Otnm/zMnJMQzDMI4ePWosXLjQ2Lt3r5GWlmZs2rTJCA8PN9q3b08//l8/OvJ1XJX7sUBmZqZRo0YNY/ny5YW2ZzyW/hnHMCrX+yOBopL705/+ZDRs2NBwd3c3OnToYPOTqFWdpCIfsbGxhmEYRnZ2ttGvXz+jTp06RrVq1YzQ0FBj1KhRxsmTJ232c/XqVWPChAlGzZo1DU9PT2PgwIGFytzLhg0bZgQHBxvVqlUz6tWrZwwZMsQ4ePCgdX1+fr4xf/58IygoyPDw8DB69OhhHDhwwGYfVb0PC/zrX/8yJBmpqak2yxmLxdu6dWuRr+NRo0YZhuG48Xf+/Hlj+PDhho+Pj+Hj42MMHz7cuHDhQgUdZfkrqR/T0tKKfb/cunWrYRiGcfLkSaNHjx5GzZo1DXd3d6NJkybGpEmTjPPnz9vUU5X70ZGv46rcjwX+/Oc/G56ensbFixcLbc94LP0zjmFUrvdHy/81GgAAAADKjHsoAAAAAJhGoAAAAABgGoECAAAAgGkECgAAAACmESgAAAAAmEagAAAAAGAagQIAAACAaQQKAAAAAKYRKAAATrFgwQK1a9fO2c0AANwhZsoGADicxWIpcf2oUaO0dOlS5eTkqFatWhXUKgBAeSBQAAAc7ty5c9Z/r1u3TvPmzVNqaqp1maenp/z8/JzRNACAg3HJEwDA4YKCgqwPPz8/WSyWQstuv+QpOjpajz76qF5++WXVrVtX/v7+WrhwofLy8vTss8+qZs2aatCggd5//32bus6cOaNhw4YpICBAtWrV0uDBg3X8+PGKPWAAqMIIFACASmPLli06e/astm/frsWLF2vBggUaOHCgAgICtHv3bo0bN07jxo3TqVOnJEnZ2dnq1auXvL29tX37du3cuVPe3t6KiopSbm6uk48GAKoGAgUAoNKoWbOm3nnnHbVo0UKjR49WixYtlJ2drT/+8Y9q1qyZZs+eLXd3d3399deSpLVr18rFxUXvvfeeWrdurYiICMXGxurkyZPatm2bcw8GAKoIN2c3AACAAq1atZKLy3+/66pbt67uu+8+63NXV1fVqlVLGRkZkqSkpCQdPXpUPj4+Nvu5du2ajh07VjGNBoAqjkABAKg0qlWrZvPcYrEUuSw/P1+SlJ+fr44dO2rNmjWF9lWnTp3yaygAwIpAAQC4a3Xo0EHr1q1TYGCgfH19nd0cAKiSuIcCAHDXGj58uGrXrq3Bgwdrx44dSktLU3x8vCZPnqzTp087u3kAUCUQKAAAd60aNWpo+/btCg0N1ZAhQxQREaHRo0fr6tWrnLEAgArCxHYAAAAATOMMBQAAAADTCBQAAAAATCNQAAAAADCNQAEAAADANAIFAAAAANMIFAAAAABMI1AAAAAAMI1AAQAAAMA0AgUAAAAA0wgUAAAAAEwjUAAAAAAwjUABAAAAwLT/D9blFiP5mbuxAAAAAElFTkSuQmCC", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "cv_cmap = matplotlib.colormaps[\"cividis\"]\n", + "plt.figure(figsize=(8, 3))\n", + "\n", + "for i, (train_mask, valid_mask) in enumerate(cv_folds):\n", + " idx = np.array([np.nan] * time_horizon)\n", + " idx[np.arange(*train_mask)] = 1\n", + " idx[np.arange(*valid_mask)] = 0\n", + " plt.scatter(\n", + " range(time_horizon),\n", + " [i + 0.5] * time_horizon,\n", + " c=idx,\n", + " marker=\"_\",\n", + " capstyle=\"butt\",\n", + " s=1,\n", + " lw=20,\n", + " cmap=cv_cmap,\n", + " vmin=-1.5,\n", + " vmax=1.5,\n", + " )\n", + "\n", + "idx = np.array([np.nan] * time_horizon)\n", + "idx[np.arange(*holdout)] = -1\n", + "plt.scatter(\n", + " range(time_horizon),\n", + " [n_folds + 0.5] * time_horizon,\n", + " c=idx,\n", + " marker=\"_\",\n", + " capstyle=\"butt\",\n", + " s=1,\n", + " lw=20,\n", + " cmap=cv_cmap,\n", + " vmin=-1.5,\n", + " vmax=1.5,\n", + ")\n", + "\n", + "plt.xlabel(\"Time\")\n", + "plt.yticks(\n", + " ticks=np.arange(n_folds + 1) + 0.5,\n", + " labels=[f\"Fold {i}\" for i in range(n_folds)] + [\"Holdout\"],\n", + ")\n", + "plt.ylim([len(cv_folds) + 1.2, -0.2])\n", + "\n", + "norm = matplotlib.colors.Normalize(vmin=-1.5, vmax=1.5)\n", + "plt.legend(\n", + " [\n", + " Patch(color=cv_cmap(norm(1))),\n", + " Patch(color=cv_cmap(norm(0))),\n", + " Patch(color=cv_cmap(norm(-1))),\n", + " ],\n", + " [\"Training set\", \"Validation set\", \"Held-out test set\"],\n", + " ncol=3,\n", + " loc=\"best\",\n", + ")\n", + "plt.tight_layout()" + ] + }, + { + "cell_type": "markdown", + "id": "2b5dffdb-01ae-4703-950d-1ee7ea0b77dd", + "metadata": {}, + "source": [ + "### Launch a Dask client on Kubernetes\n", + "Let us set up a Dask cluster using the `KubeCluster` class." + ] + }, + { + "cell_type": "code", + "execution_count": 69, + "id": "580f18c9-1e99-4c43-a244-8db3bfe3c1ab", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "9a276ff2156246e0a62c3d40e862ddac", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "cluster = KubeCluster(\n",
+    "    name=\"rapids-dask\",\n",
+    "    image=rapids_image,\n",
+    "    worker_command=\"dask-cuda-worker\",\n",
+    "    n_workers=n_workers,\n",
+    "    resources={\"limits\": {\"nvidia.com/gpu\": \"1\"}},\n",
+    "    env={\"DISABLE_JUPYTER\": \"true\", \"EXTRA_PIP_PACKAGES\": \"optuna gcsfs\"},\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 70,
+   "id": "54e8254c-f29e-467d-95b0-f111122deffc",
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "faa42202d7cd441d9a82072fa9d25ac5",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/html": [
+       "
\n", + "
\n", + "
\n", + "
\n", + "

KubeCluster

\n", + "

rapids-dask

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + " Dashboard: http://rapids-dask-scheduler.kubeflow-user-example-com:8787/status\n", + " \n", + " Workers: 0\n", + "
\n", + " Total threads: 0\n", + " \n", + " Total memory: 0 B\n", + "
\n", + "\n", + "
\n", + " \n", + "

Scheduler Info

\n", + "
\n", + "\n", + "
\n", + "
\n", + "
\n", + "
\n", + "

Scheduler

\n", + "

Scheduler-ca08b554-394b-419d-b380-7b3ccdfd882f

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + " Comm: tcp://10.36.0.26:8786\n", + " \n", + " Workers: 0\n", + "
\n", + " Dashboard: http://10.36.0.26:8787/status\n", + " \n", + " Total threads: 0\n", + "
\n", + " Started: Just now\n", + " \n", + " Total memory: 0 B\n", + "
\n", + "
\n", + "
\n", + "\n", + "
\n", + " \n", + "

Workers

\n", + "
\n", + "\n", + " \n", + "\n", + "
\n", + "
\n", + "\n", + "
\n", + "
\n", + "
" + ], + "text/plain": [ + "KubeCluster(rapids-dask, 'tcp://rapids-dask-scheduler.kubeflow-user-example-com:8786', workers=0, threads=0, memory=0 B)" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "cluster" + ] + }, + { + "cell_type": "code", + "execution_count": 73, + "id": "ad8be605-68bc-45d6-962e-dd3b4c5e8658", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "
\n", + "
\n", + "

Client

\n", + "

Client-dcc90a0f-5e32-11ee-8529-fa610e3dbb88

\n", + " \n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + "
Connection method: Cluster objectCluster type: dask_kubernetes.KubeCluster
\n", + " Dashboard: http://rapids-dask-scheduler.kubeflow-user-example-com:8787/status\n", + "
\n", + "\n", + " \n", + " \n", + " \n", + "\n", + " \n", + "
\n", + "

Cluster Info

\n", + "
\n", + "
\n", + "
\n", + "
\n", + "

KubeCluster

\n", + "

rapids-dask

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + " Dashboard: http://rapids-dask-scheduler.kubeflow-user-example-com:8787/status\n", + " \n", + " Workers: 1\n", + "
\n", + " Total threads: 1\n", + " \n", + " Total memory: 83.48 GiB\n", + "
\n", + "\n", + "
\n", + " \n", + "

Scheduler Info

\n", + "
\n", + "\n", + "
\n", + "
\n", + "
\n", + "
\n", + "

Scheduler

\n", + "

Scheduler-ca08b554-394b-419d-b380-7b3ccdfd882f

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
\n", + " Comm: tcp://10.36.0.26:8786\n", + " \n", + " Workers: 1\n", + "
\n", + " Dashboard: http://10.36.0.26:8787/status\n", + " \n", + " Total threads: 1\n", + "
\n", + " Started: Just now\n", + " \n", + " Total memory: 83.48 GiB\n", + "
\n", + "
\n", + "
\n", + "\n", + "
\n", + " \n", + "

Workers

\n", + "
\n", + "\n", + " \n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "

Worker: rapids-dask-default-worker-4f5ea8bc10

\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + " \n", + "\n", + "
\n", + " Comm: tcp://10.36.2.23:34385\n", + " \n", + " Total threads: 1\n", + "
\n", + " Dashboard: http://10.36.2.23:8788/status\n", + " \n", + " Memory: 83.48 GiB\n", + "
\n", + " Nanny: tcp://10.36.2.23:44783\n", + "
\n", + " Local directory: /tmp/dask-scratch-space/worker-h2_cxa0d\n", + "
\n", + " GPU: NVIDIA A100-SXM4-40GB\n", + " \n", + " GPU memory: 40.00 GiB\n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "\n", + "
\n", + "
\n", + "\n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "\n", + "
\n", + "
" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 73, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "client = Client(cluster)\n", + "client" + ] + }, + { + "cell_type": "markdown", + "id": "fb7527b9-0039-43e5-bb25-985501480313", + "metadata": {}, + "source": [ + "### Define the custom evaluation metric\n", + "\n", + "The M5 forecasting competition defines a custom metric called WRMSSE as follows:\n", + "\n", + "$$\n", + "WRMSSE = \\sum w_i \\cdot RMSSE_i\n", + "$$\n", + "\n", + "i.e. WRMSEE is a weighted sum of RMSSE for all product items $i$. RMSSE is in turn defined to be\n", + "\n", + "$$\n", + "RMSSE = \\sqrt{\\frac{1/h \\cdot \\sum_t{\\left(Y_t - \\hat{Y}_t\\right)}^2}{1/(n-1)\\sum_t{(Y_t - Y_{t-1})}^2}}\n", + "$$\n", + "\n", + "where the squared error of the prediction (forecast) is normalized by the speed at which the sales amount changes per unit in the training data.\n", + "\n", + "Here is the implementation of the WRMSSE using cuDF. We use the product weights $w_i$ as computed in the first preprocessing notebook." + ] + }, + { + "cell_type": "code", + "execution_count": 74, + "id": "a3d7f300-310d-4c09-b30a-4d0c54903d23", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def wrmsse(product_weights, df, pred_sales, train_mask, valid_mask):\n", + " \"\"\"Compute WRMSSE metric\"\"\"\n", + " df_train = df[(df[\"day_id\"] >= train_mask[0]) & (df[\"day_id\"] < train_mask[1])]\n", + " df_valid = df[(df[\"day_id\"] >= valid_mask[0]) & (df[\"day_id\"] < valid_mask[1])]\n", + "\n", + " # Compute denominator: 1/(n-1) * sum( (y(t) - y(t-1))**2 )\n", + " diff = (\n", + " df_train.sort_values([\"item_id\", \"day_id\"])\n", + " .groupby([\"item_id\"])[[\"sales\"]]\n", + " .diff(1)\n", + " )\n", + " x = (\n", + " df_train[[\"item_id\", \"day_id\"]]\n", + " .join(diff, how=\"left\")\n", + " .rename(columns={\"sales\": \"diff\"})\n", + " .sort_values([\"item_id\", \"day_id\"])\n", + " )\n", + " x[\"diff\"] = x[\"diff\"] ** 2\n", + " xx = x.groupby([\"item_id\"])[[\"diff\"]].agg([\"sum\", \"count\"]).sort_index()\n", + " xx.columns = xx.columns.map(\"_\".join)\n", + " xx[\"denominator\"] = xx[\"diff_sum\"] / xx[\"diff_count\"]\n", + " t = xx.reset_index()\n", + "\n", + " # Compute numerator: 1/h * sum( (y(t) - y_pred(t))**2 )\n", + " X_valid = df_valid.drop(columns=[\"item_id\", \"cat_id\", \"day_id\", \"sales\"])\n", + " if \"dept_id\" in X_valid.columns:\n", + " X_valid = X_valid.drop(columns=[\"dept_id\"])\n", + " df_pred = cudf.DataFrame(\n", + " {\n", + " \"item_id\": df_valid[\"item_id\"].copy(),\n", + " \"pred_sales\": pred_sales,\n", + " \"sales\": df_valid[\"sales\"].copy(),\n", + " }\n", + " )\n", + " df_pred[\"diff\"] = (df_pred[\"sales\"] - df_pred[\"pred_sales\"]) ** 2\n", + " yy = df_pred.groupby([\"item_id\"])[[\"diff\"]].agg([\"sum\", \"count\"]).sort_index()\n", + " yy.columns = yy.columns.map(\"_\".join)\n", + " yy[\"numerator\"] = yy[\"diff_sum\"] / yy[\"diff_count\"]\n", + "\n", + " zz = yy[[\"numerator\"]].join(xx[[\"denominator\"]], how=\"left\")\n", + " zz = zz.join(product_weights, how=\"left\").sort_index()\n", + " # Filter out zero denominator.\n", + " # This can occur if the product was never on sale during the period in the training set\n", + " zz = zz[zz[\"denominator\"] != 0]\n", + " zz[\"rmsse\"] = np.sqrt(zz[\"numerator\"] / zz[\"denominator\"])\n", + " t = zz[\"rmsse\"].multiply(zz[\"weights\"])\n", + " return zz[\"rmsse\"].multiply(zz[\"weights\"]).sum()" + ] + }, + { + "cell_type": "markdown", + "id": "17ed80bd-b6db-459d-b28a-7cf59ba2bdad", + "metadata": {}, + "source": [ + "### Define the training and hyperparameter search pipeline using Optuna\n", + "Optuna lets us define the training procedure iteratively, i.e. as if we were to write an ordinary function to train a single model. Instead of a fixed hyperparameter combination, the function now takes in a `trial` object which yields different hyperparameter combinations.\n", + "\n", + "In this example, we partition the training data according to the store and then fit a separate XGBoost model per data segment." + ] + }, + { + "cell_type": "code", + "execution_count": 75, + "id": "5f538176-2b69-4224-8570-4c0aedfde4d4", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def objective(trial):\n", + " fs = gcsfs.GCSFileSystem()\n", + " with fs.open(f\"{bucket_name}/product_weights.pkl\", \"rb\") as f:\n", + " product_weights = cudf.DataFrame(pd.read_pickle(f))\n", + " params = {\n", + " \"n_estimators\": 100,\n", + " \"verbosity\": 0,\n", + " \"learning_rate\": 0.01,\n", + " \"objective\": \"reg:tweedie\",\n", + " \"tree_method\": \"gpu_hist\",\n", + " \"grow_policy\": \"depthwise\",\n", + " \"predictor\": \"gpu_predictor\",\n", + " \"enable_categorical\": True,\n", + " \"lambda\": trial.suggest_float(\"lambda\", 1e-8, 100.0, log=True),\n", + " \"alpha\": trial.suggest_float(\"alpha\", 1e-8, 100.0, log=True),\n", + " \"colsample_bytree\": trial.suggest_float(\"colsample_bytree\", 0.2, 1.0),\n", + " \"max_depth\": trial.suggest_int(\"max_depth\", 2, 6, step=1),\n", + " \"min_child_weight\": trial.suggest_float(\n", + " \"min_child_weight\", 1e-8, 100, log=True\n", + " ),\n", + " \"gamma\": trial.suggest_float(\"gamma\", 1e-8, 1.0, log=True),\n", + " \"tweedie_variance_power\": trial.suggest_float(\"tweedie_variance_power\", 1, 2),\n", + " }\n", + " scores = [[] for store in STORES]\n", + "\n", + " for store_id, store in enumerate(STORES):\n", + " print(f\"Processing store {store}...\")\n", + " with fs.open(f\"{bucket_name}/combined_df_store_{store}.pkl\", \"rb\") as f:\n", + " df = cudf.DataFrame(pd.read_pickle(f))\n", + " for train_mask, valid_mask in cv_folds:\n", + " df_train = df[\n", + " (df[\"day_id\"] >= train_mask[0]) & (df[\"day_id\"] < train_mask[1])\n", + " ]\n", + " df_valid = df[\n", + " (df[\"day_id\"] >= valid_mask[0]) & (df[\"day_id\"] < valid_mask[1])\n", + " ]\n", + "\n", + " X_train, y_train = (\n", + " df_train.drop(\n", + " columns=[\"item_id\", \"dept_id\", \"cat_id\", \"day_id\", \"sales\"]\n", + " ),\n", + " df_train[\"sales\"],\n", + " )\n", + " X_valid = df_valid.drop(\n", + " columns=[\"item_id\", \"dept_id\", \"cat_id\", \"day_id\", \"sales\"]\n", + " )\n", + "\n", + " clf = xgb.XGBRegressor(**params)\n", + " clf.fit(X_train, y_train)\n", + " pred_sales = clf.predict(X_valid)\n", + " scores[store_id].append(\n", + " wrmsse(product_weights, df, pred_sales, train_mask, valid_mask)\n", + " )\n", + " del df_train, df_valid, X_train, y_train, clf\n", + " gc.collect()\n", + " del df\n", + " gc.collect()\n", + "\n", + " # We can sum WRMSSE scores over data segments because data segments contain disjoint sets of time series\n", + " return np.array(scores).sum(axis=0).mean()" + ] + }, + { + "cell_type": "markdown", + "id": "f2e4df5d-17a5-4b1c-8454-b5c9c1e631cf", + "metadata": {}, + "source": [ + "Using the Dask cluster client, we execute multiple training jobs in parallel. Optuna keeps track of the progress in the hyperparameter search using in-memory Dask storage." + ] + }, + { + "cell_type": "code", + "execution_count": 76, + "id": "6ba27561-5a31-4568-b341-1befc29516e8", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_1321/3389696366.py:7: ExperimentalWarning: DaskStorage is experimental (supported from v3.1.0). The interface can change in the future.\n", + " dask_storage = optuna.integration.DaskStorage(storage=backend_storage, client=client)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Testing hyperparameter combinations 0..2\n", + "Best cross-validation metric: 10.027767173304472, Time elapsed = 331.6198390149948\n", + "Testing hyperparameter combinations 2..4\n", + "Best cross-validation metric: 9.426913749927916, Time elapsed = 640.7606940959959\n", + "Testing hyperparameter combinations 4..6\n", + "Best cross-validation metric: 9.426913749927916, Time elapsed = 958.0816706369951\n", + "Testing hyperparameter combinations 6..8\n", + "Best cross-validation metric: 9.426913749927916, Time elapsed = 1295.700604706988\n", + "Testing hyperparameter combinations 8..9\n", + "Best cross-validation metric: 8.915009508695244, Time elapsed = 1476.1182343699911\n", + "Total time elapsed = 1476.1219055669935\n" + ] + } + ], + "source": [ + "##### Number of hyperparameter combinations to try in parallel\n", + "n_trials = 9 # Using a small n_trials so that the demo can finish quickly\n", + "# n_trials = 100\n", + "\n", + "# Optimize in parallel on your Dask cluster\n", + "backend_storage = optuna.storages.InMemoryStorage()\n", + "dask_storage = optuna.integration.DaskStorage(storage=backend_storage, client=client)\n", + "study = optuna.create_study(\n", + " direction=\"minimize\",\n", + " sampler=optuna.samplers.RandomSampler(seed=0),\n", + " storage=dask_storage,\n", + ")\n", + "futures = []\n", + "for i in range(0, n_trials, n_workers):\n", + " iter_range = (i, min([i + n_workers, n_trials]))\n", + " futures.append(\n", + " {\n", + " \"range\": iter_range,\n", + " \"futures\": [\n", + " client.submit(\n", + " # Work around bug https://github.com/optuna/optuna/issues/4859\n", + " lambda objective, n_trials: (\n", + " study.sampler.reseed_rng(),\n", + " study.optimize(objective, n_trials),\n", + " ),\n", + " objective,\n", + " n_trials=1,\n", + " pure=False,\n", + " )\n", + " for _ in range(*iter_range)\n", + " ],\n", + " }\n", + " )\n", + "\n", + "tstart = time.perf_counter()\n", + "for partition in futures:\n", + " iter_range = partition[\"range\"]\n", + " print(f\"Testing hyperparameter combinations {iter_range[0]}..{iter_range[1]}\")\n", + " _ = wait(partition[\"futures\"])\n", + " for fut in partition[\"futures\"]:\n", + " _ = fut.result() # Ensure that the training job was successful\n", + " tnow = time.perf_counter()\n", + " print(\n", + " f\"Best cross-validation metric: {study.best_value}, Time elapsed = {tnow - tstart}\"\n", + " )\n", + "tend = time.perf_counter()\n", + "print(f\"Total time elapsed = {tend - tstart}\")" + ] + }, + { + "cell_type": "markdown", + "id": "baa92e3e-73e6-4de5-92b0-f89aee77cbc8", + "metadata": {}, + "source": [ + "Once the hyperparameter search is complete, we fetch the optimal hyperparameter combination using the attributes of the `study` object." + ] + }, + { + "cell_type": "code", + "execution_count": 77, + "id": "2a01c20d-9581-4c36-bcb6-c9440f79b652", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'lambda': 2.6232990699579064e-06,\n", + " 'alpha': 0.004085800094564677,\n", + " 'colsample_bytree': 0.4064535567263888,\n", + " 'max_depth': 6,\n", + " 'min_child_weight': 9.652128310148716e-08,\n", + " 'gamma': 3.4446109254037165e-07,\n", + " 'tweedie_variance_power': 1.0914258082324833}" + ] + }, + "execution_count": 77, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "study.best_params" + ] + }, + { + "cell_type": "code", + "execution_count": 78, + "id": "6de40f4c-2b29-49da-881d-76c9be3424cf", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "FrozenTrial(number=8, state=TrialState.COMPLETE, values=[8.915009508695244], datetime_start=datetime.datetime(2023, 9, 28, 19, 35, 29, 888497), datetime_complete=datetime.datetime(2023, 9, 28, 19, 38, 30, 299541), params={'lambda': 2.6232990699579064e-06, 'alpha': 0.004085800094564677, 'colsample_bytree': 0.4064535567263888, 'max_depth': 6, 'min_child_weight': 9.652128310148716e-08, 'gamma': 3.4446109254037165e-07, 'tweedie_variance_power': 1.0914258082324833}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'lambda': FloatDistribution(high=100.0, log=True, low=1e-08, step=None), 'alpha': FloatDistribution(high=100.0, log=True, low=1e-08, step=None), 'colsample_bytree': FloatDistribution(high=1.0, log=False, low=0.2, step=None), 'max_depth': IntDistribution(high=6, log=False, low=2, step=1), 'min_child_weight': FloatDistribution(high=100.0, log=True, low=1e-08, step=None), 'gamma': FloatDistribution(high=1.0, log=True, low=1e-08, step=None), 'tweedie_variance_power': FloatDistribution(high=2.0, log=False, low=1.0, step=None)}, trial_id=8, value=None)" + ] + }, + "execution_count": 78, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "study.best_trial" + ] + }, + { + "cell_type": "code", + "execution_count": 79, + "id": "e0e1e8a4-846a-4d70-bc84-b9e20bebe067", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'lambda': 2.6232990699579064e-06,\n", + " 'alpha': 0.004085800094564677,\n", + " 'colsample_bytree': 0.4064535567263888,\n", + " 'max_depth': 6,\n", + " 'min_child_weight': 9.652128310148716e-08,\n", + " 'gamma': 3.4446109254037165e-07,\n", + " 'tweedie_variance_power': 1.0914258082324833}" + ] + }, + "execution_count": 79, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Make a deep copy to preserve the dictionary after deleting the Dask cluster\n", + "best_params = copy.deepcopy(study.best_params)\n", + "best_params" + ] + }, + { + "cell_type": "code", + "execution_count": 80, + "id": "354ab00c-e340-4de3-a3cf-308b476cf78a", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "fs = gcsfs.GCSFileSystem()\n", + "with fs.open(f\"{bucket_name}/params.json\", \"w\") as f:\n", + " json.dump(best_params, f)" + ] + }, + { + "cell_type": "markdown", + "id": "a47ec686-1b1b-49ac-8b87-012e0fcbead7", + "metadata": {}, + "source": [ + "### Train the final XGBoost model and evaluate\n", + "Using the optimal hyperparameters found in the search, fit a new model using the whole training data. As in the previous section, we fit a separate XGBoost model per data segment." + ] + }, + { + "cell_type": "code", + "execution_count": 81, + "id": "fafe5c72-576e-45f2-9f9d-3e5cf72874fc", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "fs = gcsfs.GCSFileSystem()\n", + "with fs.open(f\"{bucket_name}/params.json\", \"r\") as f:\n", + " best_params = json.load(f)\n", + "with fs.open(f\"{bucket_name}/product_weights.pkl\", \"rb\") as f:\n", + " product_weights = cudf.DataFrame(pd.read_pickle(f))" + ] + }, + { + "cell_type": "code", + "execution_count": 82, + "id": "f1783769-b19e-4515-8ff0-5e1b662bfcee", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def final_train(best_params):\n", + " fs = gcsfs.GCSFileSystem()\n", + " params = {\n", + " \"n_estimators\": 100,\n", + " \"verbosity\": 0,\n", + " \"learning_rate\": 0.01,\n", + " \"objective\": \"reg:tweedie\",\n", + " \"tree_method\": \"gpu_hist\",\n", + " \"grow_policy\": \"depthwise\",\n", + " \"predictor\": \"gpu_predictor\",\n", + " \"enable_categorical\": True,\n", + " }\n", + " params.update(best_params)\n", + " model = {}\n", + " train_mask = [0, 1914]\n", + "\n", + " for store in STORES:\n", + " print(f\"Processing store {store}...\")\n", + " with fs.open(f\"{bucket_name}/combined_df_store_{store}.pkl\", \"rb\") as f:\n", + " df = cudf.DataFrame(pd.read_pickle(f))\n", + "\n", + " df_train = df[(df[\"day_id\"] >= train_mask[0]) & (df[\"day_id\"] < train_mask[1])]\n", + " X_train, y_train = (\n", + " df_train.drop(columns=[\"item_id\", \"dept_id\", \"cat_id\", \"day_id\", \"sales\"]),\n", + " df_train[\"sales\"],\n", + " )\n", + "\n", + " clf = xgb.XGBRegressor(**params)\n", + " clf.fit(X_train, y_train)\n", + " model[store] = clf\n", + " del df\n", + " gc.collect()\n", + "\n", + " return model" + ] + }, + { + "cell_type": "code", + "execution_count": 83, + "id": "3fcaadeb-fa8f-4266-9280-e0f25cf1afca", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Processing store CA_1...\n", + "Processing store CA_2...\n", + "Processing store CA_3...\n", + "Processing store CA_4...\n", + "Processing store TX_1...\n", + "Processing store TX_2...\n", + "Processing store TX_3...\n", + "Processing store WI_1...\n", + "Processing store WI_2...\n", + "Processing store WI_3...\n" + ] + } + ], + "source": [ + "model = final_train(best_params)" + ] + }, + { + "cell_type": "markdown", + "id": "57da12bb-5801-46b5-9eef-16f62ed79737", + "metadata": {}, + "source": [ + "Let's now evaluate the final model using the held-out test set:" + ] + }, + { + "cell_type": "code", + "execution_count": 84, + "id": "ab4c5ede-cf84-4a62-a1d9-ce18d9f1d4a3", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "WRMSSE metric on the held-out test set: 9.478942050051291\n" + ] + } + ], + "source": [ + "test_wrmsse = 0\n", + "for store in STORES:\n", + " with fs.open(f\"{bucket_name}/combined_df_store_{store}.pkl\", \"rb\") as f:\n", + " df = cudf.DataFrame(pd.read_pickle(f))\n", + " df_test = df[(df[\"day_id\"] >= holdout[0]) & (df[\"day_id\"] < holdout[1])]\n", + " X_test = df_test.drop(columns=[\"item_id\", \"dept_id\", \"cat_id\", \"day_id\", \"sales\"])\n", + " pred_sales = model[store].predict(X_test)\n", + " test_wrmsse += wrmsse(\n", + " product_weights, df, pred_sales, train_mask=[0, 1914], valid_mask=holdout\n", + " )\n", + "print(f\"WRMSSE metric on the held-out test set: {test_wrmsse}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 85, + "id": "987923e1-d26f-44f7-8d5a-559f28f38bb3", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Save the model to the Cloud Storage\n", + "with fs.open(f\"{bucket_name}/final_model.pkl\", \"wb\") as f:\n", + " pickle.dump(model, f)" + ] + }, + { + "cell_type": "markdown", + "id": "e8e55536-4540-495c-a930-d54afb7eecaa", + "metadata": {}, + "source": [ + "## Create an ensemble model using a different strategy for segmenting sales data\n", + "It is common to create an ensemble model where multiple machine learning methods are used to obtain better predictive performance. Prediction is made from an ensemble model by averaging the prediction output of the constituent models.\n", + "\n", + "In this example, we will create a second model by segmenting the sales data in a different way. Instead of splitting by stores, we will split the data by both stores and product categories." + ] + }, + { + "cell_type": "code", + "execution_count": 86, + "id": "f8180260-cae0-4af3-8fdc-ebd654e7683d", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def objective_alt(trial):\n", + " fs = gcsfs.GCSFileSystem()\n", + " with fs.open(f\"{bucket_name}/product_weights.pkl\", \"rb\") as f:\n", + " product_weights = cudf.DataFrame(pd.read_pickle(f))\n", + " params = {\n", + " \"n_estimators\": 100,\n", + " \"verbosity\": 0,\n", + " \"learning_rate\": 0.01,\n", + " \"objective\": \"reg:tweedie\",\n", + " \"tree_method\": \"gpu_hist\",\n", + " \"grow_policy\": \"depthwise\",\n", + " \"predictor\": \"gpu_predictor\",\n", + " \"enable_categorical\": True,\n", + " \"lambda\": trial.suggest_float(\"lambda\", 1e-8, 100.0, log=True),\n", + " \"alpha\": trial.suggest_float(\"alpha\", 1e-8, 100.0, log=True),\n", + " \"colsample_bytree\": trial.suggest_float(\"colsample_bytree\", 0.2, 1.0),\n", + " \"max_depth\": trial.suggest_int(\"max_depth\", 2, 6, step=1),\n", + " \"min_child_weight\": trial.suggest_float(\n", + " \"min_child_weight\", 1e-8, 100, log=True\n", + " ),\n", + " \"gamma\": trial.suggest_float(\"gamma\", 1e-8, 1.0, log=True),\n", + " \"tweedie_variance_power\": trial.suggest_float(\"tweedie_variance_power\", 1, 2),\n", + " }\n", + " scores = [[] for i in range(len(STORES) * len(DEPTS))]\n", + "\n", + " for store_id, store in enumerate(STORES):\n", + " for dept_id, dept in enumerate(DEPTS):\n", + " print(f\"Processing store {store}, department {dept}...\")\n", + " with fs.open(\n", + " f\"{bucket_name}/combined_df_store_{store}_dept_{dept}.pkl\", \"rb\"\n", + " ) as f:\n", + " df = cudf.DataFrame(pd.read_pickle(f))\n", + " for train_mask, valid_mask in cv_folds:\n", + " df_train = df[\n", + " (df[\"day_id\"] >= train_mask[0]) & (df[\"day_id\"] < train_mask[1])\n", + " ]\n", + " df_valid = df[\n", + " (df[\"day_id\"] >= valid_mask[0]) & (df[\"day_id\"] < valid_mask[1])\n", + " ]\n", + "\n", + " X_train, y_train = (\n", + " df_train.drop(columns=[\"item_id\", \"cat_id\", \"day_id\", \"sales\"]),\n", + " df_train[\"sales\"],\n", + " )\n", + " X_valid = df_valid.drop(\n", + " columns=[\"item_id\", \"cat_id\", \"day_id\", \"sales\"]\n", + " )\n", + "\n", + " clf = xgb.XGBRegressor(**params)\n", + " clf.fit(X_train, y_train)\n", + " sales_pred = clf.predict(X_valid)\n", + " scores[store_id * len(DEPTS) + dept_id].append(\n", + " wrmsse(product_weights, df, sales_pred, train_mask, valid_mask)\n", + " )\n", + " del df_train, df_valid, X_train, y_train, clf\n", + " gc.collect()\n", + " del df\n", + " gc.collect()\n", + "\n", + " # We can sum WRMSSE scores over data segments because data segments contain disjoint sets of time series\n", + " return np.array(scores).sum(axis=0).mean()" + ] + }, + { + "cell_type": "code", + "execution_count": 87, + "id": "2c65c62f-97ee-4eb7-8334-6fa82eea83fd", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_1321/491731696.py:7: ExperimentalWarning: DaskStorage is experimental (supported from v3.1.0). The interface can change in the future.\n", + " dask_storage = optuna.integration.DaskStorage(storage=backend_storage, client=client)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Testing hyperparameter combinations 0..2\n", + "Best cross-validation metric: 9.896445497438858, Time elapsed = 802.2191872399999\n", + "Testing hyperparameter combinations 2..4\n", + "Best cross-validation metric: 9.896445497438858, Time elapsed = 1494.0718872279976\n", + "Testing hyperparameter combinations 4..6\n", + "Best cross-validation metric: 9.835407407395302, Time elapsed = 2393.3159628150024\n", + "Testing hyperparameter combinations 6..8\n", + "Best cross-validation metric: 9.330048901795887, Time elapsed = 3092.471466117\n", + "Testing hyperparameter combinations 8..9\n", + "Best cross-validation metric: 9.330048901795887, Time elapsed = 3459.9082761530008\n", + "Total time elapsed = 3459.911843854992\n" + ] + } + ], + "source": [ + "##### Number of hyperparameter combinations to try in parallel\n", + "n_trials = 9 # Using a small n_trials so that the demo can finish quickly\n", + "# n_trials = 100\n", + "\n", + "# Optimize in parallel on your Dask cluster\n", + "backend_storage = optuna.storages.InMemoryStorage()\n", + "dask_storage = optuna.integration.DaskStorage(storage=backend_storage, client=client)\n", + "study = optuna.create_study(\n", + " direction=\"minimize\",\n", + " sampler=optuna.samplers.RandomSampler(seed=0),\n", + " storage=dask_storage,\n", + ")\n", + "futures = []\n", + "for i in range(0, n_trials, n_workers):\n", + " iter_range = (i, min([i + n_workers, n_trials]))\n", + " futures.append(\n", + " {\n", + " \"range\": iter_range,\n", + " \"futures\": [\n", + " client.submit(\n", + " # Work around bug https://github.com/optuna/optuna/issues/4859\n", + " lambda objective, n_trials: (\n", + " study.sampler.reseed_rng(),\n", + " study.optimize(objective, n_trials),\n", + " ),\n", + " objective_alt,\n", + " n_trials=1,\n", + " pure=False,\n", + " )\n", + " for _ in range(*iter_range)\n", + " ],\n", + " }\n", + " )\n", + "\n", + "tstart = time.perf_counter()\n", + "for partition in futures:\n", + " iter_range = partition[\"range\"]\n", + " print(f\"Testing hyperparameter combinations {iter_range[0]}..{iter_range[1]}\")\n", + " _ = wait(partition[\"futures\"])\n", + " for fut in partition[\"futures\"]:\n", + " _ = fut.result() # Ensure that the training job was successful\n", + " tnow = time.perf_counter()\n", + " print(\n", + " f\"Best cross-validation metric: {study.best_value}, Time elapsed = {tnow - tstart}\"\n", + " )\n", + "tend = time.perf_counter()\n", + "print(f\"Total time elapsed = {tend - tstart}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 88, + "id": "b728d1d2-7180-428a-bbdd-249e6471c070", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'lambda': 0.028794929327421122,\n", + " 'alpha': 3.3150619134761685e-07,\n", + " 'colsample_bytree': 0.42330433646728755,\n", + " 'max_depth': 2,\n", + " 'min_child_weight': 0.09713314395591004,\n", + " 'gamma': 0.0016337599227941016,\n", + " 'tweedie_variance_power': 1.1915217521234043}" + ] + }, + "execution_count": 88, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Make a deep copy to preserve the dictionary after deleting the Dask cluster\n", + "best_params_alt = copy.deepcopy(study.best_params)\n", + "best_params_alt" + ] + }, + { + "cell_type": "code", + "execution_count": 89, + "id": "8c65edfa-ed67-4c49-b823-80cb72a7f6e5", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "fs = gcsfs.GCSFileSystem()\n", + "with fs.open(f\"{bucket_name}/params_alt.json\", \"w\") as f:\n", + " json.dump(best_params_alt, f)" + ] + }, + { + "cell_type": "markdown", + "id": "b2d7ba79-72fc-4298-b92e-6eb07b266bfd", + "metadata": {}, + "source": [ + "Using the optimal hyperparameters found in the search, fit a new model using the whole training data." + ] + }, + { + "cell_type": "code", + "execution_count": 90, + "id": "896d9342-794d-4b55-9b64-f644e528a9c0", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def final_train_alt(best_params):\n", + " fs = gcsfs.GCSFileSystem()\n", + " with fs.open(f\"{bucket_name}/product_weights.pkl\", \"rb\") as f:\n", + " product_weights = cudf.DataFrame(pd.read_pickle(f))\n", + " params = {\n", + " \"n_estimators\": 100,\n", + " \"verbosity\": 0,\n", + " \"learning_rate\": 0.01,\n", + " \"objective\": \"reg:tweedie\",\n", + " \"tree_method\": \"gpu_hist\",\n", + " \"grow_policy\": \"depthwise\",\n", + " \"predictor\": \"gpu_predictor\",\n", + " \"enable_categorical\": True,\n", + " }\n", + " params.update(best_params)\n", + " model = {}\n", + " train_mask = [0, 1914]\n", + "\n", + " for store_id, store in enumerate(STORES):\n", + " for dept_id, dept in enumerate(DEPTS):\n", + " print(f\"Processing store {store}, department {dept}...\")\n", + " with fs.open(\n", + " f\"{bucket_name}/combined_df_store_{store}_dept_{dept}.pkl\", \"rb\"\n", + " ) as f:\n", + " df = cudf.DataFrame(pd.read_pickle(f))\n", + " for train_mask, valid_mask in cv_folds:\n", + " df_train = df[\n", + " (df[\"day_id\"] >= train_mask[0]) & (df[\"day_id\"] < train_mask[1])\n", + " ]\n", + " X_train, y_train = (\n", + " df_train.drop(columns=[\"item_id\", \"cat_id\", \"day_id\", \"sales\"]),\n", + " df_train[\"sales\"],\n", + " )\n", + "\n", + " clf = xgb.XGBRegressor(**params)\n", + " clf.fit(X_train, y_train)\n", + " model[(store, dept)] = clf\n", + " del df\n", + " gc.collect()\n", + "\n", + " return model" + ] + }, + { + "cell_type": "code", + "execution_count": 91, + "id": "0abb7b4e-9865-42ad-824a-cf8b64c7c4f3", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "fs = gcsfs.GCSFileSystem()\n", + "with fs.open(f\"{bucket_name}/params_alt.json\", \"r\") as f:\n", + " best_params_alt = json.load(f)\n", + "with fs.open(f\"{bucket_name}/product_weights.pkl\", \"rb\") as f:\n", + " product_weights = cudf.DataFrame(pd.read_pickle(f))" + ] + }, + { + "cell_type": "code", + "execution_count": 92, + "id": "c7b5a383-1358-4201-9588-d799cf21822a", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Processing store CA_1, department HOBBIES_1...\n", + "Processing store CA_1, department HOBBIES_2...\n", + "Processing store CA_1, department HOUSEHOLD_1...\n", + "Processing store CA_1, department HOUSEHOLD_2...\n", + "Processing store CA_1, department FOODS_1...\n", + "Processing store CA_1, department FOODS_2...\n", + "Processing store CA_1, department FOODS_3...\n", + "Processing store CA_2, department HOBBIES_1...\n", + "Processing store CA_2, department HOBBIES_2...\n", + "Processing store CA_2, department HOUSEHOLD_1...\n", + "Processing store CA_2, department HOUSEHOLD_2...\n", + "Processing store CA_2, department FOODS_1...\n", + "Processing store CA_2, department FOODS_2...\n", + "Processing store CA_2, department FOODS_3...\n", + "Processing store CA_3, department HOBBIES_1...\n", + "Processing store CA_3, department HOBBIES_2...\n", + "Processing store CA_3, department HOUSEHOLD_1...\n", + "Processing store CA_3, department HOUSEHOLD_2...\n", + "Processing store CA_3, department FOODS_1...\n", + "Processing store CA_3, department FOODS_2...\n", + "Processing store CA_3, department FOODS_3...\n", + "Processing store CA_4, department HOBBIES_1...\n", + "Processing store CA_4, department HOBBIES_2...\n", + "Processing store CA_4, department HOUSEHOLD_1...\n", + "Processing store CA_4, department HOUSEHOLD_2...\n", + "Processing store CA_4, department FOODS_1...\n", + "Processing store CA_4, department FOODS_2...\n", + "Processing store CA_4, department FOODS_3...\n", + "Processing store TX_1, department HOBBIES_1...\n", + "Processing store TX_1, department HOBBIES_2...\n", + "Processing store TX_1, department HOUSEHOLD_1...\n", + "Processing store TX_1, department HOUSEHOLD_2...\n", + "Processing store TX_1, department FOODS_1...\n", + "Processing store TX_1, department FOODS_2...\n", + "Processing store TX_1, department FOODS_3...\n", + "Processing store TX_2, department HOBBIES_1...\n", + "Processing store TX_2, department HOBBIES_2...\n", + "Processing store TX_2, department HOUSEHOLD_1...\n", + "Processing store TX_2, department HOUSEHOLD_2...\n", + "Processing store TX_2, department FOODS_1...\n", + "Processing store TX_2, department FOODS_2...\n", + "Processing store TX_2, department FOODS_3...\n", + "Processing store TX_3, department HOBBIES_1...\n", + "Processing store TX_3, department HOBBIES_2...\n", + "Processing store TX_3, department HOUSEHOLD_1...\n", + "Processing store TX_3, department HOUSEHOLD_2...\n", + "Processing store TX_3, department FOODS_1...\n", + "Processing store TX_3, department FOODS_2...\n", + "Processing store TX_3, department FOODS_3...\n", + "Processing store WI_1, department HOBBIES_1...\n", + "Processing store WI_1, department HOBBIES_2...\n", + "Processing store WI_1, department HOUSEHOLD_1...\n", + "Processing store WI_1, department HOUSEHOLD_2...\n", + "Processing store WI_1, department FOODS_1...\n", + "Processing store WI_1, department FOODS_2...\n", + "Processing store WI_1, department FOODS_3...\n", + "Processing store WI_2, department HOBBIES_1...\n", + "Processing store WI_2, department HOBBIES_2...\n", + "Processing store WI_2, department HOUSEHOLD_1...\n", + "Processing store WI_2, department HOUSEHOLD_2...\n", + "Processing store WI_2, department FOODS_1...\n", + "Processing store WI_2, department FOODS_2...\n", + "Processing store WI_2, department FOODS_3...\n", + "Processing store WI_3, department HOBBIES_1...\n", + "Processing store WI_3, department HOBBIES_2...\n", + "Processing store WI_3, department HOUSEHOLD_1...\n", + "Processing store WI_3, department HOUSEHOLD_2...\n", + "Processing store WI_3, department FOODS_1...\n", + "Processing store WI_3, department FOODS_2...\n", + "Processing store WI_3, department FOODS_3...\n" + ] + } + ], + "source": [ + "model_alt = final_train_alt(best_params_alt)" + ] + }, + { + "cell_type": "code", + "execution_count": 93, + "id": "9b72487c-e311-40e8-ae52-b6b23a9d391e", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Save the model to the Cloud Storage\n", + "with fs.open(f\"{bucket_name}/final_model_alt.pkl\", \"wb\") as f:\n", + " pickle.dump(model_alt, f)" + ] + }, + { + "cell_type": "markdown", + "id": "b4eb7307-96a3-4f1b-9356-13027b2779a3", + "metadata": {}, + "source": [ + "Now consider an ensemble consisting of the two models `model` and `model_alt`. We evaluate the ensemble by computing the WRMSSE metric for the average of the predictions of the two models." + ] + }, + { + "cell_type": "code", + "execution_count": 94, + "id": "9c746944-5e79-492e-b81f-8f9a556d0df9", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Processing store CA_1...\n", + "Processing store CA_2...\n", + "Processing store CA_3...\n", + "Processing store CA_4...\n", + "Processing store TX_1...\n", + "Processing store TX_2...\n", + "Processing store TX_3...\n", + "Processing store WI_1...\n", + "Processing store WI_2...\n", + "Processing store WI_3...\n", + "WRMSSE metric on the held-out test set: 10.69187847848366\n" + ] + } + ], + "source": [ + "test_wrmsse = 0\n", + "for store in STORES:\n", + " print(f\"Processing store {store}...\")\n", + " # Prediction from Model 1\n", + " with fs.open(f\"{bucket_name}/combined_df_store_{store}.pkl\", \"rb\") as f:\n", + " df = cudf.DataFrame(pd.read_pickle(f))\n", + " df_test = df[(df[\"day_id\"] >= holdout[0]) & (df[\"day_id\"] < holdout[1])]\n", + " X_test = df_test.drop(columns=[\"item_id\", \"dept_id\", \"cat_id\", \"day_id\", \"sales\"])\n", + " df_test[\"pred1\"] = model[store].predict(X_test)\n", + "\n", + " # Prediction from Model 2\n", + " df_test[\"pred2\"] = [np.nan] * len(df_test)\n", + " df_test[\"pred2\"] = df_test[\"pred2\"].astype(\"float32\")\n", + " for dept in DEPTS:\n", + " with fs.open(\n", + " f\"{bucket_name}/combined_df_store_{store}_dept_{dept}.pkl\", \"rb\"\n", + " ) as f:\n", + " df2 = cudf.DataFrame(pd.read_pickle(f))\n", + " df2_test = df2[(df2[\"day_id\"] >= holdout[0]) & (df2[\"day_id\"] < holdout[1])]\n", + " X_test = df2_test.drop(columns=[\"item_id\", \"cat_id\", \"day_id\", \"sales\"])\n", + " assert np.sum(df_test[\"dept_id\"] == dept) == len(X_test)\n", + " df_test[\"pred2\"][df_test[\"dept_id\"] == dept] = model_alt[(store, dept)].predict(\n", + " X_test\n", + " )\n", + "\n", + " # Average prediction\n", + " df_test[\"avg_pred\"] = (df_test[\"pred1\"] + df_test[\"pred2\"]) / 2.0\n", + "\n", + " test_wrmsse += wrmsse(\n", + " product_weights,\n", + " df,\n", + " df_test[\"avg_pred\"],\n", + " train_mask=[0, 1914],\n", + " valid_mask=holdout,\n", + " )\n", + "print(f\"WRMSSE metric on the held-out test set: {test_wrmsse}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 95, + "id": "af6cc91f-036b-4503-a784-0e413c681727", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Close the Dask cluster to clean up\n", + "cluster.close()" + ] + }, + { + "cell_type": "markdown", + "id": "727454aa-06fd-4724-9ee0-7874f44bb88b", + "metadata": {}, + "source": [ + "## Conclusion" + ] + }, + { + "cell_type": "markdown", + "id": "19524812-c005-4d39-bc62-64a681fb3164", + "metadata": {}, + "source": [ + "We demonstrated an end-to-end workflow where we take a real-world time-series data and train a forecasting model using Google Kubernetes Engine (GKE). We were able to speed up the hyperparameter optimization (HPO) process by dispatching parallel training jobs to NVIDIA GPUs." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7b83283f-0e56-4b97-b7aa-024a386fc9b2", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/source/examples/time-series-forecasting-with-hpo/training_and_evaluation.ipynb b/source/examples/time-series-forecasting-with-hpo/training_and_evaluation.ipynb deleted file mode 100644 index 556f0814..00000000 --- a/source/examples/time-series-forecasting-with-hpo/training_and_evaluation.ipynb +++ /dev/null @@ -1,1781 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "328140a4-b668-4a5b-8de1-24c4392d22f1", - "metadata": {}, - "source": [ - "# Train an XGBoost model with retail sales forecasting with hyperparameter search" - ] - }, - { - "cell_type": "markdown", - "id": "770519b0-25d2-4786-933f-f0e16b5c4b18", - "metadata": {}, - "source": [ - "Now that we finished processing the data, we are now ready to train a model to forecast future sales. We will leverage the worker pods to run multiple training jobs in parallel, speeding up the hyperparameter search." - ] - }, - { - "cell_type": "markdown", - "id": "da47c5a5-3777-4bf7-8312-1da91a0d81a3", - "metadata": {}, - "source": [ - "## Import modules and define constants" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "c391837b-6102-4d45-820e-b83091285a71", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "import cudf\n", - "import gcsfs\n", - "import xgboost as xgb\n", - "import pandas as pd\n", - "import numpy as np\n", - "import optuna\n", - "import gc\n", - "import time\n", - "import pickle\n", - "import copy\n", - "import json\n", - "\n", - "import matplotlib.pyplot as plt\n", - "from matplotlib.patches import Patch\n", - "import matplotlib\n", - "\n", - "from dask.distributed import wait\n", - "from dask_kubernetes.operator import KubeCluster\n", - "from dask.distributed import Client" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "1954ff43-dcd5-4628-80f2-64f85f253342", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "# Choose the same RAPIDS image you used for launching the notebook session\n", - "rapids_image = (\n", - " \"rapidsai/rapidsai-core-nightly:23.08-cuda11.8-runtime-ubuntu22.04-py3.10\"\n", - ")\n", - "# Use the number of worker nodes in your Kubernetes cluster.\n", - "n_workers = 3\n", - "# Bucket that contains the processed data pickles, refer to start_here.ipynb\n", - "bucket_name = \"\"\n", - "\n", - "# List of stores and product departments\n", - "STORES = [\n", - " \"CA_1\",\n", - " \"CA_2\",\n", - " \"CA_3\",\n", - " \"CA_4\",\n", - " \"TX_1\",\n", - " \"TX_2\",\n", - " \"TX_3\",\n", - " \"WI_1\",\n", - " \"WI_2\",\n", - " \"WI_3\",\n", - "]\n", - "DEPTS = [\n", - " \"HOBBIES_1\",\n", - " \"HOBBIES_2\",\n", - " \"HOUSEHOLD_1\",\n", - " \"HOUSEHOLD_2\",\n", - " \"FOODS_1\",\n", - " \"FOODS_2\",\n", - " \"FOODS_3\",\n", - "]" - ] - }, - { - "cell_type": "markdown", - "id": "084ac137-1fb0-43e3-9a61-7aa259d05ca8", - "metadata": {}, - "source": [ - "## Define cross-validation folds" - ] - }, - { - "cell_type": "markdown", - "id": "cee3f346-afbd-41ea-a91e-12294594b6a6", - "metadata": {}, - "source": [ - "[**Cross-validation**](https://en.wikipedia.org/wiki/Cross-validation_(statistics)) is a statistical method for estimating how well a machine learning model generalizes to an independent data set. The method is also useful for evaluating the choice of a given combination of model hyperparameters.\n", - "\n", - "To estimate the capacity to generalize, we define multiple cross-validation **folds** consisting of mulitple pairs of `(training set, validation set)`. For each fold, we fit a model using the training set and evaluate its accuracy on the validation set. The \"goodness\" score for a given hyperparameter combination is the accuracy of the model on each validation set, averaged over all cross-validation folds.\n", - "\n", - "Great care must be taken when defining cross-validation folds for time-series data. We are not allowed to use the future to predict the past, so the training set must precede (in time) the validation set. Consequently, we partition the data set in the time dimension and assign the training and validation sets using time ranges:" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "423780b7-8431-4546-85c7-84e4f269d5e1", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "# Cross-validation folds and held-out test set (in time dimension)\n", - "# The held-out test set is used for final evaluation\n", - "cv_folds = [ # (train_set, validation_set)\n", - " ([0, 1114], [1114, 1314]),\n", - " ([0, 1314], [1314, 1514]),\n", - " ([0, 1514], [1514, 1714]),\n", - " ([0, 1714], [1714, 1914]),\n", - "]\n", - "n_folds = len(cv_folds)\n", - "holdout = [1914, 1942]\n", - "time_horizon = 1942" - ] - }, - { - "cell_type": "markdown", - "id": "d5ed120a-ae41-4211-a517-7791e22df037", - "metadata": {}, - "source": [ - "It is helpful to visualize the cross-validation folds using Matplotlib." - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "f7945265-f683-4722-8e22-df003b749eae", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAxQAAAEiCAYAAABgP5QIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAAA2+UlEQVR4nO3deXxNd+L/8fdNIhHZLZEgEXtC7VvplFBLtIyOztDWIEU6Ru10MLV201W1X0Nnphp0fMuYajs/2k6lCJrY0kaVSDGxR9MaEoRE5Pz+8M0dV7ab4yY35PV8PO7j4Z7zOefzOZ987nXf95xzPxbDMAwBAAAAgAkuzm4AAAAAgLsXgQIAAACAaQQKAAAAAKYRKAAAAACYRqAAAAAAYBqBAgAAAIBpBAoAAAAAphEoAAAAAJjm5uwG3A3y8/N19uxZ+fj4yGKxOLs5AAAAgMMYhqFLly6pXr16cnEp+/kGAoUdzp49q5CQEGc3AwAAACg3p06dUoMGDcq8HYHCDj4+PpJudrKvr6+TWwMAAAA4TlZWlkJCQqyfecuKQGGHgsucfH19CRQAAAC4J5m9tJ+bsgEAAACYRqAAAAAAYBqBAgAAAIBpBAoAAAAAphEoAAAAAJhGoAAAAABgGoECAAAAgGkECgAAAACmESgAAAAAmEagAAAAAGAagQIAAACAaQQKAAAAAKa5ObsBKF1K4jTl5V5wdjMAAHZYH1dXksXZzXA6X19fzXh2trObAaACcIbiLkCYAIC7CWFCkrKyspzdBAAVhEABAAAAwDQCBQAAAADTCBQAAAAATCNQAAAAADCNQAEAAADANAIFAAAAANMIFAAAAABMq9SBIjIyUlOmTCmxTFhYmJYsWVIh7QEAAABgq1wDRXR0tCwWS6HH0aNHy7PaQj766CO1bNlSHh4eatmypT7++OMKrR8AAAC4V5X7GYqoqCilp6fbPBo1alTe1VolJiZq2LBhGjFihPbv368RI0Zo6NCh2r17d4W1AQAAALhXlXug8PDwUFBQkM3D1dVVkhQfH68uXbrIw8NDwcHBmjVrlvLy8ordV0ZGhgYNGiRPT081atRIa9asKbX+JUuWqG/fvpo9e7bCw8M1e/ZsPfTQQ1wmBQAAADiA0+6hOHPmjB5++GF17txZ+/fv1/Lly7VixQq9+OKLxW4THR2t48ePa8uWLfrHP/6hZcuWKSMjo8R6EhMT1a9fP5tl/fv3V0JCQrHb5OTkKCsry+YBAAAAoDC38q5g48aN8vb2tj4fMGCA1q9fr2XLlikkJERLly6VxWJReHi4zp49q5kzZ2revHlycbHNOj/88IM+//xz7dq1S127dpUkrVixQhERESXWf+7cOdWtW9dmWd26dXXu3Llit1m0aJEWLlxY1kMFAAAAqpxyDxS9evXS8uXLrc+9vLwkSSkpKerWrZssFot13QMPPKDLly/r9OnTCg0NtdlPSkqK3Nzc1KlTJ+uy8PBw+fv7l9qGW+uQJMMwCi271ezZszVt2jTr86ysLIWEhJRaDwAAAFDVlHug8PLyUtOmTQstL+pDvWEYkgoHgNLWlSQoKKjQ2YiMjIxCZy1u5eHhIQ8PjzLVAwAAAFRFTruHomXLlkpISLAGBUlKSEiQj4+P6tevX6h8RESE8vLytG/fPuuy1NRUXbx4scR6unXrps2bN9ss+/LLL9W9e/c7OwAAAAAAzgsU48eP16lTpzRx4kQdPnxYn376qebPn69p06YVun9Cklq0aKGoqCjFxMRo9+7dSkpK0tixY+Xp6VliPZMnT9aXX36pV199VYcPH9arr76quLi4UifMAwAAAFA6pwWK+vXr67PPPtOePXvUtm1bjRs3TmPGjNGcOXOK3SY2NlYhISHq2bOnhgwZoqefflqBgYEl1tO9e3etXbtWsbGxatOmjVauXKl169ZZb+wGAAAAYJ7FuPWaIxQpKytLfn5+yszMlK+vb4XXfyD+qQqvEwBgzvq4IGc3odJ4/oVFzm4CADvc6Wddp52hAAAAAHD3I1AAAAAAMI1AAQAAAMA0AgUAAAAA0wgUAAAAAEwjUAAAAAAwjUBxF3BzD3B2EwAAduPX2CU55WfWATiHm7MbgNJFdFvs7CYAAOzUuqezWwAAFYszFAAAAABMI1AAAAAAMI1AAQAAAMA0AgUAAAAA0wgUAAAAAEwjUAAAAAAwjUABAAAAwDQCBQAAAADTCBQAAAAATCNQAAAAADCNQAEAAADANAIFAAAAANMIFAAAAABMc3N2A1C6lMRpysu94OxmAABgl4076uhqjoski7Ob4lQWi0ULn3/Z2c0Ayh2B4i5AmAAA3E2u5rg6uwmVgmEYzm4CUCG45AkAAACAaQQKAAAAAKYRKAAAAACYRqAAAAAAYBqBAgAAAIBpBAoAAAAAphEoAAAAAJhWqQNFZGSkpkyZUmKZsLAwLVmypELaAwAAAMBWuQaK6OhoWSyWQo+jR4+WZ7U2Dh48qMcee0xhYWGyWCyEDwAAAMCByv0MRVRUlNLT020ejRo1Ku9qrbKzs9W4cWO98sorCgoKqrB6AQAAgKqg3AOFh4eHgoKCbB6urq6SpPj4eHXp0kUeHh4KDg7WrFmzlJeXV+y+MjIyNGjQIHl6eqpRo0Zas2ZNqfV37txZr7/+uh5//HF5eHg47LgAAAAASG7OqvjMmTN6+OGHFR0drdWrV+vw4cOKiYlR9erVtWDBgiK3iY6O1qlTp7Rlyxa5u7tr0qRJysjIcHjbcnJylJOTY32elZXl8DoAAACAe0G5B4qNGzfK29vb+nzAgAFav369li1bppCQEC1dulQWi0Xh4eE6e/asZs6cqXnz5snFxfbkyQ8//KDPP/9cu3btUteuXSVJK1asUEREhMPbvGjRIi1cuNDh+wUAAADuNeUeKHr16qXly5dbn3t5eUmSUlJS1K1bN1ksFuu6Bx54QJcvX9bp06cVGhpqs5+UlBS5ubmpU6dO1mXh4eHy9/d3eJtnz56tadOmWZ9nZWUpJCTE4fUAAAAAd7tyDxReXl5q2rRpoeWGYdiEiYJlkgotL22do3l4eHC/BQAAAGAHp81D0bJlSyUkJFiDgiQlJCTIx8dH9evXL1Q+IiJCeXl52rdvn3VZamqqLl68WBHNBQAAAFAEpwWK8ePH69SpU5o4caIOHz6sTz/9VPPnz9e0adMK3T8hSS1atFBUVJRiYmK0e/duJSUlaezYsfL09CyxntzcXCUnJys5OVm5ubk6c+aMkpOTK3QuDAAAAOBe5bRAUb9+fX322Wfas2eP2rZtq3HjxmnMmDGaM2dOsdvExsYqJCREPXv21JAhQ/T0008rMDCwxHrOnj2r9u3bq3379kpPT9cbb7yh9u3ba+zYsY4+JAAAAKDKsRi3XnOEImVlZcnPz0+ZmZny9fWt8PoPxD9V4XUCAGDW+jgmki3w/AuLnN0EoFR3+lnXaWcoAAAAANz9CBQAAAAATCNQAAAAADCNQAEAAADANAIFAAAAANMIFAAAAABMI1DcBdzcA5zdBAAA7ObpcUMSv0pvsVic3QSgQrg5uwEoXUS3xc5uAgAAdmvd09ktAFCROEMBAAAAwDQCBQAAAADTCBQAAAAATCNQAAAAADCNQAEAAADANAIFAAAAANMIFAAAAABMI1AAAAAAMI1AAQAAAMA0AgUAAAAA0wgUAAAAAEwjUAAAAAAwjUABAAAAwDQ3ZzcApUtJnKa83AvObgYAACiD9XF1JVmc3Qyn8/X11YxnZzu7GShHnKG4CxAmAAC4GxEmJCkrK8vZTUA5I1AAAAAAMI1AAQAAAMA0AgUAAAAA0wgUAAAAAEwjUAAAAAAwjUABAAAAwDQCBQAAAADTKnWgiIyM1JQpU0osExYWpiVLllRIewAAAADYKtdAER0dLYvFUuhx9OjR8qzWxl//+lc9+OCDCggIUEBAgPr06aM9e/ZUWP0AAADAvazcz1BERUUpPT3d5tGoUaPyrtZq27ZteuKJJ7R161YlJiYqNDRU/fr105kzZyqsDQAAAMC9qtwDhYeHh4KCgmwerq6ukqT4+Hh16dJFHh4eCg4O1qxZs5SXl1fsvjIyMjRo0CB5enqqUaNGWrNmTan1r1mzRuPHj1e7du0UHh6uv/71r8rPz9dXX33lsGMEAAAAqio3Z1V85swZPfzww4qOjtbq1at1+PBhxcTEqHr16lqwYEGR20RHR+vUqVPasmWL3N3dNWnSJGVkZJSp3uzsbF2/fl01a9YstkxOTo5ycnKsz7OysspUBwAAAFBVlHug2Lhxo7y9va3PBwwYoPXr12vZsmUKCQnR0qVLZbFYFB4errNnz2rmzJmaN2+eXFxsT5788MMP+vzzz7Vr1y517dpVkrRixQpFRESUqT2zZs1S/fr11adPn2LLLFq0SAsXLizTfgEAAICqqNwDRa9evbR8+XLrcy8vL0lSSkqKunXrJovFYl33wAMP6PLlyzp9+rRCQ0Nt9pOSkiI3Nzd16tTJuiw8PFz+/v52t+W1117Thx9+qG3btql69erFlps9e7amTZtmfZ6VlaWQkBC76wEAAACqinIPFF5eXmratGmh5YZh2ISJgmWSCi0vbZ093njjDb388suKi4tTmzZtSizr4eEhDw8PU/UAAAAAVYnT5qFo2bKlEhISrEFBkhISEuTj46P69esXKh8REaG8vDzt27fPuiw1NVUXL14sta7XX39dL7zwgr744gubMxwAAAAA7ozTAsX48eN16tQpTZw4UYcPH9ann36q+fPna9q0aYXun5CkFi1aKCoqSjExMdq9e7eSkpI0duxYeXp6lljPa6+9pjlz5uj9999XWFiYzp07p3Pnzuny5cvldWgAAABAleG0QFG/fn199tln2rNnj9q2batx48ZpzJgxmjNnTrHbxMbGKiQkRD179tSQIUP09NNPKzAwsMR6li1bptzcXP36179WcHCw9fHGG284+pAAAACAKsdi3HrNEYqUlZUlPz8/ZWZmytfXt8LrPxD/VIXXCQAA7sz6uCBnN6HSeP6FRc5uAkpwp591nXaGAgAAAMDdj0ABAAAAwDQCBQAAAADTCBQAAAAATCNQAAAAADCNQAEAAADANALFXcDNPcDZTQAAAGXGL/NLcspP7qNiuTm7AShdRLfFzm4CAAAoo9Y9nd0CoGJwhgIAAACAaQQKAAAAAKYRKAAAAACYRqAAAAAAYBqBAgAAAIBpBAoAAAAAphEoAAAAAJhGoAAAAABgGoECAAAAgGkECgAAAACmESgAAAAAmEagAAAAAGAagQIAAACAaW7ObgBKl5I4TXm5F5zdDAAAgDLZuKOOrua4SLI4uylOZbFYtPD5l53djHJDoLgLECYAAMDd6GqOq7ObUCkYhuHsJpQrLnkCAAAAYBqBAgAAAIBpBAoAAAAAphEoAAAAAJhGoAAAAABgGoECAAAAgGkECgAAAACmVepAERkZqSlTppRYJiwsTEuWLKmQ9gAAAACwVa6BIjo6WhaLpdDj6NGj5VmtjQ0bNqhTp07y9/eXl5eX2rVrpw8++KDC6gcAAADuZeU+U3ZUVJRiY2NtltWpU6e8q7WqWbOmnnvuOYWHh8vd3V0bN27UU089pcDAQPXv37/C2gEAAADci8r9kicPDw8FBQXZPFxdb07DHh8fry5dusjDw0PBwcGaNWuW8vLyit1XRkaGBg0aJE9PTzVq1Ehr1qwptf7IyEj96le/UkREhJo0aaLJkyerTZs22rlzp8OOEQAAAKiqyv0MRXHOnDmjhx9+WNHR0Vq9erUOHz6smJgYVa9eXQsWLChym+joaJ06dUpbtmyRu7u7Jk2apIyMDLvrNAxDW7ZsUWpqql599dViy+Xk5CgnJ8f6PCsry+46AAAAgKqk3APFxo0b5e3tbX0+YMAArV+/XsuWLVNISIiWLl0qi8Wi8PBwnT17VjNnztS8efPk4mJ78uSHH37Q559/rl27dqlr166SpBUrVigiIqLUNmRmZqp+/frKycmRq6urli1bpr59+xZbftGiRVq4cKHJIwYAAACqjnIPFL169dLy5cutz728vCRJKSkp6tatmywWi3XdAw88oMuXL+v06dMKDQ212U9KSorc3NzUqVMn67Lw8HD5+/uX2gYfHx8lJyfr8uXL+uqrrzRt2jQ1btxYkZGRRZafPXu2pk2bZn2elZWlkJAQew4XAAAAqFLKPVB4eXmpadOmhZYbhmETJgqWSSq0vLR1pXFxcbG2oV27dkpJSdGiRYuKDRQeHh7y8PAocz0AAABAVeO0eShatmyphIQEa1CQpISEBPn4+Kh+/fqFykdERCgvL0/79u2zLktNTdXFixfLXLdhGDb3SAAAAAAwx2mBYvz48Tp16pQmTpyow4cP69NPP9X8+fM1bdq0QvdPSFKLFi0UFRWlmJgY7d69W0lJSRo7dqw8PT1LrGfRokXavHmz/v3vf+vw4cNavHixVq9erd/+9rfldWgAAABAleG0X3mqX7++PvvsMz377LNq27atatasqTFjxmjOnDnFbhMbG6uxY8eqZ8+eqlu3rl588UXNnTu3xHquXLmi8ePH6/Tp0/L09FR4eLj+9re/adiwYY4+JAAAAKDKsRi3XnOEImVlZcnPz0+ZmZny9fWt8PoPxD9V4XUCAADcqfVxQc5uQqXx/AuLnN2EYt3pZ12nXfIEAAAA4O5HoAAAAABgGoECAAAAgGkECgAAAACmESgAAAAAmEagAAAAAGAageIu4OYe4OwmAAAAlJmnxw1JzFBgsVic3YRy5bSJ7WC/iG6Lnd0EAACAMmvd09ktQEXgDAUAAAAA0wgUAAAAAEwjUAAAAAAwjUABAAAAwDQCBQAAAADTCBQAAAAATCNQAAAAADCNQAEAAADANAIFAAAAANMIFAAAAABMI1AAAAAAMI1AAQAAAMA0AgUAAAAA09yc3QAAAADgXta692Tl5xullju47Z0KaI3jcYYCAAAAKEf2hIm7GYECAAAAgGkECgAAAACmESgAAAAAmEagAAAAAGAagQIAAACAaQQKAAAAAKYRKAAAAACY5pRAERYWpiVLlpRYxmKx6JNPPqmQ9gAAAAAwp0yBIjo6Wo8++mih5du2bZPFYtHFixcd1KzyYU+QAQAAAGA/N2c34F5hGIby8vJ048YNZzcFAIB7RrVq1eTq6ursZgAoQbkEio8++kjz5s3T0aNHFRwcrIkTJ2r69OnFlj9y5IjGjBmjPXv2qHHjxnr77bcLlTlw4IAmT56sxMRE1ahRQ4899pgWL14sb29vSVJkZKTatWtncwbi0Ucflb+/v1auXKnIyEidOHFCU6dO1dSpUyXdDAGOkJubq/T0dGVnZztkfwAA4CaLxaIGDRpY/78HUPk4PFAkJSVp6NChWrBggYYNG6aEhASNHz9etWrVUnR0dKHy+fn5GjJkiGrXrq1du3YpKytLU6ZMsSmTnZ2tqKgo3X///dq7d68yMjI0duxYTZgwQStXrrSrXRs2bFDbtm319NNPKyYmpsSyOTk5ysnJsT7Pysoqtmx+fr7S0tLk6uqqevXqyd3dXRaLxa42AQCA4hmGoZ9++kmnT59Ws2bNOFMBVFJlDhQbN24s9C3BrZf5LF68WA899JDmzp0rSWrevLkOHTqk119/vchAERcXp5SUFB0/flwNGjSQJL388ssaMGCAtcyaNWt09epVrV69Wl5eXpKkpUuXatCgQXr11VdVt27dUttds2ZNubq6ysfHR0FBQSWWXbRokRYuXFjqPqWbZyfy8/MVEhKiGjVq2LUNAACwT506dXT8+HFdv36dQAFUUmX+ladevXopOTnZ5vHee+9Z16ekpOiBBx6w2eaBBx7QkSNHiry/ICUlRaGhodYwIUndunUrVKZt27bWMFGwz/z8fKWmppb1EEo1e/ZsZWZmWh+nTp0qdRsXF36BFwAAR+OsP1D5lfkMhZeXl5o2bWqz7PTp09Z/G4ZR6MVf0r0KRa0ravvi3lAKlru4uBTa1/Xr14uttyQeHh7y8PAwtS0AAABQlTj8a/WWLVtq586dNssSEhLUvHnzIk9VtmzZUidPntTZs2etyxITEwuVSU5O1pUrV6zLvv76a7m4uKh58+aSbp4STU9Pt66/ceOGvv/+e5v9uLu78ytM5SgyMrLQ/S8lOX78uCwWi5KTk8utTbh73D5+KnK+Gua9wa2YKwkAysbhN2VPnz5dnTt31gsvvKBhw4YpMTFRS5cu1bJly4os36dPH7Vo0UIjR47Um2++qaysLD333HM2ZYYPH6758+dr1KhRWrBggX766SdNnDhRI0aMsN4/0bt3b02bNk2bNm1SkyZN9NZbbxWaFyMsLEzbt2/X448/Lg8PD9WuXdvRh28j99p53bh+qVzrKOBazUfu1WvZVba008ejRo2y+2b3W23YsEHVqlWzu3xISIjS09PL/e9wp6Kjo3Xx4kWnfni4ePGisrOvlF7QQWrU8JK/v79dZQcNGqSrV68qLi6u0LrExER1795dSUlJ6tChQ5nasHfvXpvLHB1hwYIF+uSTTwqF2PT0dAUEBDi0LkcLCwvTlClTyhTaHensj//RxcyKG4P+fl6qV7em3eWLe51u27ZNvXr10oULF+we05VJUb9gWJTyGB/21n0n7va/D4CbHB4oOnTooL///e+aN2+eXnjhBQUHB+v5558v8oZs6ealSh9//LHGjBmjLl26KCwsTO+8846ioqKsZWrUqKF//etfmjx5sjp37mzzs7EFRo8erf3792vkyJFyc3PT1KlT1atXL5u6nn/+ef3ud79TkyZNlJOT47CfjS1K7rXz+mHvbBn55i67KiuLSzU177zIrlBx65mcdevWad68eTb3onh6etqUv379ul1BoWZN+//zlyRXV9dSb5DHzTDxzttvKi8vr8LqdHNz06TJ0+36D37MmDEaMmSITpw4oYYNG9qse//999WuXbsyhwnp5lnHisI4LNnZH/+jR0a8qNzcihuD7u5u2vTBnDKFCgCAc5TpkqeVK1cW+S1tZGSkDMOwfvh47LHHdPDgQeXm5urEiROaMWOGTfnjx4/bfIvSvHlz7dixQzk5OUpNTVX//v1lGIbNrNytW7fWli1bdPXqVZ0/f15/+ctfbH5tqlq1alq2bJnOnz+vH3/8UbNmzdInn3xi8037/fffr/379+vatWvlGiYk6cb1SxUWJiTJyL9u99mQoKAg68PPz08Wi8X6/Nq1a/L399ff//53RUZGqnr16vrb3/6m8+fP64knnlCDBg1Uo0YNtW7dWh9++KHNfou6ZOXll1/W6NGj5ePjo9DQUP3lL3+xrr/9kqeCGde/+uorderUSTVq1FD37t0L3Xj/4osvKjAwUD4+Pho7dqxmzZqldu3aFXu8Fy5c0PDhw1WnTh15enqqWbNmio2Nta4/c+aMhg0bpoCAANWqVUuDBw/W8ePHJd38RnvVqlX69NNPZbFYZLFYtG3bNrv62VGys69UaJiQpLy8PLvPiAwcOFCBgYGFzmplZ2dr3bp1GjNmjF3j53a3X3Zy5MgR9ejRQ9WrV1fLli21efPmQtvMnDlTzZs3V40aNdS4cWPNnTvXei/VypUrtXDhQu3fv9/6tyxo8+2Xrxw4cEC9e/eWp6enatWqpaefflqXL1+2ro+Ojtajjz6qN954Q8HBwapVq5aeeeaZEu/b2r9/v3r16iUfHx/5+vqqY8eO2rdvn3V9QkKCevToIU9PT4WEhGjSpEnWyzxvnUenoO0V6WLmlQoNE5KUm5tXbmdESurrotgz9opy8uRJDR48WN7e3vL19dXQoUP1448/WtcXjKNbTZkyRZGRkdb18fHxevvtt61/94L3pluVND5KO9Zly5apWbNmql69uurWratf//rXZaq7pH1IN++BfO2119S4cWN5enqqbdu2+sc//iHp5v8BBV/8BQQEyGKxFPvlI4DKjZ8mQpFmzpypSZMmKSUlRf3799e1a9fUsWNHbdy4Ud9//72efvppjRgxQrt37y5xP2+++aY6deqkb7/9VuPHj9fvf/97HT58uMRtnnvuOb355pvat2+f3NzcNHr0aOu6NWvW6KWXXtKrr76qpKQkhYaGavny5SXub+7cuTp06JA+//xzpaSkaPny5dbLrLKzs9WrVy95e3tr+/bt2rlzp7y9vRUVFaXc3FzNmDFDQ4cOVVRUlNLT05Wenq7u3bvb2YtVg5ubm0aOHKmVK1faBPX169crNzdXw4cPNz1+ChTMV+Pq6qpdu3bp3Xff1cyZMwuV8/Hx0cqVK3Xo0CG9/fbb+utf/6q33npLkjRs2DBNnz5drVq1sv4thw0bVmgfBfPeBAQEaO/evVq/fr3i4uI0YcIEm3Jbt27VsWPHtHXrVq1atUorV64s8VLB4cOHq0GDBtq7d6+SkpI0a9Ys65m/AwcOqH///hoyZIi+++47rVu3Tjt37rTWuWHDBjVo0EDPP/+8te0wp7S+vp29Y+92BV+K/ec//1F8fLw2b96sY8eOFTnmivP222+rW7duiomJsf7dQ0JCCpUrbnyUdqz79u3TpEmT9Pzzzys1NVVffPGFevToUaa6S9qHJM2ZM0exsbFavny5Dh48qKlTp+q3v/2t4uPjFRISoo8++kiSlJqaqvT09CIntgVQ+ZXLTNm4+02ZMkVDhgyxWXbrmaaJEyfqiy++0Pr169W1a9di9/Pwww9r/Pjxkm6GlLfeekvbtm1TeHh4sdu89NJL6tmzpyRp1qxZeuSRR3Tt2jVVr15d//M//6MxY8boqaeekiTNmzdPX375pc23x7c7efKk2rdvr06dOkm6+c13gbVr18rFxUXvvfee9Vu92NhY+fv7a9u2berXr588PT2Vk5PDZTElGD16tF5//XXr9dDSzcudhgwZooCAAAUEBJgaPwXsma9GuvnhpUBYWJimT5+udevW6Q9/+IM8PT3l7e0tNze3Ev+W9s57ExAQoKVLl8rV1VXh4eF65JFH9NVXXxU7cebJkyf17LPPWsd+s2bNrOtef/11Pfnkk9YzfM2aNdM777yjnj17avny5WWaR6cqK22eJKn0vq5evbpNeXvH3u3i4uL03XffKS0tzfpB/IMPPlCrVq20d+9ede7cudTj8fPzk7u7u2rUqFHi37248VHasZ48eVJeXl4aOHCgfHx81LBhQ7Vv375MdZe0jytXrmjx4sXasmWL9efgGzdurJ07d+rPf/6zevbsab1UNjAwkHsogLsYZyhQpIIP3wVu3Lihl156SW3atFGtWrXk7e2tL7/8UidPnixxP23atLH+u+DSqoyMDLu3CQ4OliTrNqmpqerSpYtN+duf3+73v/+91q5dq3bt2ukPf/iDEhISrOuSkpJ09OhR+fj4yNvbW97e3qpZs6auXbumY8eOlbhf/Fd4eLi6d++u999/X5J07Ngx7dixw3p2yez4KWDPfDWS9I9//EO/+MUvFBQUJG9vb82dO9fuOm6ty555b1q1amXzy3XBwcElju1p06Zp7Nix6tOnj1555RWb8ZWUlKSVK1dax6C3t7f69++v/Px8paWllan9VVlp8yRJZe9re8begAEDrPtq1aqVdbuQkBCbb/Vbtmwpf39/paSkOPKwi1Xasfbt21cNGzZU48aNNWLECK1Zs0bZ2dllqqOkfRw6dEjXrl1T3759bdqwevVq3l+BewxnKFCk239d580339Rbb72lJUuWqHXr1vLy8tKUKVOUm5tb4n5uv5nbYrEoPz/f7m0Kzhrcuk1Z5jmRbv5nf+LECW3atElxcXF66KGH9Mwzz+iNN95Qfn6+OnbsqDVr1hTariJvCr4XjBkzRhMmTNCf/vQnxcbGqmHDhnrooYckmR8/BeyZr2bXrl16/PHHtXDhQvXv319+fn5au3at3nzzzTIdhz3z3khlH9sLFizQk08+qU2bNunzzz/X/PnztXbtWv3qV79Sfn6+fve732nSpEmFtgsNDS1T+6uy0uZJklTmvrZn7L333nu6evWqpP+Oi+LG0a3LHTl/UlFKO1Z3d3d988032rZtm7788kvNmzdPCxYs0N69e+0+W+Dj41PsPgpeD5s2bVL9+vVttmOuJ+DeQqCAXXbs2KHBgwfrt7/9raSb/1EdOXJEERERFdqOFi1aaM+ePRoxYoR12a03thanTp06io6OVnR0tB588EE9++yzeuONN9ShQwetW7dOgYGB8vX1LXJb5i+xz9ChQzV58mT97//+r1atWqWYmBjrB6c7HT+3zldTr149SYXnq/n666/VsGFDm5+dPnHihE0Ze/6WLVu21KpVq3TlyhVrsL593huzmjdvrubNm2vq1Kl64oknFBsbq1/96lfq0KGDDh48WOjDcFnbjtLZ09e3smfs3f5h+dbtTp06ZT1LcejQIWVmZlrHfZ06dQrNl5ScnGwTVu39uxdVzp5jdXNzU58+fdSnTx/Nnz9f/v7+2rJli4YMGWJ33cXto2/fvvLw8NDJkyetl7EW1W6p8KVpAO4uXPIEuzRt2lSbN29WQkKCUlJS9Lvf/U7nzp2r8HZMnDhRK1as0KpVq3TkyBG9+OKL+u6770r81Zt58+bp008/1dGjR3Xw4EFt3LjR+h/68OHDVbt2bQ0ePFg7duxQWlqa4uPjNXnyZOs3m2FhYfruu++Umpqqn3/+2aHfIN5LvL29NWzYMP3xj3/U2bNnbX6t5U7Hz63z1ezfv187duwoNF9N06ZNdfLkSa1du1bHjh3TO++8o48//timTFhYmNLS0pScnKyff/5ZOTk5heoaPny4qlevrlGjRun777/X1q1bC817U1ZXr17VhAkTtG3bNp04cUJff/219u7dax2HM2fOVGJiop555hklJyfryJEj+uc//6mJEyfatH379u06c+aMfv75Z1PtgH19fSt7xl5x27Vp00bDhw/XN998oz179mjkyJHq2bOn9ZLS3r17a9++fVq9erWOHDmi+fPnFwoYYWFh2r17t44fP66ff/652LNgRY2P0o5148aNeuedd5ScnKwTJ05o9erVys/PV4sWLeyuu6R9+Pj4aMaMGZo6dapWrVqlY8eO6dtvv9Wf/vQnrVq1SpLUsGFDWSwWbdy4UT/99FOJ98MBqLwIFLDL3Llz1aFDB/Xv31+RkZEKCgoq9HOHFWH48OGaPXu2ZsyYoQ4dOigtLU3R0dGFbqS8lbu7u2bPnq02bdqoR48ecnV11dq1ayXdnONk+/btCg0N1ZAhQxQREaHRo0fr6tWr1jMWMTExatGihTp16qQ6dero66+/rpBjvRuNGTNGFy5cUJ8+fWwuH7nT8VMwX01OTo66dOmisWPH6qWXXrIpM3jwYE2dOlUTJkxQu3btlJCQoLlz59qUeeyxxxQVFaVevXqpTp06Rf50bcG8N//5z3/UuXNn/frXv9ZDDz2kpUuXlq0zbuHq6qrz589r5MiRat68uYYOHaoBAwZo4cKFkm7eNxQfH68jR47owQcfVPv27TV37lzrPUTSzXl0jh8/riZNmnA53h2wp69vZc/YK0rBTxEHBASoR48e6tOnjxo3bqx169ZZy/Tv319z587VH/7wB3Xu3FmXLl3SyJEjbfYzY8YMubq6qmXLlqpTp06x9wQVNT5KO1Z/f39t2LBBvXv3VkREhN599119+OGH1vtA7Km7tH288MILmjdvnhYtWqSIiAj1799f/+///T81atRI0s2zOwsXLtSsWbNUt27dYn9tC7jbubhU7M99VzSLUd4TMtwDsrKy5Ofnp8zMzEKXxVy7dk1paWlq1KiRzYfayjyx3b2mb9++CgoK0gcffODsppSLyj6xHe59TGwHZyru/1kAjlPSZ117cA9FOXGvXkvNOy+ye7K5O+VazadKhIns7Gy9++676t+/v1xdXfXhhx8qLi7O7omm7kb+/v6aNHm63RPNOUKNGl6ECVjVq1tTmz6YU24TzRXF38+LMAEAdwkCRTlyr15LqgIf8iuSxWLRZ599phdffFE5OTlq0aKFPvroI/Xp08fZTStX/v7+fMCHU9WrW5MP+ACAIhEocFfx9PRUXFycs5sBAACA/8NN2QAAAABMI1AAAAAAMI1A4SD8WBYAAI7H/69A5UeguEMFM5pmZ2c7uSUAANx7cnNzJd2czwVA5cRN2XfI1dVV/v7+ysjIkHRzUqySZm0GAAD2yc/P108//aQaNWrIzY2PLEBlxavTAYKCgiTJGioAAIBjuLi4KDQ0lC/rgEqMQOEAFotFwcHBCgwM1PXrFTMzNgAAVYG7u7tcXLhCG6jMCBQO5OrqyjWeAAAAqFKI/AAAAABMI1AAAAAAMI1AAQAAAMA07qGwQ8GkOllZWU5uCQAAAOBYBZ9xzU4kSaCww6VLlyRJISEhTm4JAAAAUD4uXbokPz+/Mm9nMZjTvlT5+fk6e/asfHx8nPI72FlZWQoJCdGpU6fk6+tb4fXfK+hHx6Af7xx96Bj0o2PQj45BPzoG/egYZe1HwzB06dIl1atXz9TPNHOGwg4uLi5q0KCBs5shX19fXlwOQD86Bv145+hDx6AfHYN+dAz60THoR8coSz+aOTNRgJuyAQAAAJhGoAAAAABgGoHiLuDh4aH58+fLw8PD2U25q9GPjkE/3jn60DHoR8egHx2DfnQM+tExKrofuSkbAAAAgGmcoQAAAABgGoECAAAAgGkECgAAAACmESgquWXLlqlRo0aqXr26OnbsqB07dji7SZXGokWL1LlzZ/n4+CgwMFCPPvqoUlNTbcpER0fLYrHYPO6//36bMjk5OZo4caJq164tLy8v/fKXv9Tp06cr8lCcasGCBYX6KCgoyLreMAwtWLBA9erVk6enpyIjI3Xw4EGbfVT1PpSksLCwQv1osVj0zDPPSGIsFmf79u0aNGiQ6tWrJ4vFok8++cRmvaPG34ULFzRixAj5+fnJz89PI0aM0MWLF8v56CpOSf14/fp1zZw5U61bt5aXl5fq1aunkSNH6uzZszb7iIyMLDRGH3/8cZsyVbkfJce9jqt6Pxb1XmmxWPT6669by1T18WjPZ5zK9P5IoKjE1q1bpylTpui5557Tt99+qwcffFADBgzQyZMnnd20SiE+Pl7PPPOMdu3apc2bNysvL0/9+vXTlStXbMpFRUUpPT3d+vjss89s1k+ZMkUff/yx1q5dq507d+ry5csaOHCgbty4UZGH41StWrWy6aMDBw5Y17322mtavHixli5dqr179yooKEh9+/bVpUuXrGXoQ2nv3r02fbh582ZJ0m9+8xtrGcZiYVeuXFHbtm21dOnSItc7avw9+eSTSk5O1hdffKEvvvhCycnJGjFiRLkfX0UpqR+zs7P1zTffaO7cufrmm2+0YcMG/fDDD/rlL39ZqGxMTIzNGP3zn/9ss74q92MBR7yOq3o/3tp/6enpev/992WxWPTYY4/ZlKvK49GezziV6v3RQKXVpUsXY9y4cTbLwsPDjVmzZjmpRZVbRkaGIcmIj4+3Lhs1apQxePDgYre5ePGiUa1aNWPt2rXWZWfOnDFcXFyML774ojybW2nMnz/faNu2bZHr8vPzjaCgIOOVV16xLrt27Zrh5+dnvPvuu4Zh0IfFmTx5stGkSRMjPz/fMAzGoj0kGR9//LH1uaPG36FDhwxJxq5du6xlEhMTDUnG4cOHy/moKt7t/ViUPXv2GJKMEydOWJf17NnTmDx5crHb0I+OeR3Tj4UNHjzY6N27t80yxqOt2z/jVLb3R85QVFK5ublKSkpSv379bJb369dPCQkJTmpV5ZaZmSlJqlmzps3ybdu2KTAwUM2bN1dMTIwyMjKs65KSknT9+nWbfq5Xr57uu+++KtXPR44cUb169dSoUSM9/vjj+ve//y1JSktL07lz52z6x8PDQz179rT2D31YWG5urv72t79p9OjRslgs1uWMxbJx1PhLTEyUn5+funbtai1z//33y8/Pr8r2bWZmpiwWi/z9/W2Wr1mzRrVr11arVq00Y8YMm2866ceb7vR1TD/a+vHHH7Vp0yaNGTOm0DrG43/d/hmnsr0/upk/NJSnn3/+WTdu3FDdunVtltetW1fnzp1zUqsqL8MwNG3aNP3iF7/QfffdZ10+YMAA/eY3v1HDhg2VlpamuXPnqnfv3kpKSpKHh4fOnTsnd3d3BQQE2OyvKvVz165dtXr1ajVv3lw//vijXnzxRXXv3l0HDx609kFR4/DEiROSRB8W4ZNPPtHFixcVHR1tXcZYLDtHjb9z584pMDCw0P4DAwOrZN9eu3ZNs2bN0pNPPilfX1/r8uHDh6tRo0YKCgrS999/r9mzZ2v//v3Wy/foR8e8julHW6tWrZKPj4+GDBlis5zx+F9FfcapbO+PBIpK7tZvN6Wbg+r2ZZAmTJig7777Tjt37rRZPmzYMOu/77vvPnXq1EkNGzbUpk2bCr153aoq9fOAAQOs/27durW6deumJk2aaNWqVdabDc2Mw6rUh7dbsWKFBgwYoHr16lmXMRbNc8T4K6p8Vezb69ev6/HHH1d+fr6WLVtmsy4mJsb67/vuu0/NmjVTp06d9M0336hDhw6S6EdHvY6rej/e6v3339fw4cNVvXp1m+WMx/8q7jOOVHneH7nkqZKqXbu2XF1dC6XDjIyMQmm0qps4caL++c9/auvWrWrQoEGJZYODg9WwYUMdOXJEkhQUFKTc3FxduHDBplxV7mcvLy+1bt1aR44csf7aU0njkD60deLECcXFxWns2LEllmMsls5R4y8oKEg//vhjof3/9NNPVapvr1+/rqFDhyotLU2bN2+2OTtRlA4dOqhatWo2Y5R+tGXmdUw//teOHTuUmppa6vulVHXHY3GfcSrb+yOBopJyd3dXx44draf2CmzevFndu3d3UqsqF8MwNGHCBG3YsEFbtmxRo0aNSt3m/PnzOnXqlIKDgyVJHTt2VLVq1Wz6OT09Xd9//32V7eecnBylpKQoODjYerr51v7Jzc1VfHy8tX/oQ1uxsbEKDAzUI488UmI5xmLpHDX+unXrpszMTO3Zs8daZvfu3crMzKwyfVsQJo4cOaK4uDjVqlWr1G0OHjyo69evW8co/ViYmdcx/fhfK1asUMeOHdW2bdtSy1a18VjaZ5xK9/5o//3lqGhr1641qlWrZqxYscI4dOiQMWXKFMPLy8s4fvy4s5tWKfz+9783/Pz8jG3bthnp6enWR3Z2tmEYhnHp0iVj+vTpRkJCgpGWlmZs3brV6Natm1G/fn0jKyvLup9x48YZDRo0MOLi4oxvvvnG6N27t9G2bVsjLy/PWYdWoaZPn25s27bN+Pe//23s2rXLGDhwoOHj42MdZ6+88orh5+dnbNiwwThw4IDxxBNPGMHBwfRhEW7cuGGEhoYaM2fOtFnOWCzepUuXjG+//db49ttvDUnG4sWLjW+//db660OOGn9RUVFGmzZtjMTERCMxMdFo3bq1MXDgwAo/3vJSUj9ev37d+OUvf2k0aNDASE5Otnm/zMnJMQzDMI4ePWosXLjQ2Lt3r5GWlmZs2rTJCA8PN9q3b08//l8/OvJ1XJX7sUBmZqZRo0YNY/ny5YW2ZzyW/hnHMCrX+yOBopL705/+ZDRs2NBwd3c3OnToYPOTqFWdpCIfsbGxhmEYRnZ2ttGvXz+jTp06RrVq1YzQ0FBj1KhRxsmTJ232c/XqVWPChAlGzZo1DU9PT2PgwIGFytzLhg0bZgQHBxvVqlUz6tWrZwwZMsQ4ePCgdX1+fr4xf/58IygoyPDw8DB69OhhHDhwwGYfVb0PC/zrX/8yJBmpqak2yxmLxdu6dWuRr+NRo0YZhuG48Xf+/Hlj+PDhho+Pj+Hj42MMHz7cuHDhQgUdZfkrqR/T0tKKfb/cunWrYRiGcfLkSaNHjx5GzZo1DXd3d6NJkybGpEmTjPPnz9vUU5X70ZGv46rcjwX+/Oc/G56ensbFixcLbc94LP0zjmFUrvdHy/81GgAAAADKjHsoAAAAAJhGoAAAAABgGoECAAAAgGkECgAAAACmESgAAAAAmEagAAAAAGAagQIAAACAaQQKAAAAAKYRKAAATrFgwQK1a9fO2c0AANwhZsoGADicxWIpcf2oUaO0dOlS5eTkqFatWhXUKgBAeSBQAAAc7ty5c9Z/r1u3TvPmzVNqaqp1maenp/z8/JzRNACAg3HJEwDA4YKCgqwPPz8/WSyWQstuv+QpOjpajz76qF5++WXVrVtX/v7+WrhwofLy8vTss8+qZs2aatCggd5//32bus6cOaNhw4YpICBAtWrV0uDBg3X8+PGKPWAAqMIIFACASmPLli06e/astm/frsWLF2vBggUaOHCgAgICtHv3bo0bN07jxo3TqVOnJEnZ2dnq1auXvL29tX37du3cuVPe3t6KiopSbm6uk48GAKoGAgUAoNKoWbOm3nnnHbVo0UKjR49WixYtlJ2drT/+8Y9q1qyZZs+eLXd3d3399deSpLVr18rFxUXvvfeeWrdurYiICMXGxurkyZPatm2bcw8GAKoIN2c3AACAAq1atZKLy3+/66pbt67uu+8+63NXV1fVqlVLGRkZkqSkpCQdPXpUPj4+Nvu5du2ajh07VjGNBoAqjkABAKg0qlWrZvPcYrEUuSw/P1+SlJ+fr44dO2rNmjWF9lWnTp3yaygAwIpAAQC4a3Xo0EHr1q1TYGCgfH19nd0cAKiSuIcCAHDXGj58uGrXrq3Bgwdrx44dSktLU3x8vCZPnqzTp087u3kAUCUQKAAAd60aNWpo+/btCg0N1ZAhQxQREaHRo0fr6tWrnLEAgArCxHYAAAAATOMMBQAAAADTCBQAAAAATCNQAAAAADCNQAEAAADANAIFAAAAANMIFAAAAABMI1AAAAAAMI1AAQAAAMA0AgUAAAAA0wgUAAAAAEwjUAAAAAAwjUABAAAAwLT/D9blFiP5mbuxAAAAAElFTkSuQmCC", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "cv_cmap = matplotlib.colormaps[\"cividis\"]\n", - "plt.figure(figsize=(8, 3))\n", - "\n", - "for i, (train_mask, valid_mask) in enumerate(cv_folds):\n", - " idx = np.array([np.nan] * time_horizon)\n", - " idx[np.arange(*train_mask)] = 1\n", - " idx[np.arange(*valid_mask)] = 0\n", - " plt.scatter(\n", - " range(time_horizon),\n", - " [i + 0.5] * time_horizon,\n", - " c=idx,\n", - " marker=\"_\",\n", - " capstyle=\"butt\",\n", - " s=1,\n", - " lw=20,\n", - " cmap=cv_cmap,\n", - " vmin=-1.5,\n", - " vmax=1.5,\n", - " )\n", - "\n", - "idx = np.array([np.nan] * time_horizon)\n", - "idx[np.arange(*holdout)] = -1\n", - "plt.scatter(\n", - " range(time_horizon),\n", - " [n_folds + 0.5] * time_horizon,\n", - " c=idx,\n", - " marker=\"_\",\n", - " capstyle=\"butt\",\n", - " s=1,\n", - " lw=20,\n", - " cmap=cv_cmap,\n", - " vmin=-1.5,\n", - " vmax=1.5,\n", - ")\n", - "\n", - "plt.xlabel(\"Time\")\n", - "plt.yticks(\n", - " ticks=np.arange(n_folds + 1) + 0.5,\n", - " labels=[f\"Fold {i}\" for i in range(n_folds)] + [\"Holdout\"],\n", - ")\n", - "plt.ylim([len(cv_folds) + 1.2, -0.2])\n", - "\n", - "norm = matplotlib.colors.Normalize(vmin=-1.5, vmax=1.5)\n", - "plt.legend(\n", - " [\n", - " Patch(color=cv_cmap(norm(1))),\n", - " Patch(color=cv_cmap(norm(0))),\n", - " Patch(color=cv_cmap(norm(-1))),\n", - " ],\n", - " [\"Training set\", \"Validation set\", \"Held-out test set\"],\n", - " ncol=3,\n", - " loc=\"best\",\n", - ")\n", - "plt.tight_layout()" - ] - }, - { - "cell_type": "markdown", - "id": "17b4c300-971d-419e-b167-19ff533860dd", - "metadata": {}, - "source": [ - "## Launch a Dask client on Kubernetes" - ] - }, - { - "cell_type": "markdown", - "id": "bd94ac3f-2843-4596-9f27-14864ef5ec8e", - "metadata": {}, - "source": [ - "Let us set up a Dask cluster using the `KubeCluster` class." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "deac0215-66fb-4c2b-a606-bbb6b6ec2260", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "8c26e7b9e0a54dc8863e6733a70ebdbf", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Output()" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n"
-      ],
-      "text/plain": []
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    }
-   ],
-   "source": [
-    "cluster = KubeCluster(\n",
-    "    name=\"rapids-dask\",\n",
-    "    image=rapids_image,\n",
-    "    worker_command=\"dask-cuda-worker\",\n",
-    "    n_workers=n_workers,\n",
-    "    resources={\"limits\": {\"nvidia.com/gpu\": \"1\"}},\n",
-    "    env={\"DISABLE_JUPYTER\": \"true\", \"EXTRA_PIP_PACKAGES\": \"optuna gcsfs\"},\n",
-    ")"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 9,
-   "id": "b2d5c010-4cd3-4d58-9c88-84d4aa404234",
-   "metadata": {
-    "tags": []
-   },
-   "outputs": [
-    {
-     "data": {
-      "application/vnd.jupyter.widget-view+json": {
-       "model_id": "76e37407208149dcb0b97123d8a2a135",
-       "version_major": 2,
-       "version_minor": 0
-      },
-      "text/html": [
-       "
\n", - "
\n", - "
\n", - "
\n", - "

KubeCluster

\n", - "

rapids-dask

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - " Dashboard: http://rapids-dask-scheduler.kubeflow-user-example-com:8787/status\n", - " \n", - " Workers: 1\n", - "
\n", - " Total threads: 1\n", - " \n", - " Total memory: 117.93 GiB\n", - "
\n", - "\n", - "
\n", - " \n", - "

Scheduler Info

\n", - "
\n", - "\n", - "
\n", - "
\n", - "
\n", - "
\n", - "

Scheduler

\n", - "

Scheduler-8df8661b-8c41-4c39-ba31-07f76461af5e

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - " Comm: tcp://10.36.3.19:8786\n", - " \n", - " Workers: 1\n", - "
\n", - " Dashboard: http://10.36.3.19:8787/status\n", - " \n", - " Total threads: 1\n", - "
\n", - " Started: Just now\n", - " \n", - " Total memory: 117.93 GiB\n", - "
\n", - "
\n", - "
\n", - "\n", - "
\n", - " \n", - "

Workers

\n", - "
\n", - "\n", - " \n", - "
\n", - "
\n", - "
\n", - "
\n", - " \n", - "

Worker: rapids-dask-default-worker-9fc2234a8d

\n", - "
\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "\n", - " \n", - "\n", - "
\n", - " Comm: tcp://10.36.4.20:35995\n", - " \n", - " Total threads: 1\n", - "
\n", - " Dashboard: http://10.36.4.20:8788/status\n", - " \n", - " Memory: 117.93 GiB\n", - "
\n", - " Nanny: tcp://10.36.4.20:40817\n", - "
\n", - " Local directory: /tmp/dask-scratch-space/worker-mwjhcmjv\n", - "
\n", - " GPU: Tesla T4\n", - " \n", - " GPU memory: 15.00 GiB\n", - "
\n", - "
\n", - "
\n", - "
\n", - " \n", - "\n", - "
\n", - "
\n", - "\n", - "
\n", - "
\n", - "
" - ], - "text/plain": [ - "KubeCluster(rapids-dask, 'tcp://rapids-dask-scheduler.kubeflow-user-example-com:8786', workers=1, threads=1, memory=117.93 GiB)" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "cluster" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "727c1182-4c74-4d2d-a059-c359193d76ff", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "client = Client(cluster)" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "633a2f56-9ad2-4e38-8e6e-be3bf52ba795", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "
\n", - "
\n", - "

Client

\n", - "

Client-54cb8819-2c1e-11ee-8c1e-9a09a5b5e674

\n", - " \n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "\n", - "
Connection method: Cluster objectCluster type: dask_kubernetes.KubeCluster
\n", - " Dashboard: http://rapids-dask-scheduler.kubeflow-user-example-com:8787/status\n", - "
\n", - "\n", - " \n", - " \n", - " \n", - "\n", - " \n", - "
\n", - "

Cluster Info

\n", - "
\n", - "
\n", - "
\n", - "
\n", - "

KubeCluster

\n", - "

rapids-dask

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - " Dashboard: http://rapids-dask-scheduler.kubeflow-user-example-com:8787/status\n", - " \n", - " Workers: 3\n", - "
\n", - " Total threads: 3\n", - " \n", - " Total memory: 353.79 GiB\n", - "
\n", - "\n", - "
\n", - " \n", - "

Scheduler Info

\n", - "
\n", - "\n", - "
\n", - "
\n", - "
\n", - "
\n", - "

Scheduler

\n", - "

Scheduler-8df8661b-8c41-4c39-ba31-07f76461af5e

\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - " Comm: tcp://10.36.3.19:8786\n", - " \n", - " Workers: 3\n", - "
\n", - " Dashboard: http://10.36.3.19:8787/status\n", - " \n", - " Total threads: 3\n", - "
\n", - " Started: Just now\n", - " \n", - " Total memory: 353.79 GiB\n", - "
\n", - "
\n", - "
\n", - "\n", - "
\n", - " \n", - "

Workers

\n", - "
\n", - "\n", - " \n", - "
\n", - "
\n", - "
\n", - "
\n", - " \n", - "

Worker: rapids-dask-default-worker-3588dacd46

\n", - "
\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "\n", - " \n", - "\n", - "
\n", - " Comm: tcp://10.36.1.16:41549\n", - " \n", - " Total threads: 1\n", - "
\n", - " Dashboard: http://10.36.1.16:8788/status\n", - " \n", - " Memory: 117.93 GiB\n", - "
\n", - " Nanny: tcp://10.36.1.16:43755\n", - "
\n", - " Local directory: /tmp/dask-scratch-space/worker-j2wa7czf\n", - "
\n", - " GPU: Tesla T4\n", - " \n", - " GPU memory: 15.00 GiB\n", - "
\n", - "
\n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "
\n", - "
\n", - " \n", - "

Worker: rapids-dask-default-worker-5d46c38fcf

\n", - "
\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "\n", - " \n", - "\n", - "
\n", - " Comm: tcp://10.36.2.23:39471\n", - " \n", - " Total threads: 1\n", - "
\n", - " Dashboard: http://10.36.2.23:8788/status\n", - " \n", - " Memory: 117.93 GiB\n", - "
\n", - " Nanny: tcp://10.36.2.23:44423\n", - "
\n", - " Local directory: /tmp/dask-scratch-space/worker-dsfuecsa\n", - "
\n", - " GPU: Tesla T4\n", - " \n", - " GPU memory: 15.00 GiB\n", - "
\n", - "
\n", - "
\n", - "
\n", - " \n", - "
\n", - "
\n", - "
\n", - "
\n", - " \n", - "

Worker: rapids-dask-default-worker-9fc2234a8d

\n", - "
\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "\n", - " \n", - "\n", - "
\n", - " Comm: tcp://10.36.4.20:35995\n", - " \n", - " Total threads: 1\n", - "
\n", - " Dashboard: http://10.36.4.20:8788/status\n", - " \n", - " Memory: 117.93 GiB\n", - "
\n", - " Nanny: tcp://10.36.4.20:40817\n", - "
\n", - " Local directory: /tmp/dask-scratch-space/worker-mwjhcmjv\n", - "
\n", - " GPU: Tesla T4\n", - " \n", - " GPU memory: 15.00 GiB\n", - "
\n", - "
\n", - "
\n", - "
\n", - " \n", - "\n", - "
\n", - "
\n", - "\n", - "
\n", - "
\n", - "
\n", - "
\n", - " \n", - "\n", - "
\n", - "
" - ], - "text/plain": [ - "" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "client" - ] - }, - { - "cell_type": "markdown", - "id": "5e0aa64d-1c5a-4a9b-886b-07e846111b55", - "metadata": {}, - "source": [ - "## Define the custom evaluation metric" - ] - }, - { - "cell_type": "markdown", - "id": "32d8c67d-3187-4483-bdc9-27b5b7b0cda0", - "metadata": {}, - "source": [ - "The M5 forecasting competition defines a custom metric called WRMSSE as follows:\n", - "\n", - "$$\n", - "WRMSSE = \\sum w_i \\cdot RMSSE_i\n", - "$$\n", - "\n", - "i.e. WRMSEE is a weighted sum of RMSSE for all product items $i$. RMSSE is in turn defined to be\n", - "\n", - "$$\n", - "RMSSE = \\sqrt{\\frac{1/h \\cdot \\sum_t{\\left(Y_t - \\hat{Y}_t\\right)}^2}{1/(n-1)\\sum_t{(Y_t - Y_{t-1})}^2}}\n", - "$$\n", - "\n", - "where the squared error of the prediction (forecast) is normalized by the speed at which the sales amount changes per unit in the training data.\n", - "\n", - "Here is the implementation of the WRMSSE using cuDF. We use the product weights $w_i$ as computed in the first preprocessing notebook." - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "43f0e74d-734b-4135-8313-b9ed277a3341", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "def wrmsse(product_weights, df, pred_sales, train_mask, valid_mask):\n", - " \"\"\"Compute WRMSSE metric\"\"\"\n", - " df_train = df[(df[\"day_id\"] >= train_mask[0]) & (df[\"day_id\"] < train_mask[1])]\n", - " df_valid = df[(df[\"day_id\"] >= valid_mask[0]) & (df[\"day_id\"] < valid_mask[1])]\n", - "\n", - " # Compute denominator: 1/(n-1) * sum( (y(t) - y(t-1))**2 )\n", - " diff = (\n", - " df_train.sort_values([\"item_id\", \"day_id\"])\n", - " .groupby([\"item_id\"])[[\"sales\"]]\n", - " .diff(1)\n", - " )\n", - " x = (\n", - " df_train[[\"item_id\", \"day_id\"]]\n", - " .join(diff, how=\"left\")\n", - " .rename(columns={\"sales\": \"diff\"})\n", - " .sort_values([\"item_id\", \"day_id\"])\n", - " )\n", - " x[\"diff\"] = x[\"diff\"] ** 2\n", - " xx = x.groupby([\"item_id\"])[[\"diff\"]].agg([\"sum\", \"count\"]).sort_index()\n", - " xx.columns = xx.columns.map(\"_\".join)\n", - " xx[\"denominator\"] = xx[\"diff_sum\"] / xx[\"diff_count\"]\n", - " t = xx.reset_index()\n", - "\n", - " # Compute numerator: 1/h * sum( (y(t) - y_pred(t))**2 )\n", - " X_valid = df_valid.drop(columns=[\"item_id\", \"cat_id\", \"day_id\", \"sales\"])\n", - " if \"dept_id\" in X_valid.columns:\n", - " X_valid = X_valid.drop(columns=[\"dept_id\"])\n", - " df_pred = cudf.DataFrame(\n", - " {\n", - " \"item_id\": df_valid[\"item_id\"].copy(),\n", - " \"pred_sales\": pred_sales,\n", - " \"sales\": df_valid[\"sales\"].copy(),\n", - " }\n", - " )\n", - " df_pred[\"diff\"] = (df_pred[\"sales\"] - df_pred[\"pred_sales\"]) ** 2\n", - " yy = df_pred.groupby([\"item_id\"])[[\"diff\"]].agg([\"sum\", \"count\"]).sort_index()\n", - " yy.columns = yy.columns.map(\"_\".join)\n", - " yy[\"numerator\"] = yy[\"diff_sum\"] / yy[\"diff_count\"]\n", - "\n", - " zz = yy[[\"numerator\"]].join(xx[[\"denominator\"]], how=\"left\")\n", - " zz = zz.join(product_weights, how=\"left\").sort_index()\n", - " # Filter out zero denominator.\n", - " # This can occur if the product was never on sale during the period in the training set\n", - " zz = zz[zz[\"denominator\"] != 0]\n", - " zz[\"rmsse\"] = np.sqrt(zz[\"numerator\"] / zz[\"denominator\"])\n", - " t = zz[\"rmsse\"].multiply(zz[\"weights\"])\n", - " return zz[\"rmsse\"].multiply(zz[\"weights\"]).sum()" - ] - }, - { - "cell_type": "markdown", - "id": "b5e83747-b181-49f3-b302-bdffd1e2d847", - "metadata": {}, - "source": [ - "## Define the training and hyperparameter search pipeline using Optuna" - ] - }, - { - "cell_type": "markdown", - "id": "6145f299-9c1d-4391-af2a-63bf9afc8fb8", - "metadata": {}, - "source": [ - "Optuna lets us define the training procedure iteratively, i.e. as if we were to write an ordinary function to train a single model. Instead of a fixed hyperparameter combination, the function now takes in a `trial` object which yields different hyperparameter combinations.\n", - "\n", - "In this example, we partition the training data according to the store and then fit a separate XGBoost model per data segment." - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "38c7c1ae-1888-4ef9-8a57-0eabd845fcb4", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "def objective(trial):\n", - " fs = gcsfs.GCSFileSystem()\n", - " with fs.open(f\"{bucket_name}/product_weights.pkl\", \"rb\") as f:\n", - " product_weights = cudf.DataFrame(pd.read_pickle(f))\n", - " params = {\n", - " \"n_estimators\": 100,\n", - " \"verbosity\": 0,\n", - " \"learning_rate\": 0.01,\n", - " \"objective\": \"reg:tweedie\",\n", - " \"tree_method\": \"gpu_hist\",\n", - " \"grow_policy\": \"depthwise\",\n", - " \"predictor\": \"gpu_predictor\",\n", - " \"enable_categorical\": True,\n", - " \"lambda\": trial.suggest_float(\"lambda\", 1e-8, 100.0, log=True),\n", - " \"alpha\": trial.suggest_float(\"alpha\", 1e-8, 100.0, log=True),\n", - " \"colsample_bytree\": trial.suggest_float(\"colsample_bytree\", 0.2, 1.0),\n", - " \"max_depth\": trial.suggest_int(\"max_depth\", 2, 6, step=1),\n", - " \"min_child_weight\": trial.suggest_float(\n", - " \"min_child_weight\", 1e-8, 100, log=True\n", - " ),\n", - " \"gamma\": trial.suggest_float(\"gamma\", 1e-8, 1.0, log=True),\n", - " \"tweedie_variance_power\": trial.suggest_float(\"tweedie_variance_power\", 1, 2),\n", - " }\n", - " scores = [[] for store in STORES]\n", - "\n", - " for store_id, store in enumerate(STORES):\n", - " print(f\"Processing store {store}...\")\n", - " with fs.open(f\"{bucket_name}/combined_df_store_{store}.pkl\", \"rb\") as f:\n", - " df = cudf.DataFrame(pd.read_pickle(f))\n", - " for train_mask, valid_mask in cv_folds:\n", - " df_train = df[\n", - " (df[\"day_id\"] >= train_mask[0]) & (df[\"day_id\"] < train_mask[1])\n", - " ]\n", - " df_valid = df[\n", - " (df[\"day_id\"] >= valid_mask[0]) & (df[\"day_id\"] < valid_mask[1])\n", - " ]\n", - "\n", - " X_train, y_train = (\n", - " df_train.drop(\n", - " columns=[\"item_id\", \"dept_id\", \"cat_id\", \"day_id\", \"sales\"]\n", - " ),\n", - " df_train[\"sales\"],\n", - " )\n", - " X_valid = df_valid.drop(\n", - " columns=[\"item_id\", \"dept_id\", \"cat_id\", \"day_id\", \"sales\"]\n", - " )\n", - "\n", - " clf = xgb.XGBRegressor(**params)\n", - " clf.fit(X_train, y_train)\n", - " pred_sales = clf.predict(X_valid)\n", - " scores[store_id].append(\n", - " wrmsse(product_weights, df, pred_sales, train_mask, valid_mask)\n", - " )\n", - " del df_train, df_valid, X_train, y_train, clf\n", - " gc.collect()\n", - " del df\n", - " gc.collect()\n", - "\n", - " # We can sum WRMSSE scores over data segments because data segments contain disjoint sets of time series\n", - " return np.array(scores).sum(axis=0).mean()" - ] - }, - { - "cell_type": "markdown", - "id": "78b96b48-f94b-47c1-ac27-cfab20b8561c", - "metadata": {}, - "source": [ - "Using the Dask cluster client, we execute multiple training jobs in parallel. Optuna keeps track of the progress in the hyperparameter search using in-memory Dask storage." - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "09af0be0-4a72-4dc2-bd60-ff1d8152385b", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/tmp/ipykernel_3102/456600745.py:7: ExperimentalWarning: DaskStorage is experimental (supported from v3.1.0). The interface can change in the future.\n", - " dask_storage = optuna.integration.DaskStorage(storage=backend_storage, client=client)\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Testing hyperparameter combinations 0..3\n", - "Best cross-validation metric: 9.689218658589553, Time elapsed = 491.4758385729983\n", - "Testing hyperparameter combinations 3..6\n", - "Best cross-validation metric: 9.689218658589553, Time elapsed = 1047.8801612580028\n", - "Testing hyperparameter combinations 6..9\n", - "Best cross-validation metric: 9.689218658589553, Time elapsed = 1650.7563961980013\n", - "Total time elapsed = 1650.7610972189977\n" - ] - } - ], - "source": [ - "##### Number of hyperparameter combinations to try in parallel\n", - "n_trials = 9 # Using a small n_trials so that the demo can finish quickly\n", - "# n_trials = 100\n", - "\n", - "# Optimize in parallel on your Dask cluster\n", - "backend_storage = optuna.storages.InMemoryStorage()\n", - "dask_storage = optuna.integration.DaskStorage(storage=backend_storage, client=client)\n", - "study = optuna.create_study(\n", - " direction=\"minimize\",\n", - " sampler=optuna.samplers.RandomSampler(seed=0),\n", - " storage=dask_storage,\n", - ")\n", - "futures = []\n", - "for i in range(0, n_trials, n_workers):\n", - " iter_range = (i, min([i + n_workers, n_trials]))\n", - " futures.append(\n", - " {\n", - " \"range\": iter_range,\n", - " \"futures\": [\n", - " client.submit(study.optimize, objective, n_trials=1, pure=False)\n", - " for _ in range(*iter_range)\n", - " ],\n", - " }\n", - " )\n", - "\n", - "tstart = time.perf_counter()\n", - "for partition in futures:\n", - " iter_range = partition[\"range\"]\n", - " print(f\"Testing hyperparameter combinations {iter_range[0]}..{iter_range[1]}\")\n", - " _ = wait(partition[\"futures\"])\n", - " for fut in partition[\"futures\"]:\n", - " _ = fut.result() # Ensure that the training job was successful\n", - " tnow = time.perf_counter()\n", - " print(\n", - " f\"Best cross-validation metric: {study.best_value}, Time elapsed = {tnow - tstart}\"\n", - " )\n", - "tend = time.perf_counter()\n", - "print(f\"Total time elapsed = {tend - tstart}\")" - ] - }, - { - "cell_type": "markdown", - "id": "01e6da28-42e6-4e7e-9fa9-340539254e5e", - "metadata": {}, - "source": [ - "Once the hyperparameter search is complete, we fetch the optimal hyperparameter combination using the attributes of the `study` object." - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "3d11f3ac-76f5-48d2-bbb9-109d38b1f261", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "data": { - "text/plain": [ - "{'lambda': 0.003077053443211648,\n", - " 'alpha': 0.14187101103672142,\n", - " 'colsample_bytree': 0.682210700857315,\n", - " 'max_depth': 4,\n", - " 'min_child_weight': 0.00017240426024865184,\n", - " 'gamma': 0.0014694435419424668,\n", - " 'tweedie_variance_power': 1.4375872112626924}" - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "study.best_params" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "ac74f348-c1cd-4dbf-918b-52292d874a16", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "data": { - "text/plain": [ - "FrozenTrial(number=0, state=TrialState.COMPLETE, values=[9.689218658589553], datetime_start=datetime.datetime(2023, 7, 27, 1, 39, 4, 604443), datetime_complete=datetime.datetime(2023, 7, 27, 1, 47, 9, 804887), params={'lambda': 0.003077053443211648, 'alpha': 0.14187101103672142, 'colsample_bytree': 0.682210700857315, 'max_depth': 4, 'min_child_weight': 0.00017240426024865184, 'gamma': 0.0014694435419424668, 'tweedie_variance_power': 1.4375872112626924}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'lambda': FloatDistribution(high=100.0, log=True, low=1e-08, step=None), 'alpha': FloatDistribution(high=100.0, log=True, low=1e-08, step=None), 'colsample_bytree': FloatDistribution(high=1.0, log=False, low=0.2, step=None), 'max_depth': IntDistribution(high=6, log=False, low=2, step=1), 'min_child_weight': FloatDistribution(high=100.0, log=True, low=1e-08, step=None), 'gamma': FloatDistribution(high=1.0, log=True, low=1e-08, step=None), 'tweedie_variance_power': FloatDistribution(high=2.0, log=False, low=1.0, step=None)}, trial_id=0, value=None)" - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "study.best_trial" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "dbc8d725-5aff-4339-bc45-d43f80aab454", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "data": { - "text/plain": [ - "{'lambda': 0.003077053443211648,\n", - " 'alpha': 0.14187101103672142,\n", - " 'colsample_bytree': 0.682210700857315,\n", - " 'max_depth': 4,\n", - " 'min_child_weight': 0.00017240426024865184,\n", - " 'gamma': 0.0014694435419424668,\n", - " 'tweedie_variance_power': 1.4375872112626924}" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# Make a deep copy to preserve the dictionary after deleting the Dask cluster\n", - "best_params = copy.deepcopy(study.best_params)\n", - "best_params" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "id": "aa630ad6-e011-4652-b51c-1f2b8ce6564f", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "fs = gcsfs.GCSFileSystem()\n", - "with fs.open(f\"{bucket_name}/params.json\", \"w\") as f:\n", - " json.dump(best_params, f)" - ] - }, - { - "cell_type": "markdown", - "id": "b5d02e80-c4e8-4445-bd0e-d05b9009c07f", - "metadata": {}, - "source": [ - "## Train the final XGBoost model and evaluate" - ] - }, - { - "cell_type": "markdown", - "id": "ca9a2b78-8e08-47d9-9136-76faa8d21761", - "metadata": {}, - "source": [ - "Using the optimal hyperparameters found in the search, fit a new model using the whole training data. As in the previous section, we fit a separate XGBoost model per data segment." - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "3b665d27-8148-4959-ba6d-bdd54b18a311", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "fs = gcsfs.GCSFileSystem()\n", - "with fs.open(f\"{bucket_name}/params.json\", \"r\") as f:\n", - " best_params = json.load(f)\n", - "with fs.open(f\"{bucket_name}/product_weights.pkl\", \"rb\") as f:\n", - " product_weights = cudf.DataFrame(pd.read_pickle(f))" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "id": "17c25588-b1e1-4a53-afe1-478bab418793", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "def final_train(best_params):\n", - " fs = gcsfs.GCSFileSystem()\n", - " params = {\n", - " \"n_estimators\": 100,\n", - " \"verbosity\": 0,\n", - " \"learning_rate\": 0.01,\n", - " \"objective\": \"reg:tweedie\",\n", - " \"tree_method\": \"gpu_hist\",\n", - " \"grow_policy\": \"depthwise\",\n", - " \"predictor\": \"gpu_predictor\",\n", - " \"enable_categorical\": True,\n", - " }\n", - " params.update(best_params)\n", - " model = {}\n", - " train_mask = [0, 1914]\n", - "\n", - " for store in STORES:\n", - " print(f\"Processing store {store}...\")\n", - " with fs.open(f\"{bucket_name}/combined_df_store_{store}.pkl\", \"rb\") as f:\n", - " df = cudf.DataFrame(pd.read_pickle(f))\n", - "\n", - " df_train = df[(df[\"day_id\"] >= train_mask[0]) & (df[\"day_id\"] < train_mask[1])]\n", - " X_train, y_train = (\n", - " df_train.drop(columns=[\"item_id\", \"dept_id\", \"cat_id\", \"day_id\", \"sales\"]),\n", - " df_train[\"sales\"],\n", - " )\n", - "\n", - " clf = xgb.XGBRegressor(**params)\n", - " clf.fit(X_train, y_train)\n", - " model[store] = clf\n", - " del df\n", - " gc.collect()\n", - "\n", - " return model" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "id": "3170cd9d-cfb3-4de4-9eb8-80f0e96438c3", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Processing store CA_1...\n", - "Processing store CA_2...\n", - "Processing store CA_3...\n", - "Processing store CA_4...\n", - "Processing store TX_1...\n", - "Processing store TX_2...\n", - "Processing store TX_3...\n", - "Processing store WI_1...\n", - "Processing store WI_2...\n", - "Processing store WI_3...\n" - ] - } - ], - "source": [ - "model = final_train(best_params)" - ] - }, - { - "cell_type": "markdown", - "id": "fa83a711-879a-44be-bfc6-70eaf0691934", - "metadata": {}, - "source": [ - "Let's now evaluate the final model using the held-out test set:" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "id": "7426afcd-5e22-4901-84fe-6155b172f8c4", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "WRMSSE metric on the held-out test set: 10.495262182826213\n" - ] - } - ], - "source": [ - "test_wrmsse = 0\n", - "for store in STORES:\n", - " with fs.open(f\"{bucket_name}/combined_df_store_{store}.pkl\", \"rb\") as f:\n", - " df = cudf.DataFrame(pd.read_pickle(f))\n", - " df_test = df[(df[\"day_id\"] >= holdout[0]) & (df[\"day_id\"] < holdout[1])]\n", - " X_test = df_test.drop(columns=[\"item_id\", \"dept_id\", \"cat_id\", \"day_id\", \"sales\"])\n", - " pred_sales = model[store].predict(X_test)\n", - " test_wrmsse += wrmsse(\n", - " product_weights, df, pred_sales, train_mask=[0, 1914], valid_mask=holdout\n", - " )\n", - "print(f\"WRMSSE metric on the held-out test set: {test_wrmsse}\")" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "id": "e6c804ce-3ac2-475e-a86a-3e6acf2f49ce", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "# Save the model to the Cloud Storage\n", - "with fs.open(f\"{bucket_name}/final_model.pkl\", \"wb\") as f:\n", - " pickle.dump(model, f)" - ] - }, - { - "cell_type": "markdown", - "id": "bb62e220-d69a-42f4-90ee-acee6ef6144a", - "metadata": {}, - "source": [ - "## Create an ensemble model using a different strategy for segmenting sales data" - ] - }, - { - "cell_type": "markdown", - "id": "d5714912-df6b-4686-83a0-2d6424e4d522", - "metadata": {}, - "source": [ - "It is common to create an **ensemble model** where multiple machine learning methods are used to obtain better predictive performance. Prediction is made from an ensemble model by averaging the prediction output of the constituent models.\n", - "\n", - "In this example, we will create a second model by segmenting the sales data in a different way. Instead of splitting by stores, we will split the data by both stores and product categories." - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "id": "8461f1ca-4521-44f9-8a2b-0666b684b72d", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "def objective_alt(trial):\n", - " fs = gcsfs.GCSFileSystem()\n", - " with fs.open(f\"{bucket_name}/product_weights.pkl\", \"rb\") as f:\n", - " product_weights = cudf.DataFrame(pd.read_pickle(f))\n", - " params = {\n", - " \"n_estimators\": 100,\n", - " \"verbosity\": 0,\n", - " \"learning_rate\": 0.01,\n", - " \"objective\": \"reg:tweedie\",\n", - " \"tree_method\": \"gpu_hist\",\n", - " \"grow_policy\": \"depthwise\",\n", - " \"predictor\": \"gpu_predictor\",\n", - " \"enable_categorical\": True,\n", - " \"lambda\": trial.suggest_float(\"lambda\", 1e-8, 100.0, log=True),\n", - " \"alpha\": trial.suggest_float(\"alpha\", 1e-8, 100.0, log=True),\n", - " \"colsample_bytree\": trial.suggest_float(\"colsample_bytree\", 0.2, 1.0),\n", - " \"max_depth\": trial.suggest_int(\"max_depth\", 2, 6, step=1),\n", - " \"min_child_weight\": trial.suggest_float(\n", - " \"min_child_weight\", 1e-8, 100, log=True\n", - " ),\n", - " \"gamma\": trial.suggest_float(\"gamma\", 1e-8, 1.0, log=True),\n", - " \"tweedie_variance_power\": trial.suggest_float(\"tweedie_variance_power\", 1, 2),\n", - " }\n", - " scores = [[] for i in range(len(STORES) * len(DEPTS))]\n", - "\n", - " for store_id, store in enumerate(STORES):\n", - " for dept_id, dept in enumerate(DEPTS):\n", - " print(f\"Processing store {store}, department {dept}...\")\n", - " with fs.open(\n", - " f\"{bucket_name}/combined_df_store_{store}_dept_{dept}.pkl\", \"rb\"\n", - " ) as f:\n", - " df = cudf.DataFrame(pd.read_pickle(f))\n", - " for train_mask, valid_mask in cv_folds:\n", - " df_train = df[\n", - " (df[\"day_id\"] >= train_mask[0]) & (df[\"day_id\"] < train_mask[1])\n", - " ]\n", - " df_valid = df[\n", - " (df[\"day_id\"] >= valid_mask[0]) & (df[\"day_id\"] < valid_mask[1])\n", - " ]\n", - "\n", - " X_train, y_train = (\n", - " df_train.drop(columns=[\"item_id\", \"cat_id\", \"day_id\", \"sales\"]),\n", - " df_train[\"sales\"],\n", - " )\n", - " X_valid = df_valid.drop(\n", - " columns=[\"item_id\", \"cat_id\", \"day_id\", \"sales\"]\n", - " )\n", - "\n", - " clf = xgb.XGBRegressor(**params)\n", - " clf.fit(X_train, y_train)\n", - " sales_pred = clf.predict(X_valid)\n", - " scores[store_id * len(DEPTS) + dept_id].append(\n", - " wrmsse(product_weights, df, sales_pred, train_mask, valid_mask)\n", - " )\n", - " del df_train, df_valid, X_train, y_train, clf\n", - " gc.collect()\n", - " del df\n", - " gc.collect()\n", - "\n", - " # We can sum WRMSSE scores over data segments because data segments contain disjoint sets of time series\n", - " return np.array(scores).sum(axis=0).mean()" - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "id": "3e5c34ba-9709-4c71-bfcc-dca214891a2f", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/tmp/ipykernel_3102/383703293.py:7: ExperimentalWarning: DaskStorage is experimental (supported from v3.1.0). The interface can change in the future.\n", - " dask_storage = optuna.integration.DaskStorage(storage=backend_storage, client=client)\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Testing hyperparameter combinations 0..3\n", - "Best cross-validation metric: 9.657402162051978, Time elapsed = 663.513354638002\n", - "Testing hyperparameter combinations 3..6\n", - "Best cross-validation metric: 9.657402162051978, Time elapsed = 1379.8620550880005\n", - "Testing hyperparameter combinations 6..9\n", - "Best cross-validation metric: 9.657402162051978, Time elapsed = 2183.6284268570016\n", - "Total time elapsed = 2183.632464492999\n" - ] - } - ], - "source": [ - "##### Number of hyperparameter combinations to try in parallel\n", - "n_trials = 9 # Using a small n_trials so that the demo can finish quickly\n", - "# n_trials = 100\n", - "\n", - "# Optimize in parallel on your Dask cluster\n", - "backend_storage = optuna.storages.InMemoryStorage()\n", - "dask_storage = optuna.integration.DaskStorage(storage=backend_storage, client=client)\n", - "study = optuna.create_study(\n", - " direction=\"minimize\",\n", - " sampler=optuna.samplers.RandomSampler(seed=0),\n", - " storage=dask_storage,\n", - ")\n", - "futures = []\n", - "for i in range(0, n_trials, n_workers):\n", - " iter_range = (i, min([i + n_workers, n_trials]))\n", - " futures.append(\n", - " {\n", - " \"range\": iter_range,\n", - " \"futures\": [\n", - " client.submit(study.optimize, objective_alt, n_trials=1, pure=False)\n", - " for _ in range(*iter_range)\n", - " ],\n", - " }\n", - " )\n", - "\n", - "tstart = time.perf_counter()\n", - "for partition in futures:\n", - " iter_range = partition[\"range\"]\n", - " print(f\"Testing hyperparameter combinations {iter_range[0]}..{iter_range[1]}\")\n", - " _ = wait(partition[\"futures\"])\n", - " for fut in partition[\"futures\"]:\n", - " _ = fut.result() # Ensure that the training job was successful\n", - " tnow = time.perf_counter()\n", - " print(\n", - " f\"Best cross-validation metric: {study.best_value}, Time elapsed = {tnow - tstart}\"\n", - " )\n", - "tend = time.perf_counter()\n", - "print(f\"Total time elapsed = {tend - tstart}\")" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "id": "0ce96895-8ef8-4897-bf52-b6bac69188c2", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'lambda': 0.003077053443211648,\n", - " 'alpha': 0.14187101103672142,\n", - " 'colsample_bytree': 0.682210700857315,\n", - " 'max_depth': 4,\n", - " 'min_child_weight': 0.00017240426024865184,\n", - " 'gamma': 0.0014694435419424668,\n", - " 'tweedie_variance_power': 1.4375872112626924}" - ] - }, - "execution_count": 27, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# Make a deep copy to preserve the dictionary after deleting the Dask cluster\n", - "best_params_alt = copy.deepcopy(study.best_params)\n", - "best_params_alt" - ] - }, - { - "cell_type": "code", - "execution_count": 28, - "id": "9841bf88-24ff-4fda-991f-108291866a43", - "metadata": {}, - "outputs": [], - "source": [ - "fs = gcsfs.GCSFileSystem()\n", - "with fs.open(f\"{bucket_name}/params_alt.json\", \"w\") as f:\n", - " json.dump(best_params_alt, f)" - ] - }, - { - "cell_type": "markdown", - "id": "e04f712f-b7d6-4063-979d-22358da9fbc0", - "metadata": {}, - "source": [ - "Using the optimal hyperparameters found in the search, fit a new model using the whole training data." - ] - }, - { - "cell_type": "code", - "execution_count": 33, - "id": "44bbde76-93e6-4c69-bb12-d71e93e775db", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "def final_train_alt(best_params):\n", - " fs = gcsfs.GCSFileSystem()\n", - " with fs.open(f\"{bucket_name}/product_weights.pkl\", \"rb\") as f:\n", - " product_weights = cudf.DataFrame(pd.read_pickle(f))\n", - " params = {\n", - " \"n_estimators\": 100,\n", - " \"verbosity\": 0,\n", - " \"learning_rate\": 0.01,\n", - " \"objective\": \"reg:tweedie\",\n", - " \"tree_method\": \"gpu_hist\",\n", - " \"grow_policy\": \"depthwise\",\n", - " \"predictor\": \"gpu_predictor\",\n", - " \"enable_categorical\": True,\n", - " }\n", - " params.update(best_params)\n", - " model = {}\n", - " train_mask = [0, 1914]\n", - "\n", - " for store_id, store in enumerate(STORES):\n", - " for dept_id, dept in enumerate(DEPTS):\n", - " print(f\"Processing store {store}, department {dept}...\")\n", - " with fs.open(\n", - " f\"{bucket_name}/combined_df_store_{store}_dept_{dept}.pkl\", \"rb\"\n", - " ) as f:\n", - " df = cudf.DataFrame(pd.read_pickle(f))\n", - " for train_mask, valid_mask in cv_folds:\n", - " df_train = df[\n", - " (df[\"day_id\"] >= train_mask[0]) & (df[\"day_id\"] < train_mask[1])\n", - " ]\n", - " X_train, y_train = (\n", - " df_train.drop(columns=[\"item_id\", \"cat_id\", \"day_id\", \"sales\"]),\n", - " df_train[\"sales\"],\n", - " )\n", - "\n", - " clf = xgb.XGBRegressor(**params)\n", - " clf.fit(X_train, y_train)\n", - " model[(store, dept)] = clf\n", - " del df\n", - " gc.collect()\n", - "\n", - " return model" - ] - }, - { - "cell_type": "code", - "execution_count": 34, - "id": "0af52a5c-4d7c-454c-90a9-c47f570d6457", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "fs = gcsfs.GCSFileSystem()\n", - "with fs.open(f\"{bucket_name}/params_alt.json\", \"r\") as f:\n", - " best_params_alt = json.load(f)\n", - "with fs.open(f\"{bucket_name}/product_weights.pkl\", \"rb\") as f:\n", - " product_weights = cudf.DataFrame(pd.read_pickle(f))" - ] - }, - { - "cell_type": "code", - "execution_count": 35, - "id": "fdd3b3a4-9e5a-4ebe-b138-57774c8cacf9", - "metadata": { - "scrolled": true, - "tags": [] - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Processing store CA_1, department HOBBIES_1...\n", - "Processing store CA_1, department HOBBIES_2...\n", - "Processing store CA_1, department HOUSEHOLD_1...\n", - "Processing store CA_1, department HOUSEHOLD_2...\n", - "Processing store CA_1, department FOODS_1...\n", - "Processing store CA_1, department FOODS_2...\n", - "Processing store CA_1, department FOODS_3...\n", - "Processing store CA_2, department HOBBIES_1...\n", - "Processing store CA_2, department HOBBIES_2...\n", - "Processing store CA_2, department HOUSEHOLD_1...\n", - "Processing store CA_2, department HOUSEHOLD_2...\n", - "Processing store CA_2, department FOODS_1...\n", - "Processing store CA_2, department FOODS_2...\n", - "Processing store CA_2, department FOODS_3...\n", - "Processing store CA_3, department HOBBIES_1...\n", - "Processing store CA_3, department HOBBIES_2...\n", - "Processing store CA_3, department HOUSEHOLD_1...\n", - "Processing store CA_3, department HOUSEHOLD_2...\n", - "Processing store CA_3, department FOODS_1...\n", - "Processing store CA_3, department FOODS_2...\n", - "Processing store CA_3, department FOODS_3...\n", - "Processing store CA_4, department HOBBIES_1...\n", - "Processing store CA_4, department HOBBIES_2...\n", - "Processing store CA_4, department HOUSEHOLD_1...\n", - "Processing store CA_4, department HOUSEHOLD_2...\n", - "Processing store CA_4, department FOODS_1...\n", - "Processing store CA_4, department FOODS_2...\n", - "Processing store CA_4, department FOODS_3...\n", - "Processing store TX_1, department HOBBIES_1...\n", - "Processing store TX_1, department HOBBIES_2...\n", - "Processing store TX_1, department HOUSEHOLD_1...\n", - "Processing store TX_1, department HOUSEHOLD_2...\n", - "Processing store TX_1, department FOODS_1...\n", - "Processing store TX_1, department FOODS_2...\n", - "Processing store TX_1, department FOODS_3...\n", - "Processing store TX_2, department HOBBIES_1...\n", - "Processing store TX_2, department HOBBIES_2...\n", - "Processing store TX_2, department HOUSEHOLD_1...\n", - "Processing store TX_2, department HOUSEHOLD_2...\n", - "Processing store TX_2, department FOODS_1...\n", - "Processing store TX_2, department FOODS_2...\n", - "Processing store TX_2, department FOODS_3...\n", - "Processing store TX_3, department HOBBIES_1...\n", - "Processing store TX_3, department HOBBIES_2...\n", - "Processing store TX_3, department HOUSEHOLD_1...\n", - "Processing store TX_3, department HOUSEHOLD_2...\n", - "Processing store TX_3, department FOODS_1...\n", - "Processing store TX_3, department FOODS_2...\n", - "Processing store TX_3, department FOODS_3...\n", - "Processing store WI_1, department HOBBIES_1...\n", - "Processing store WI_1, department HOBBIES_2...\n", - "Processing store WI_1, department HOUSEHOLD_1...\n", - "Processing store WI_1, department HOUSEHOLD_2...\n", - "Processing store WI_1, department FOODS_1...\n", - "Processing store WI_1, department FOODS_2...\n", - "Processing store WI_1, department FOODS_3...\n", - "Processing store WI_2, department HOBBIES_1...\n", - "Processing store WI_2, department HOBBIES_2...\n", - "Processing store WI_2, department HOUSEHOLD_1...\n", - "Processing store WI_2, department HOUSEHOLD_2...\n", - "Processing store WI_2, department FOODS_1...\n", - "Processing store WI_2, department FOODS_2...\n", - "Processing store WI_2, department FOODS_3...\n", - "Processing store WI_3, department HOBBIES_1...\n", - "Processing store WI_3, department HOBBIES_2...\n", - "Processing store WI_3, department HOUSEHOLD_1...\n", - "Processing store WI_3, department HOUSEHOLD_2...\n", - "Processing store WI_3, department FOODS_1...\n", - "Processing store WI_3, department FOODS_2...\n", - "Processing store WI_3, department FOODS_3...\n" - ] - } - ], - "source": [ - "model_alt = final_train_alt(best_params_alt)" - ] - }, - { - "cell_type": "code", - "execution_count": 57, - "id": "b00cf96a-8589-4ee6-8c29-51a5aa81af2c", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "# Save the model to the Cloud Storage\n", - "with fs.open(f\"{bucket_name}/final_model_alt.pkl\", \"wb\") as f:\n", - " pickle.dump(model_alt, f)" - ] - }, - { - "cell_type": "markdown", - "id": "49fe134a-ebd8-4322-887f-ffcb97a380a2", - "metadata": {}, - "source": [ - "Now consider an ensemble consisting of the two models `model` and `model_alt`. We evaluate the ensemble by computing the WRMSSE metric for the average of the predictions of the two models." - ] - }, - { - "cell_type": "code", - "execution_count": 58, - "id": "99495acc-0355-40cf-935c-91658a6909ea", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Processing store CA_1...\n", - "Processing store CA_2...\n", - "Processing store CA_3...\n", - "Processing store CA_4...\n", - "Processing store TX_1...\n", - "Processing store TX_2...\n", - "Processing store TX_3...\n", - "Processing store WI_1...\n", - "Processing store WI_2...\n", - "Processing store WI_3...\n", - "WRMSSE metric on the held-out test set: 11.055364531163706\n" - ] - } - ], - "source": [ - "test_wrmsse = 0\n", - "for store in STORES:\n", - " print(f\"Processing store {store}...\")\n", - " # Prediction from Model 1\n", - " with fs.open(f\"{bucket_name}/combined_df_store_{store}.pkl\", \"rb\") as f:\n", - " df = cudf.DataFrame(pd.read_pickle(f))\n", - " df_test = df[(df[\"day_id\"] >= holdout[0]) & (df[\"day_id\"] < holdout[1])]\n", - " X_test = df_test.drop(columns=[\"item_id\", \"dept_id\", \"cat_id\", \"day_id\", \"sales\"])\n", - " df_test[\"pred1\"] = model[store].predict(X_test)\n", - "\n", - " # Prediction from Model 2\n", - " df_test[\"pred2\"] = [np.nan] * len(df_test)\n", - " df_test[\"pred2\"] = df_test[\"pred2\"].astype(\"float32\")\n", - " for dept in DEPTS:\n", - " with fs.open(\n", - " f\"{bucket_name}/combined_df_store_{store}_dept_{dept}.pkl\", \"rb\"\n", - " ) as f:\n", - " df2 = cudf.DataFrame(pd.read_pickle(f))\n", - " df2_test = df2[(df2[\"day_id\"] >= holdout[0]) & (df2[\"day_id\"] < holdout[1])]\n", - " X_test = df2_test.drop(columns=[\"item_id\", \"cat_id\", \"day_id\", \"sales\"])\n", - " assert np.sum(df_test[\"dept_id\"] == dept) == len(X_test)\n", - " df_test[\"pred2\"][df_test[\"dept_id\"] == dept] = model_alt[(store, dept)].predict(\n", - " X_test\n", - " )\n", - "\n", - " # Average prediction\n", - " df_test[\"avg_pred\"] = (df_test[\"pred1\"] + df_test[\"pred2\"]) / 2.0\n", - "\n", - " test_wrmsse += wrmsse(\n", - " product_weights,\n", - " df,\n", - " df_test[\"avg_pred\"],\n", - " train_mask=[0, 1914],\n", - " valid_mask=holdout,\n", - " )\n", - "print(f\"WRMSSE metric on the held-out test set: {test_wrmsse}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "fda530fd-9d8e-4200-bc08-20432f4ffa53", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "# Close the Dask cluster to clean up\n", - "cluster.close()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ae38a58b-5c5e-41b7-adcb-06cbfd81bbfc", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.12" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} From 4ee82d9b60cb599fe0926833db26e5fe174ad3f2 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Thu, 28 Sep 2023 17:22:02 -0700 Subject: [PATCH 08/11] Rename notebook --- source/examples/index.md | 2 +- .../{start_here.ipynb => notebook.ipynb} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename source/examples/time-series-forecasting-with-hpo/{start_here.ipynb => notebook.ipynb} (100%) diff --git a/source/examples/index.md b/source/examples/index.md index 9b250a5a..450b3955 100644 --- a/source/examples/index.md +++ b/source/examples/index.md @@ -14,6 +14,6 @@ rapids-ec2-mnmg/notebook rapids-autoscaling-multi-tenant-kubernetes/notebook xgboost-randomforest-gpu-hpo-dask/notebook rapids-azureml-hpo/notebook -time-series-forecasting-with-hpo/start_here +time-series-forecasting-with-hpo/notebook xgboost-rf-gpu-cpu-benchmark/notebook ``` diff --git a/source/examples/time-series-forecasting-with-hpo/start_here.ipynb b/source/examples/time-series-forecasting-with-hpo/notebook.ipynb similarity index 100% rename from source/examples/time-series-forecasting-with-hpo/start_here.ipynb rename to source/examples/time-series-forecasting-with-hpo/notebook.ipynb From e550651ec10c58ce1823bcd4c32f43430eaf17ad Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Thu, 28 Sep 2023 17:31:41 -0700 Subject: [PATCH 09/11] Small fixes --- .../time-series-forecasting-with-hpo/notebook.ipynb | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/source/examples/time-series-forecasting-with-hpo/notebook.ipynb b/source/examples/time-series-forecasting-with-hpo/notebook.ipynb index 6afba35b..7d1b2307 100644 --- a/source/examples/time-series-forecasting-with-hpo/notebook.ipynb +++ b/source/examples/time-series-forecasting-with-hpo/notebook.ipynb @@ -37,16 +37,18 @@ "id": "fde94ab8-123c-4a72-95db-86a2098247bb", "metadata": {}, "source": [ - "To run the example, you will need a working Google Kubernetes Engine (GKE) cluster with access to NVIDIA GPUs. Use the following resources to set up a cluster:\n", + "To run the example, you will need a working Google Kubernetes Engine (GKE) cluster with access to NVIDIA GPUs.\n", "\n", "````{docref} /cloud/gcp/gke\n", "Set up a Google Kubernetes Engine (GKE) cluster with access to NVIDIA GPUs. Follow instructions in [Google Kubernetes Engine](../../cloud/gcp/gke).\n", "````\n", "\n", - "To ensure that the example runs smoothly, ensure that you have ample memory in your GPUs. This notebook has been tested with NVIDIA A100.\n", + "1. To ensure that the example runs smoothly, ensure that you have ample memory in your GPUs. This notebook has been tested with NVIDIA A100.\n", "\n", - "* [Install the Dask-Kubernetes operator](https://kubernetes.dask.org/en/latest/operator_installation.html)\n", - "* [Install Kubeflow](https://www.kubeflow.org/docs/started/installing-kubeflow/)\n", + "2. Set up Dask-Kubernetes integration by following instructions in the following guides:\n", + "\n", + " * [Install the Dask-Kubernetes operator](https://kubernetes.dask.org/en/latest/operator_installation.html)\n", + " * [Install Kubeflow](https://www.kubeflow.org/docs/started/installing-kubeflow/)\n", "\n", "Kubeflow is not strictly necessary, but we highly recommend it, as Kubeflow gives you a nice notebook environment to run this notebook within the k8s cluster. (You may choose any method; we tested this example after installing Kubeflow from manifests.) When creating the notebook environment, use the following configuration:\n", "\n", @@ -54,7 +56,7 @@ "* 1 NVIDIA GPU\n", "* 40 GiB disk volume\n", "\n", - "After uploading all the notebooks in the example, run this notebook (`time_series_forecasting.ipynb`) in the notebook environment.\n", + "After uploading all the notebooks in the example, run this notebook (`notebook.ipynb`) in the notebook environment.\n", "\n", "Note: We will use the worker pods to speed up the training stage. The preprocessing steps will run solely on the scheduler node." ] From 67fa69c11389ba1e7c55081fb75b3d03ddfeb842 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Thu, 28 Sep 2023 17:42:14 -0700 Subject: [PATCH 10/11] Add tags --- .../notebook.ipynb | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/source/examples/time-series-forecasting-with-hpo/notebook.ipynb b/source/examples/time-series-forecasting-with-hpo/notebook.ipynb index 7d1b2307..35a83aad 100644 --- a/source/examples/time-series-forecasting-with-hpo/notebook.ipynb +++ b/source/examples/time-series-forecasting-with-hpo/notebook.ipynb @@ -3,7 +3,20 @@ { "cell_type": "markdown", "id": "671dd603-6b51-46b2-98b3-2b05c7c92c38", - "metadata": {}, + "metadata": { + "tags": [ + "platform/kubernetes", + "cloud/gcp/gke", + "tools/dask-operator", + "workflow/hpo", + "workflow/xgboost", + "library/dask", + "library/dask-cuda", + "library/xgboost", + "library/optuna", + "data-storage/gcs" + ] + }, "source": [ "# Perform time series forecasting on Google Kubernetes Engine with NVIDIA GPUs" ] From 88ec68e94ef9e2aa45a9ee73e0a2c7845cc905cf Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Tue, 10 Oct 2023 10:06:28 -0700 Subject: [PATCH 11/11] Remove DISABLE_JUPYTER --- source/examples/time-series-forecasting-with-hpo/notebook.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/examples/time-series-forecasting-with-hpo/notebook.ipynb b/source/examples/time-series-forecasting-with-hpo/notebook.ipynb index 35a83aad..129de3f9 100644 --- a/source/examples/time-series-forecasting-with-hpo/notebook.ipynb +++ b/source/examples/time-series-forecasting-with-hpo/notebook.ipynb @@ -6684,7 +6684,7 @@ " worker_command=\"dask-cuda-worker\",\n", " n_workers=n_workers,\n", " resources={\"limits\": {\"nvidia.com/gpu\": \"1\"}},\n", - " env={\"DISABLE_JUPYTER\": \"true\", \"EXTRA_PIP_PACKAGES\": \"optuna gcsfs\"},\n", + " env={\"EXTRA_PIP_PACKAGES\": \"optuna gcsfs\"},\n", ")" ] },