Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

⚗️ Benchmark experiments loading ERA5 Zarr data using kvikIO #4

Merged
merged 1 commit into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 182 additions & 0 deletions 1_benchmark_kvikIOzarr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
"""
Experiments on WeatherBench2, loading with kvikIO or Zarr engine.

python 1_benchmark_kvikIOzarr.py

References:
- https://weatherbench2.readthedocs.io/en/latest/index.html
"""
import time

import cupy
import lightning as L
import torch
import torchdata
import torchdata.dataloader2
import tqdm
import xarray as xr
import zen3geo


# %%
def sel_datavars(
dataset: xr.Dataset,
data_vars: list = ["geopotential", "u_component_of_wind", "v_component_of_wind"],
) -> xr.Dataset:
"""
Select specific data variables from an xarray.Dataset object.
"""
return dataset.get(key=data_vars) # .sel(level=500)


def xarray_to_tensor_collate_fn(
samples: torchdata.datapipes.DataChunk,
) -> (torch.Tensor, torch.Tensor, list[dict]):
"""
Converts individual xarray.Dataset objects to torch.Tensor (float32 dtype),
and stack them all into a single torch.Tensor. Also outputs a metadata list
of dictionaries that contains the timestamp of the xarray.Dataset's data
variable.
"""
tensor_t0: torch.Tensor = torch.stack(
tensors=[
torch.as_tensor(data=sample.isel(time=0).to_array().data, device="cuda")
for sample in samples
]
)
tensor_t1: torch.Tensor = torch.stack(
tensors=[
torch.as_tensor(data=sample.isel(time=1).to_array().data, device="cuda")
for sample in samples
]
)
metadata: list[dict] = [
{
"unixtime0": torch.as_tensor(
data=sample.time.isel(time=0).astype(dtype="int64").data, device="cuda"
),
"unixtime1": torch.as_tensor(
data=sample.time.isel(time=1).astype(dtype="int64").data, device="cuda"
),
}
for sample in samples
]

return tensor_t0, tensor_t1, metadata


# %%
class WeatherBench2DataModule(L.LightningDataModule):
"""
LightningDataModule to load WeatherBench2 data from Zarr.
"""

def __init__(
self,
zarr_store: str = "2020-full_37-6h-0p25deg-chunk-1_zuv500.zarr",
# zarr_store: str = "gs://weatherbench2/datasets/era5/1959-2022-full_37-6h-0p25deg-chunk-1.zarr-v2",
engine: str = "kvikio",
batch_size: int = 32,
):
"""
Go from a Zarr datacube to 6-hourly time-slice chips!

Also does mini-batching.

Parameters
----------
zarr_stores : str
A path or URL to a Zarr stores to read from. E.g. ``store1.zarr``.
See list of available WeatherBench2 Zarr stores at
https://weatherbench2.readthedocs.io/en/latest/data-guide.html

engine : str
The engine to use in `xr.open_dataset`. E.g. ``kvikio``, ``zarr``.
Default is ``kvikio``.

batch_size : int
Size of each mini-batch. Default is 32.

Returns
-------
datapipe : torchdata.datapipes.iter.IterDataPipe
A torch DataPipe that can be passed into a torch DataLoader.
"""
super().__init__()
self.zarr_store: str = zarr_store
self.engine: str = engine
self.batch_size: int = batch_size

print(f"Loading data using {self.engine} engine")

def setup(self, stage: str | None = None) -> torchdata.datapipes.iter.IterDataPipe:
"""
Data operations to perform on every GPU.
Split data into training and test sets, etc.

Returns
-------
datapipes : IterDataPipe
The torch DataPipe object to iterate over the training set.
"""
# Step 0 - Iterate through all the Zarr stores
dp_source: torchdata.datapipes.iter.IterDataPipe = (
torchdata.datapipes.iter.IterableWrapper(iterable=[self.zarr_store])
)

# Step 1 - Create WeatherBench2 xarray.Dataset chips (full metadata)
dp_weather_chips: torchdata.datapipes.iter.IterDataPipe = (
# Step 1.0 - Open each Zarr store using xarray
dp_source.read_from_xpystac(
engine=self.engine, chunks=None, consolidated=False
)
# Step 1.1 - Select desired data variables at 500hPa
.map(fn=sel_datavars)
# Step 1.2 - Slice datacube along time-dimension into 12 hour chunks (2x 6-hourly)
.slice_with_xbatcher(
input_dims={"latitude": 721, "longitude": 1440, "time": 2},
preload_batch=False,
Copy link
Owner Author

@weiji14 weiji14 Oct 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like kvikio engine (11s) is slightly faster than zarr (14s) when preload_batch=False, whereas zarr (10s) is slightly faster than kvikio (11s) when the default preload_batch=True is set. Maybe because the loading from dask.Array objects is not so optimized for kvikIO compared to Zarr yet?

For benchmark purposes though, it's probably best to disable this preload_batch setting since it's somewhat like a cache, and we want to look at raw IO speed. And yes, the timings are probably not significantly different, so I'll run it over more epochs as mentioned at #4 (comment) to get a better average time.

)
)

# Step 2 - Train/validation split each chip based on geography
# TODO

# Step 3 - Batch and split ERA5 chips into Machine Learning format
self.datapipe_train = (
# Step 3.1 - Create mini-batches (default is 32)
dp_weather_chips.batch(batch_size=self.batch_size)
# Step 3.2 - Convert xarray.Dataset to torch.Tensor and stack
.collate(collate_fn=xarray_to_tensor_collate_fn)
)

def train_dataloader(self) -> torchdata.dataloader2.DataLoader2:
"""
Loads the data used in the training loop.
"""
return torchdata.dataloader2.DataLoader2(datapipe=self.datapipe_train)


# %%
if __name__ == "__main__":
# Optimize torch performance
torch.set_float32_matmul_precision(precision="medium")

# Setup data
datamodule: L.LightningDataModule = WeatherBench2DataModule(engine="kvikio")
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently changing the engine manually here (either kvikio or zarr). Should make an CLI flag to set the engine here.

datamodule.setup()
train_dataloader = datamodule.train_dataloader()

# Start timing
tic = time.perf_counter()

# Training loop
for i, batch in tqdm.tqdm(iterable=enumerate(train_dataloader), total=23):
input, target, metadata = batch
# Compute Mean Squared Error loss between t=0 and t=1, just for fun
loss: torch.Tensor = torch.functional.F.mse_loss(input=input, target=target)
print(f"Batch {i}, MSE Loss: {loss}")
Comment on lines +173 to +178
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO train on more than just 1 epoch (maybe 100?) to get a nicer average result comparing between zarr and kvikio.


# Stop timing
toc = time.perf_counter()
print(f"Total: {toc - tic:0.4f} seconds")
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,15 @@ This will save a one year subset of the WeatherBench2 ERA5 dataset to your
local disk. It will include data at pressure level 500hPa, with the variables
'geopotential', 'u_component_of_wind', and 'v_component_of_wind' only.

To run the benchmark experiment loading with the kvikIO engine, run:

python 1_benchmark_kvikIOzarr.py

This will print out a progress bar showing the ERA5 data being loaded in
mini-batches (simulating a neural network training loop), and a total count
for the time taken to finish. One 'epoch' should take under 15 seconds on an
Ampere generation (e.g. RTX A2000) NVIDIA GPU.

# References

## Links
Expand Down