Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parameter persistence with sharding support #338

Open
jonatanklosko opened this issue Feb 15, 2024 · 4 comments
Open

Parameter persistence with sharding support #338

jonatanklosko opened this issue Feb 15, 2024 · 4 comments
Labels
kind:feature New feature or request note:discussion Details up for discussion

Comments

@jonatanklosko
Copy link
Member

jonatanklosko commented Feb 15, 2024

Currently whenever we load a model, we need to convert their layout from whatever PyTorch uses to whatever Axon uses (mostly transposition of dense and conv layers). For smaller models this is quick, however for large models this: (a) introduces loading overhead; (b) consumes much memory (this prevents from loading params directly onto the GPU, which would make sense in a single-GPU use case) (fixed in #344).

Ideally we would have an easy way to persist the loaded parameters into multiple files (in case of large parameters). With that, the user could do Bumblebee.load_model/2, persist the parameters into a file, then in production load the parameters directly without the conversion overhead (possibly straight onto the GPU).

This probably belongs in Axon directly, but may as well track here given the use case. I also wonder if we should be using Safetensors rather than term-to-binary for better portability. One issue with Safetensors is that it supports flat map, but Axon parameters can be any Nx.Container (e.g. LSTM uses tuples), so unless we make Axon parameters more strict we can't really do it.

This also depends on elixir-nx/axon#553, which changes params into a struct, and we likely want to persist the whole struct.

@jonatanklosko jonatanklosko added kind:feature New feature or request note:discussion Details up for discussion labels Feb 15, 2024
@josevalim
Copy link
Contributor

The flat parameters should not really be a problem, should it? You could convert a nested map of keys “foo” and “bar” into a special flattened key, such as “foo——bar”, no?

@jonatanklosko
Copy link
Member Author

@josevalim the nested map is not a problem, it's other Nx.Containers (currently tuples), so it may make sense to restrict Axon parameters to tensors.

@jonatanklosko
Copy link
Member Author

Sidenote: sharding is a nice-to-have, but with elixir-nx/safetensors#8 we should be able to write all parameters into a single file efficiently.

@jonatanklosko
Copy link
Member Author

With #344 the main motivation (excessive memory usage) is addressed, so this is less of a priority. It would still reduce some time overhead necessary for transforming the params. Either way, we should have a good way of persisting large parameters (again, rather in Axon).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
kind:feature New feature or request note:discussion Details up for discussion
Projects
None yet
Development

No branches or pull requests

2 participants