diff --git a/mlx/c/array.cpp b/mlx/c/array.cpp index 4debd71..f12281c 100644 --- a/mlx/c/array.cpp +++ b/mlx/c/array.cpp @@ -79,6 +79,10 @@ extern "C" mlx_array mlx_array_from_data( } } +extern "C" size_t mlx_array_dtype_size(mlx_array_dtype dtype) { + return MLX_CPP_ARRAY_DTYPE(dtype).size(); +} + extern "C" size_t mlx_array_itemsize(mlx_array arr) { return MLX_CPP_ARRAY(arr).itemsize(); } diff --git a/mlx/c/array.h b/mlx/c/array.h index 3b27df9..e30def1 100644 --- a/mlx/c/array.h +++ b/mlx/c/array.h @@ -74,6 +74,11 @@ mlx_array mlx_array_from_data( int dim, mlx_array_dtype dtype); +/** + * The size of one element of the given datatype in bytes. + */ +size_t mlx_array_dtype_size(mlx_array_dtype dtype); + /** * The size of the array's datatype in bytes. */