-
-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
⚗️ Benchmark experiments loading ERA5 Zarr data using kvikIO #4
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
""" | ||
Experiments on WeatherBench2, loading with kvikIO or Zarr engine. | ||
|
||
python 1_benchmark_kvikIOzarr.py | ||
|
||
References: | ||
- https://weatherbench2.readthedocs.io/en/latest/index.html | ||
""" | ||
import time | ||
|
||
import cupy | ||
import lightning as L | ||
import torch | ||
import torchdata | ||
import torchdata.dataloader2 | ||
import tqdm | ||
import xarray as xr | ||
import zen3geo | ||
|
||
|
||
# %% | ||
def sel_datavars( | ||
dataset: xr.Dataset, | ||
data_vars: list = ["geopotential", "u_component_of_wind", "v_component_of_wind"], | ||
) -> xr.Dataset: | ||
""" | ||
Select specific data variables from an xarray.Dataset object. | ||
""" | ||
return dataset.get(key=data_vars) # .sel(level=500) | ||
|
||
|
||
def xarray_to_tensor_collate_fn( | ||
samples: torchdata.datapipes.DataChunk, | ||
) -> (torch.Tensor, torch.Tensor, list[dict]): | ||
""" | ||
Converts individual xarray.Dataset objects to torch.Tensor (float32 dtype), | ||
and stack them all into a single torch.Tensor. Also outputs a metadata list | ||
of dictionaries that contains the timestamp of the xarray.Dataset's data | ||
variable. | ||
""" | ||
tensor_t0: torch.Tensor = torch.stack( | ||
tensors=[ | ||
torch.as_tensor(data=sample.isel(time=0).to_array().data, device="cuda") | ||
for sample in samples | ||
] | ||
) | ||
tensor_t1: torch.Tensor = torch.stack( | ||
tensors=[ | ||
torch.as_tensor(data=sample.isel(time=1).to_array().data, device="cuda") | ||
for sample in samples | ||
] | ||
) | ||
metadata: list[dict] = [ | ||
{ | ||
"unixtime0": torch.as_tensor( | ||
data=sample.time.isel(time=0).astype(dtype="int64").data, device="cuda" | ||
), | ||
"unixtime1": torch.as_tensor( | ||
data=sample.time.isel(time=1).astype(dtype="int64").data, device="cuda" | ||
), | ||
} | ||
for sample in samples | ||
] | ||
|
||
return tensor_t0, tensor_t1, metadata | ||
|
||
|
||
# %% | ||
class WeatherBench2DataModule(L.LightningDataModule): | ||
""" | ||
LightningDataModule to load WeatherBench2 data from Zarr. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
zarr_store: str = "2020-full_37-6h-0p25deg-chunk-1_zuv500.zarr", | ||
# zarr_store: str = "gs://weatherbench2/datasets/era5/1959-2022-full_37-6h-0p25deg-chunk-1.zarr-v2", | ||
engine: str = "kvikio", | ||
batch_size: int = 32, | ||
): | ||
""" | ||
Go from a Zarr datacube to 6-hourly time-slice chips! | ||
|
||
Also does mini-batching. | ||
|
||
Parameters | ||
---------- | ||
zarr_stores : str | ||
A path or URL to a Zarr stores to read from. E.g. ``store1.zarr``. | ||
See list of available WeatherBench2 Zarr stores at | ||
https://weatherbench2.readthedocs.io/en/latest/data-guide.html | ||
|
||
engine : str | ||
The engine to use in `xr.open_dataset`. E.g. ``kvikio``, ``zarr``. | ||
Default is ``kvikio``. | ||
|
||
batch_size : int | ||
Size of each mini-batch. Default is 32. | ||
|
||
Returns | ||
------- | ||
datapipe : torchdata.datapipes.iter.IterDataPipe | ||
A torch DataPipe that can be passed into a torch DataLoader. | ||
""" | ||
super().__init__() | ||
self.zarr_store: str = zarr_store | ||
self.engine: str = engine | ||
self.batch_size: int = batch_size | ||
|
||
print(f"Loading data using {self.engine} engine") | ||
|
||
def setup(self, stage: str | None = None) -> torchdata.datapipes.iter.IterDataPipe: | ||
""" | ||
Data operations to perform on every GPU. | ||
Split data into training and test sets, etc. | ||
|
||
Returns | ||
------- | ||
datapipes : IterDataPipe | ||
The torch DataPipe object to iterate over the training set. | ||
""" | ||
# Step 0 - Iterate through all the Zarr stores | ||
dp_source: torchdata.datapipes.iter.IterDataPipe = ( | ||
torchdata.datapipes.iter.IterableWrapper(iterable=[self.zarr_store]) | ||
) | ||
|
||
# Step 1 - Create WeatherBench2 xarray.Dataset chips (full metadata) | ||
dp_weather_chips: torchdata.datapipes.iter.IterDataPipe = ( | ||
# Step 1.0 - Open each Zarr store using xarray | ||
dp_source.read_from_xpystac( | ||
engine=self.engine, chunks=None, consolidated=False | ||
) | ||
# Step 1.1 - Select desired data variables at 500hPa | ||
.map(fn=sel_datavars) | ||
# Step 1.2 - Slice datacube along time-dimension into 12 hour chunks (2x 6-hourly) | ||
.slice_with_xbatcher( | ||
input_dims={"latitude": 721, "longitude": 1440, "time": 2}, | ||
preload_batch=False, | ||
) | ||
) | ||
|
||
# Step 2 - Train/validation split each chip based on geography | ||
# TODO | ||
|
||
# Step 3 - Batch and split ERA5 chips into Machine Learning format | ||
self.datapipe_train = ( | ||
# Step 3.1 - Create mini-batches (default is 32) | ||
dp_weather_chips.batch(batch_size=self.batch_size) | ||
# Step 3.2 - Convert xarray.Dataset to torch.Tensor and stack | ||
.collate(collate_fn=xarray_to_tensor_collate_fn) | ||
) | ||
|
||
def train_dataloader(self) -> torchdata.dataloader2.DataLoader2: | ||
""" | ||
Loads the data used in the training loop. | ||
""" | ||
return torchdata.dataloader2.DataLoader2(datapipe=self.datapipe_train) | ||
|
||
|
||
# %% | ||
if __name__ == "__main__": | ||
# Optimize torch performance | ||
torch.set_float32_matmul_precision(precision="medium") | ||
|
||
# Setup data | ||
datamodule: L.LightningDataModule = WeatherBench2DataModule(engine="kvikio") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Currently changing the engine manually here (either |
||
datamodule.setup() | ||
train_dataloader = datamodule.train_dataloader() | ||
|
||
# Start timing | ||
tic = time.perf_counter() | ||
|
||
# Training loop | ||
for i, batch in tqdm.tqdm(iterable=enumerate(train_dataloader), total=23): | ||
input, target, metadata = batch | ||
# Compute Mean Squared Error loss between t=0 and t=1, just for fun | ||
loss: torch.Tensor = torch.functional.F.mse_loss(input=input, target=target) | ||
print(f"Batch {i}, MSE Loss: {loss}") | ||
Comment on lines
+173
to
+178
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TODO train on more than just 1 epoch (maybe 100?) to get a nicer average result comparing between |
||
|
||
# Stop timing | ||
toc = time.perf_counter() | ||
print(f"Total: {toc - tic:0.4f} seconds") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like
kvikio
engine (11s) is slightly faster thanzarr
(14s) whenpreload_batch=False
, whereaszarr
(10s) is slightly faster thankvikio
(11s) when the defaultpreload_batch=True
is set. Maybe because the loading from dask.Array objects is not so optimized for kvikIO compared to Zarr yet?For benchmark purposes though, it's probably best to disable this
preload_batch
setting since it's somewhat like a cache, and we want to look at raw IO speed. And yes, the timings are probably not significantly different, so I'll run it over more epochs as mentioned at #4 (comment) to get a better average time.