-
Notifications
You must be signed in to change notification settings - Fork 72
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] TensorDict.consolidate #814
Conversation
tensordict/base.py
Outdated
offset = v.storage_offset() | ||
stride = v.stride() | ||
if offset or stride != 1: | ||
content = v.clone().untyped_storage() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be nice to be able to get a view of the storage directly
Do we have that?
cc @albanD @mikaylagawarecki
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well a "view of a storage" has a name, it's a "Tensor" :D
If you need a contiguous chunk of memory containing these values, you will need a copy indeed and clone() is the way to go!
Note that you might want to pass in memory_format=torch.contiguous_format if you want a row major storage. Otherwise, you might get a channel last here as the default is to preserve the memory format.
TODO:
|
if total_key[-1].startswith("<NJT_OFFSETS"): | ||
offsets = flat_key_values[total_key] | ||
key = key.replace("<NJT_OFFSETS>", "") | ||
value = torch.nested.nested_tensor_from_jagged(nested_values, offsets) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: eventually it'll likely be useful to support NJT's lengths
as well, which is an optional set of metadata present for non-contiguous NJTs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you have an example of such a thing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure thing, here's a test case demonstrating construction with lengths: https://github.com/pytorch/pytorch/blob/d52684e9a8b95f8b0d06f0bf08e6a39846cb3ae6/test/test_nestedtensor.py#L4242
When both offsets
(shape B + 1
including the end offset) and lengths
(shape B
) are specified, the idea is that offsets
point to the beginning of each batch item sequence and lengths
indicates the length of each. This allows for representing NJTs that are "non-contiguous with holes". If only offsets
is specified, the assumption is that the entire length between offsets is part of each batch item, which of course doesn't allow for such holes.
|
Name | Max | Mean | Ops | Ops on Repo HEAD
|
Change |
---|---|---|---|---|---|
test_plain_set_nested | 43.4220μs | 16.9290μs | 59.0703 KOps/s | 65.0464 KOps/s | |
test_plain_set_stack_nested | 0.2190ms | 18.7837μs | 53.2375 KOps/s | 63.3650 KOps/s | |
test_plain_set_nested_inplace | 72.9870μs | 19.3373μs | 51.7134 KOps/s | 56.2881 KOps/s | |
test_plain_set_stack_nested_inplace | 65.7740μs | 19.3632μs | 51.6444 KOps/s | 56.5907 KOps/s | |
test_items | 18.9560μs | 2.7379μs | 365.2402 KOps/s | 391.7611 KOps/s | |
test_items_nested | 0.6251ms | 0.2774ms | 3.6049 KOps/s | 3.6500 KOps/s | |
test_items_nested_locked | 0.4627ms | 0.2761ms | 3.6221 KOps/s | 3.6220 KOps/s | |
test_items_nested_leaf | 0.5305ms | 78.1901μs | 12.7893 KOps/s | 12.6702 KOps/s | |
test_items_stack_nested | 0.6662ms | 0.2820ms | 3.5460 KOps/s | 3.6396 KOps/s | |
test_items_stack_nested_leaf | 0.1516ms | 78.8556μs | 12.6814 KOps/s | 12.1246 KOps/s | |
test_items_stack_nested_locked | 0.5232ms | 0.2818ms | 3.5492 KOps/s | 3.6070 KOps/s | |
test_keys | 40.2450μs | 3.8410μs | 260.3482 KOps/s | 258.0627 KOps/s | |
test_keys_nested | 0.2236ms | 0.1382ms | 7.2383 KOps/s | 7.2557 KOps/s | |
test_keys_nested_locked | 1.8924ms | 0.1421ms | 7.0385 KOps/s | 6.9729 KOps/s | |
test_keys_nested_leaf | 0.1954ms | 0.1168ms | 8.5638 KOps/s | 8.4732 KOps/s | |
test_keys_stack_nested | 0.2632ms | 0.1380ms | 7.2467 KOps/s | 7.2096 KOps/s | |
test_keys_stack_nested_leaf | 0.1940ms | 0.1165ms | 8.5841 KOps/s | 8.5619 KOps/s | |
test_keys_stack_nested_locked | 0.2716ms | 0.1433ms | 6.9805 KOps/s | 6.9838 KOps/s | |
test_values | 10.3745μs | 1.1408μs | 876.5977 KOps/s | 886.0799 KOps/s | |
test_values_nested | 0.1030ms | 50.2454μs | 19.9023 KOps/s | 19.9900 KOps/s | |
test_values_nested_locked | 0.1063ms | 49.8776μs | 20.0491 KOps/s | 19.6765 KOps/s | |
test_values_nested_leaf | 96.8730μs | 45.5500μs | 21.9539 KOps/s | 22.1089 KOps/s | |
test_values_stack_nested | 0.1234ms | 51.2046μs | 19.5295 KOps/s | 19.7817 KOps/s | |
test_values_stack_nested_leaf | 81.2430μs | 45.9924μs | 21.7427 KOps/s | 22.1695 KOps/s | |
test_values_stack_nested_locked | 0.1070ms | 50.7105μs | 19.7198 KOps/s | 19.7070 KOps/s | |
test_membership | 14.5180μs | 1.3387μs | 746.9694 KOps/s | 747.5523 KOps/s | |
test_membership_nested | 24.4260μs | 3.4373μs | 290.9224 KOps/s | 283.4041 KOps/s | |
test_membership_nested_leaf | 32.7520μs | 3.6036μs | 277.4973 KOps/s | 282.6925 KOps/s | |
test_membership_stacked_nested | 22.3620μs | 3.4710μs | 288.0978 KOps/s | 275.6100 KOps/s | |
test_membership_stacked_nested_leaf | 33.4660μs | 3.3717μs | 296.5854 KOps/s | 285.6704 KOps/s | |
test_membership_nested_last | 33.1830μs | 4.1477μs | 241.0975 KOps/s | 235.3935 KOps/s | |
test_membership_nested_leaf_last | 40.9170μs | 4.1415μs | 241.4602 KOps/s | 237.2536 KOps/s | |
test_membership_stacked_nested_last | 22.6420μs | 4.6763μs | 213.8459 KOps/s | 237.9690 KOps/s | |
test_membership_stacked_nested_leaf_last | 23.7550μs | 4.7765μs | 209.3593 KOps/s | 236.2330 KOps/s | |
test_nested_getleaf | 57.3680μs | 10.5205μs | 95.0526 KOps/s | 93.5018 KOps/s | |
test_nested_get | 34.0550μs | 9.8531μs | 101.4907 KOps/s | 99.3775 KOps/s | |
test_stacked_getleaf | 49.8140μs | 10.4885μs | 95.3421 KOps/s | 94.5139 KOps/s | |
test_stacked_get | 43.4220μs | 9.8181μs | 101.8529 KOps/s | 99.1389 KOps/s | |
test_nested_getitemleaf | 52.7690μs | 11.0262μs | 90.6930 KOps/s | 91.1412 KOps/s | |
test_nested_getitem | 49.4230μs | 10.1650μs | 98.3763 KOps/s | 95.7216 KOps/s | |
test_stacked_getitemleaf | 76.4730μs | 10.9125μs | 91.6381 KOps/s | 91.3376 KOps/s | |
test_stacked_getitem | 55.9660μs | 10.2204μs | 97.8436 KOps/s | 98.3052 KOps/s | |
test_lock_nested | 1.0450ms | 0.3407ms | 2.9348 KOps/s | 2.9415 KOps/s | |
test_lock_stack_nested | 0.7890ms | 0.3149ms | 3.1759 KOps/s | 3.2146 KOps/s | |
test_unlock_nested | 0.8273ms | 0.3466ms | 2.8854 KOps/s | 2.8869 KOps/s | |
test_unlock_stack_nested | 0.4641ms | 0.3195ms | 3.1300 KOps/s | 3.1120 KOps/s | |
test_flatten_speed | 0.4427ms | 96.8796μs | 10.3221 KOps/s | 10.0021 KOps/s | |
test_unflatten_speed | 0.8856ms | 0.4155ms | 2.4068 KOps/s | 2.3949 KOps/s | |
test_common_ops | 3.2860ms | 0.7425ms | 1.3468 KOps/s | 1.5006 KOps/s | |
test_creation | 19.0360μs | 1.8738μs | 533.6743 KOps/s | 522.0483 KOps/s | |
test_creation_empty | 30.2570μs | 11.4212μs | 87.5562 KOps/s | 123.9516 KOps/s | |
test_creation_nested_1 | 41.8090μs | 14.2660μs | 70.0969 KOps/s | 92.2576 KOps/s | |
test_creation_nested_2 | 68.3580μs | 17.2887μs | 57.8411 KOps/s | 71.3705 KOps/s | |
test_clone | 62.1470μs | 13.2800μs | 75.3011 KOps/s | 73.0715 KOps/s | |
test_getitem[int] | 37.5810μs | 11.3907μs | 87.7909 KOps/s | 89.0418 KOps/s | |
test_getitem[slice_int] | 54.5320μs | 22.7303μs | 43.9941 KOps/s | 44.2242 KOps/s | |
test_getitem[range] | 83.2970μs | 60.8884μs | 16.4235 KOps/s | 15.5893 KOps/s | |
test_getitem[tuple] | 53.0500μs | 19.2587μs | 51.9245 KOps/s | 53.3465 KOps/s | |
test_getitem[list] | 0.1175ms | 41.3017μs | 24.2121 KOps/s | 24.5596 KOps/s | |
test_setitem_dim[int] | 73.6790μs | 37.9921μs | 26.3212 KOps/s | 32.1556 KOps/s | |
test_setitem_dim[slice_int] | 0.1339ms | 64.9067μs | 15.4067 KOps/s | 17.2817 KOps/s | |
test_setitem_dim[range] | 0.1583ms | 88.3648μs | 11.3167 KOps/s | 12.3871 KOps/s | |
test_setitem_dim[tuple] | 84.8590μs | 53.9393μs | 18.5394 KOps/s | 21.5754 KOps/s | |
test_setitem | 82.4110μs | 20.2072μs | 49.4874 KOps/s | 52.6856 KOps/s | |
test_set | 77.3380μs | 20.1520μs | 49.6229 KOps/s | 54.1608 KOps/s | |
test_set_shared | 1.1434ms | 0.1441ms | 6.9397 KOps/s | 6.9962 KOps/s | |
test_update | 0.1735ms | 22.9222μs | 43.6258 KOps/s | 51.6338 KOps/s | |
test_update_nested | 76.4460μs | 31.7547μs | 31.4914 KOps/s | 34.8937 KOps/s | |
test_update__nested | 83.3890μs | 24.9232μs | 40.1232 KOps/s | 39.5220 KOps/s | |
test_set_nested | 76.0030μs | 21.4571μs | 46.6046 KOps/s | 50.9132 KOps/s | |
test_set_nested_new | 66.1550μs | 25.7568μs | 38.8247 KOps/s | 41.3957 KOps/s | |
test_select | 0.1102ms | 41.1621μs | 24.2942 KOps/s | 24.7429 KOps/s | |
test_select_nested | 0.1032ms | 58.9348μs | 16.9679 KOps/s | 16.4169 KOps/s | |
test_exclude_nested | 0.2179ms | 0.1189ms | 8.4121 KOps/s | 8.3531 KOps/s | |
test_empty[True] | 0.4755ms | 0.3912ms | 2.5560 KOps/s | 2.4855 KOps/s | |
test_empty[False] | 10.2642μs | 1.1666μs | 857.1950 KOps/s | 841.3153 KOps/s | |
test_unbind_speed | 3.4869ms | 0.2603ms | 3.8415 KOps/s | 3.9083 KOps/s | |
test_unbind_speed_stack0 | 0.3945ms | 0.2516ms | 3.9747 KOps/s | 3.9466 KOps/s | |
test_unbind_speed_stack1 | 67.5512ms | 0.7404ms | 1.3507 KOps/s | 1.3486 KOps/s | |
test_split | 67.7288ms | 1.6091ms | 621.4534 Ops/s | 621.5672 Ops/s | |
test_chunk | 67.5346ms | 1.6043ms | 623.3227 Ops/s | 620.5256 Ops/s | |
test_creation[device0] | 0.1823ms | 85.7253μs | 11.6652 KOps/s | 11.8624 KOps/s | |
test_creation_from_tensor | 3.4401ms | 87.2705μs | 11.4586 KOps/s | 11.5101 KOps/s | |
test_add_one[memmap_tensor0] | 67.8170μs | 5.4767μs | 182.5933 KOps/s | 180.9841 KOps/s | |
test_contiguous[memmap_tensor0] | 20.0470μs | 0.6374μs | 1.5690 MOps/s | 1.5225 MOps/s | |
test_stack[memmap_tensor0] | 22.9230μs | 3.4703μs | 288.1586 KOps/s | 282.2621 KOps/s | |
test_memmaptd_index | 0.9376ms | 0.2545ms | 3.9297 KOps/s | 3.9290 KOps/s | |
test_memmaptd_index_astensor | 0.6790ms | 0.3285ms | 3.0441 KOps/s | 2.9594 KOps/s | |
test_memmaptd_index_op | 1.1870ms | 0.6276ms | 1.5934 KOps/s | 1.7337 KOps/s | |
test_serialize_model | 0.1021s | 99.1830ms | 10.0824 Ops/s | 8.6572 Ops/s | |
test_serialize_model_pickle | 0.4507s | 0.3762s | 2.6579 Ops/s | 2.6375 Ops/s | |
test_serialize_weights | 0.1022s | 97.5325ms | 10.2530 Ops/s | 9.4649 Ops/s | |
test_serialize_weights_returnearly | 0.1208s | 0.1161s | 8.6101 Ops/s | 7.2012 Ops/s | |
test_serialize_weights_pickle | 0.4824s | 0.4116s | 2.4294 Ops/s | 2.4287 Ops/s | |
test_serialize_weights_filesystem | 0.1121s | 97.7949ms | 10.2255 Ops/s | 10.5055 Ops/s | |
test_serialize_model_filesystem | 0.1729s | 0.1041s | 9.6047 Ops/s | 9.6728 Ops/s | |
test_reshape_pytree | 50.1140μs | 25.8527μs | 38.6807 KOps/s | 39.4043 KOps/s | |
test_reshape_td | 73.6190μs | 33.4043μs | 29.9363 KOps/s | 29.8799 KOps/s | |
test_view_pytree | 70.1230μs | 25.8468μs | 38.6896 KOps/s | 39.2833 KOps/s | |
test_view_td | 96.9820μs | 37.8814μs | 26.3982 KOps/s | 26.2599 KOps/s | |
test_unbind_pytree | 77.1450μs | 29.5268μs | 33.8676 KOps/s | 34.4862 KOps/s | |
test_unbind_td | 0.3783ms | 37.5019μs | 26.6653 KOps/s | 26.7485 KOps/s | |
test_split_pytree | 66.5350μs | 29.3139μs | 34.1135 KOps/s | 34.4114 KOps/s | |
test_split_td | 0.1182ms | 40.6097μs | 24.6247 KOps/s | 25.0570 KOps/s | |
test_add_pytree | 83.8370μs | 35.8910μs | 27.8621 KOps/s | 27.7999 KOps/s | |
test_add_td | 0.1277ms | 56.6789μs | 17.6433 KOps/s | 19.0739 KOps/s | |
test_distributed | 0.2244ms | 0.1047ms | 9.5525 KOps/s | 9.5834 KOps/s | |
test_tdmodule | 65.6230μs | 18.2734μs | 54.7244 KOps/s | 61.4726 KOps/s | |
test_tdmodule_dispatch | 63.0890μs | 36.6598μs | 27.2778 KOps/s | 31.8624 KOps/s | |
test_tdseq | 39.9750μs | 21.6338μs | 46.2240 KOps/s | 52.8730 KOps/s | |
test_tdseq_dispatch | 62.0470μs | 41.9593μs | 23.8326 KOps/s | 27.1485 KOps/s | |
test_instantiation_functorch | 1.5784ms | 1.3134ms | 761.3794 Ops/s | 749.4898 Ops/s | |
test_instantiation_td | 2.9647ms | 1.0196ms | 980.8008 Ops/s | 977.3345 Ops/s | |
test_exec_functorch | 0.3073ms | 0.1611ms | 6.2080 KOps/s | 6.1495 KOps/s | |
test_exec_functional_call | 0.4818ms | 0.1521ms | 6.5758 KOps/s | 6.5858 KOps/s | |
test_exec_td | 0.2334ms | 0.1478ms | 6.7672 KOps/s | 6.5305 KOps/s | |
test_exec_td_decorator | 0.8135ms | 0.2259ms | 4.4270 KOps/s | 4.4377 KOps/s | |
test_vmap_mlp_speed[True-True] | 0.6165ms | 0.4972ms | 2.0112 KOps/s | 2.0498 KOps/s | |
test_vmap_mlp_speed[True-False] | 0.8183ms | 0.5003ms | 1.9987 KOps/s | 1.9899 KOps/s | |
test_vmap_mlp_speed[False-True] | 0.6600ms | 0.4060ms | 2.4630 KOps/s | 2.4869 KOps/s | |
test_vmap_mlp_speed[False-False] | 0.9779ms | 0.4199ms | 2.3815 KOps/s | 2.4945 KOps/s | |
test_vmap_mlp_speed_decorator[True-True] | 1.2878ms | 0.5755ms | 1.7377 KOps/s | 1.7680 KOps/s | |
test_vmap_mlp_speed_decorator[True-False] | 0.8282ms | 0.5761ms | 1.7359 KOps/s | 1.7635 KOps/s | |
test_vmap_mlp_speed_decorator[False-True] | 0.7609ms | 0.4691ms | 2.1317 KOps/s | 2.1426 KOps/s | |
test_vmap_mlp_speed_decorator[False-False] | 0.9249ms | 0.4692ms | 2.1315 KOps/s | 2.1399 KOps/s | |
test_to_module_speed[True] | 2.6298ms | 1.6774ms | 596.1729 Ops/s | 588.4042 Ops/s | |
test_to_module_speed[False] | 2.1811ms | 1.6386ms | 610.2804 Ops/s | 602.5468 Ops/s | |
test_tc_init | 75.7530μs | 31.4529μs | 31.7936 KOps/s | 43.7118 KOps/s | |
test_tc_init_nested | 0.1447ms | 64.1801μs | 15.5811 KOps/s | 21.5780 KOps/s | |
test_tc_first_layer_tensor | 7.3040μs | 0.7097μs | 1.4090 MOps/s | 1.4047 MOps/s | |
test_tc_first_layer_nontensor | 2.7095μs | 0.6797μs | 1.4712 MOps/s | 1.4693 MOps/s | |
test_tc_second_layer_tensor | 43.1110μs | 1.8425μs | 542.7319 KOps/s | 549.5050 KOps/s | |
test_tc_second_layer_nontensor | 19.4860μs | 1.6805μs | 595.0527 KOps/s | 601.7405 KOps/s | |
test_unbind | 93.4236ms | 7.7441ms | 129.1299 Ops/s | 140.9556 Ops/s | |
test_full_like | 16.0136ms | 11.2676ms | 88.7503 Ops/s | 91.9865 Ops/s | |
test_zeros_like | 11.7348ms | 5.6212ms | 177.8965 Ops/s | 172.9256 Ops/s | |
test_ones_like | 11.4117ms | 6.1943ms | 161.4393 Ops/s | 152.3032 Ops/s | |
test_clone | 13.5948ms | 7.8982ms | 126.6106 Ops/s | 124.8532 Ops/s | |
test_squeeze | 68.5590μs | 13.8578μs | 72.1615 KOps/s | 71.7936 KOps/s | |
test_unsqueeze | 0.1094ms | 60.2247μs | 16.6045 KOps/s | 16.7907 KOps/s | |
test_split | 0.1840ms | 0.1128ms | 8.8619 KOps/s | 8.8456 KOps/s | |
test_permute | 0.2034ms | 0.1275ms | 7.8419 KOps/s | 7.8272 KOps/s | |
test_stack | 27.9748ms | 22.4954ms | 44.4535 Ops/s | 42.8654 Ops/s | |
test_cat | 26.5585ms | 22.6047ms | 44.2386 Ops/s | 42.3218 Ops/s |
|
Name | Max | Mean | Ops | Ops on Repo HEAD
|
Change |
---|---|---|---|---|---|
test_plain_set_nested | 24.5610μs | 13.3638μs | 74.8290 KOps/s | 84.5508 KOps/s | |
test_plain_set_stack_nested | 29.7600μs | 13.4681μs | 74.2496 KOps/s | 82.3225 KOps/s | |
test_plain_set_nested_inplace | 37.2510μs | 14.7638μs | 67.7333 KOps/s | 75.4902 KOps/s | |
test_plain_set_stack_nested_inplace | 32.0910μs | 14.7780μs | 67.6681 KOps/s | 75.2133 KOps/s | |
test_items | 29.8500μs | 4.6983μs | 212.8443 KOps/s | 211.3241 KOps/s | |
test_items_nested | 0.3944ms | 0.3423ms | 2.9218 KOps/s | 2.9309 KOps/s | |
test_items_nested_locked | 0.3794ms | 0.3404ms | 2.9377 KOps/s | 2.9101 KOps/s | |
test_items_nested_leaf | 0.1119ms | 82.6052μs | 12.1058 KOps/s | 12.0719 KOps/s | |
test_items_stack_nested | 0.4276ms | 0.3403ms | 2.9384 KOps/s | 2.8715 KOps/s | |
test_items_stack_nested_leaf | 0.1173ms | 85.3348μs | 11.7186 KOps/s | 11.8771 KOps/s | |
test_items_stack_nested_locked | 0.3940ms | 0.3463ms | 2.8877 KOps/s | 2.8909 KOps/s | |
test_keys | 16.6210μs | 4.3754μs | 228.5485 KOps/s | 226.3028 KOps/s | |
test_keys_nested | 91.4420μs | 69.3572μs | 14.4181 KOps/s | 14.7772 KOps/s | |
test_keys_nested_locked | 0.6711ms | 74.7867μs | 13.3714 KOps/s | 13.4133 KOps/s | |
test_keys_nested_leaf | 83.3520μs | 59.8996μs | 16.6946 KOps/s | 16.8238 KOps/s | |
test_keys_stack_nested | 93.0320μs | 68.9363μs | 14.5061 KOps/s | 14.8035 KOps/s | |
test_keys_stack_nested_leaf | 85.7830μs | 60.0135μs | 16.6629 KOps/s | 17.3523 KOps/s | |
test_keys_stack_nested_locked | 0.1024ms | 75.0754μs | 13.3199 KOps/s | 13.5881 KOps/s | |
test_values | 13.6450μs | 1.8447μs | 542.0925 KOps/s | 556.1785 KOps/s | |
test_values_nested | 59.5610μs | 35.7670μs | 27.9587 KOps/s | 27.9353 KOps/s | |
test_values_nested_locked | 61.2420μs | 37.7276μs | 26.5058 KOps/s | 26.4395 KOps/s | |
test_values_nested_leaf | 50.6920μs | 31.7544μs | 31.4917 KOps/s | 31.2478 KOps/s | |
test_values_stack_nested | 65.1120μs | 36.6199μs | 27.3076 KOps/s | 27.1275 KOps/s | |
test_values_stack_nested_leaf | 54.2810μs | 32.6134μs | 30.6622 KOps/s | 30.5220 KOps/s | |
test_values_stack_nested_locked | 68.4510μs | 38.4566μs | 26.0034 KOps/s | 26.0369 KOps/s | |
test_membership | 4.7273μs | 0.7230μs | 1.3831 MOps/s | 1.3287 MOps/s | |
test_membership_nested | 20.4100μs | 2.6428μs | 378.3802 KOps/s | 382.1033 KOps/s | |
test_membership_nested_leaf | 34.7210μs | 2.6103μs | 383.0937 KOps/s | 383.3434 KOps/s | |
test_membership_stacked_nested | 59.9510μs | 2.6667μs | 374.9905 KOps/s | 379.8160 KOps/s | |
test_membership_stacked_nested_leaf | 21.9890μs | 2.6236μs | 381.1561 KOps/s | 380.5488 KOps/s | |
test_membership_nested_last | 34.9610μs | 3.1441μs | 318.0536 KOps/s | 316.5679 KOps/s | |
test_membership_nested_leaf_last | 23.1400μs | 3.1620μs | 316.2532 KOps/s | 315.9758 KOps/s | |
test_membership_stacked_nested_last | 34.8310μs | 3.1595μs | 316.5057 KOps/s | 278.8927 KOps/s | |
test_membership_stacked_nested_leaf_last | 18.1200μs | 3.1455μs | 317.9142 KOps/s | 279.1427 KOps/s | |
test_nested_getleaf | 37.0700μs | 8.3163μs | 120.2456 KOps/s | 119.6132 KOps/s | |
test_nested_get | 24.0710μs | 7.8862μs | 126.8043 KOps/s | 127.3757 KOps/s | |
test_stacked_getleaf | 72.2910μs | 8.3856μs | 119.2522 KOps/s | 118.8385 KOps/s | |
test_stacked_get | 39.3700μs | 7.8811μs | 126.8864 KOps/s | 126.8712 KOps/s | |
test_nested_getitemleaf | 24.1500μs | 8.5387μs | 117.1135 KOps/s | 115.9477 KOps/s | |
test_nested_getitem | 33.7110μs | 8.0305μs | 124.5257 KOps/s | 123.4512 KOps/s | |
test_stacked_getitemleaf | 23.9210μs | 8.5638μs | 116.7710 KOps/s | 116.3016 KOps/s | |
test_stacked_getitem | 36.2310μs | 8.0535μs | 124.1697 KOps/s | 123.8736 KOps/s | |
test_lock_nested | 58.7555ms | 0.4201ms | 2.3805 KOps/s | 2.4530 KOps/s | |
test_lock_stack_nested | 0.3415ms | 0.3152ms | 3.1730 KOps/s | 3.2463 KOps/s | |
test_unlock_nested | 61.3299ms | 0.4191ms | 2.3861 KOps/s | 2.4391 KOps/s | |
test_unlock_stack_nested | 0.3611ms | 0.3229ms | 3.0966 KOps/s | 3.1688 KOps/s | |
test_flatten_speed | 0.4375ms | 0.1026ms | 9.7441 KOps/s | 9.7677 KOps/s | |
test_unflatten_speed | 0.3419ms | 0.2987ms | 3.3478 KOps/s | 3.3510 KOps/s | |
test_common_ops | 1.1731ms | 0.6285ms | 1.5911 KOps/s | 1.7706 KOps/s | |
test_creation | 37.0500μs | 1.6352μs | 611.5478 KOps/s | 601.1347 KOps/s | |
test_creation_empty | 26.4300μs | 9.5200μs | 105.0425 KOps/s | 146.6825 KOps/s | |
test_creation_nested_1 | 29.5800μs | 11.4262μs | 87.5185 KOps/s | 114.4068 KOps/s | |
test_creation_nested_2 | 35.9300μs | 13.6895μs | 73.0488 KOps/s | 91.2872 KOps/s | |
test_clone | 0.1157ms | 12.4712μs | 80.1848 KOps/s | 82.2242 KOps/s | |
test_getitem[int] | 30.1010μs | 11.6715μs | 85.6789 KOps/s | 89.8068 KOps/s | |
test_getitem[slice_int] | 52.4610μs | 22.4516μs | 44.5403 KOps/s | 46.1356 KOps/s | |
test_getitem[range] | 89.3320μs | 55.1968μs | 18.1170 KOps/s | 20.7828 KOps/s | |
test_getitem[tuple] | 38.6810μs | 19.9925μs | 50.0187 KOps/s | 51.4897 KOps/s | |
test_getitem[list] | 0.1122ms | 36.1518μs | 27.6612 KOps/s | 29.9965 KOps/s | |
test_setitem_dim[int] | 47.3210μs | 31.3387μs | 31.9094 KOps/s | 37.7022 KOps/s | |
test_setitem_dim[slice_int] | 73.6910μs | 53.1686μs | 18.8081 KOps/s | 21.0604 KOps/s | |
test_setitem_dim[range] | 89.3720μs | 69.9720μs | 14.2914 KOps/s | 15.3349 KOps/s | |
test_setitem_dim[tuple] | 66.8020μs | 46.1338μs | 21.6761 KOps/s | 24.0112 KOps/s | |
test_setitem | 42.8410μs | 18.1301μs | 55.1568 KOps/s | 62.1054 KOps/s | |
test_set | 50.9820μs | 17.5432μs | 57.0021 KOps/s | 64.6057 KOps/s | |
test_set_shared | 1.6765ms | 0.1001ms | 9.9940 KOps/s | 10.0862 KOps/s | |
test_update | 90.5030μs | 20.8117μs | 48.0500 KOps/s | 57.7072 KOps/s | |
test_update_nested | 77.1810μs | 26.5951μs | 37.6009 KOps/s | 44.1831 KOps/s | |
test_update__nested | 70.9310μs | 23.6050μs | 42.3639 KOps/s | 43.4563 KOps/s | |
test_set_nested | 66.3910μs | 18.9760μs | 52.6982 KOps/s | 59.7658 KOps/s | |
test_set_nested_new | 68.7420μs | 21.9009μs | 45.6602 KOps/s | 50.4502 KOps/s | |
test_select | 86.7310μs | 35.3916μs | 28.2553 KOps/s | 30.6814 KOps/s | |
test_select_nested | 0.6659ms | 56.8931μs | 17.5768 KOps/s | 17.8547 KOps/s | |
test_exclude_nested | 0.1801ms | 0.1112ms | 8.9953 KOps/s | 8.9227 KOps/s | |
test_empty[True] | 0.3947ms | 0.3502ms | 2.8555 KOps/s | 2.8473 KOps/s | |
test_empty[False] | 3.0280μs | 1.0154μs | 984.8386 KOps/s | 972.2961 KOps/s | |
test_to | 0.1042ms | 78.7360μs | 12.7007 KOps/s | 13.1643 KOps/s | |
test_to_nonblocking | 95.0420μs | 64.5730μs | 15.4864 KOps/s | 16.1539 KOps/s | |
test_unbind_speed | 0.3181ms | 0.2777ms | 3.6007 KOps/s | 3.7232 KOps/s | |
test_unbind_speed_stack0 | 0.3486ms | 0.2759ms | 3.6251 KOps/s | 3.7099 KOps/s | |
test_unbind_speed_stack1 | 75.4396ms | 0.8238ms | 1.2139 KOps/s | 1.2162 KOps/s | |
test_split | 75.3590ms | 1.8083ms | 552.9999 Ops/s | 571.0213 Ops/s | |
test_chunk | 1.7768ms | 1.6801ms | 595.1922 Ops/s | 617.3525 Ops/s | |
test_creation[device0] | 0.1275ms | 59.7726μs | 16.7301 KOps/s | 17.2265 KOps/s | |
test_creation_from_tensor | 0.1610ms | 55.3790μs | 18.0574 KOps/s | 17.6264 KOps/s | |
test_add_one[memmap_tensor0] | 76.9920μs | 7.3479μs | 136.0935 KOps/s | 132.0483 KOps/s | |
test_contiguous[memmap_tensor0] | 9.6600μs | 0.6552μs | 1.5262 MOps/s | 1.5219 MOps/s | |
test_stack[memmap_tensor0] | 31.9410μs | 5.5800μs | 179.2129 KOps/s | 194.1153 KOps/s | |
test_memmaptd_index | 1.1583ms | 0.3145ms | 3.1801 KOps/s | 3.3419 KOps/s | |
test_memmaptd_index_astensor | 0.6540ms | 0.3881ms | 2.5766 KOps/s | 2.4771 KOps/s | |
test_memmaptd_index_op | 1.1542ms | 0.7253ms | 1.3787 KOps/s | 1.5324 KOps/s | |
test_serialize_model | 0.1744s | 0.1040s | 9.6135 Ops/s | 8.5968 Ops/s | |
test_serialize_model_pickle | 1.3496s | 1.2351s | 0.8097 Ops/s | 0.8067 Ops/s | |
test_serialize_weights | 0.1724s | 0.1020s | 9.8009 Ops/s | 9.5686 Ops/s | |
test_serialize_weights_returnearly | 87.7731ms | 72.8831ms | 13.7206 Ops/s | 12.0681 Ops/s | |
test_serialize_weights_pickle | 1.3481s | 1.1698s | 0.8549 Ops/s | 0.8009 Ops/s | |
test_reshape_pytree | 50.0810μs | 27.2832μs | 36.6526 KOps/s | 37.5124 KOps/s | |
test_reshape_td | 0.2341ms | 33.1579μs | 30.1588 KOps/s | 31.1347 KOps/s | |
test_view_pytree | 52.8810μs | 27.1391μs | 36.8471 KOps/s | 38.1981 KOps/s | |
test_view_td | 0.2300ms | 38.3395μs | 26.0827 KOps/s | 25.0422 KOps/s | |
test_unbind_pytree | 57.8710μs | 32.8436μs | 30.4473 KOps/s | 30.7080 KOps/s | |
test_unbind_td | 0.4554ms | 43.3968μs | 23.0432 KOps/s | 23.6403 KOps/s | |
test_split_pytree | 63.9420μs | 36.5393μs | 27.3678 KOps/s | 28.1274 KOps/s | |
test_split_td | 0.2425ms | 42.6505μs | 23.4464 KOps/s | 24.3631 KOps/s | |
test_add_pytree | 79.0320μs | 39.3767μs | 25.3957 KOps/s | 25.4823 KOps/s | |
test_add_td | 0.3738ms | 53.6881μs | 18.6261 KOps/s | 21.5144 KOps/s | |
test_distributed | 1.5195ms | 72.8684μs | 13.7234 KOps/s | 14.7186 KOps/s | |
test_tdmodule | 30.2600μs | 15.5336μs | 64.3766 KOps/s | 71.0025 KOps/s | |
test_tdmodule_dispatch | 47.7220μs | 30.8669μs | 32.3971 KOps/s | 36.9315 KOps/s | |
test_tdseq | 31.8710μs | 17.0884μs | 58.5191 KOps/s | 64.2854 KOps/s | |
test_tdseq_dispatch | 49.6610μs | 33.3036μs | 30.0267 KOps/s | 33.0341 KOps/s | |
test_instantiation_functorch | 1.5140ms | 1.4380ms | 695.4272 Ops/s | 697.2978 Ops/s | |
test_instantiation_td | 1.4969ms | 0.9991ms | 1.0009 KOps/s | 998.1144 Ops/s | |
test_exec_functorch | 0.2272ms | 0.1526ms | 6.5510 KOps/s | 6.5879 KOps/s | |
test_exec_functional_call | 0.2042ms | 0.1476ms | 6.7770 KOps/s | 7.0866 KOps/s | |
test_exec_td | 0.1834ms | 0.1445ms | 6.9210 KOps/s | 7.1162 KOps/s | |
test_exec_td_decorator | 0.4047ms | 0.2159ms | 4.6309 KOps/s | 4.5915 KOps/s | |
test_vmap_mlp_speed[True-True] | 1.3955ms | 0.6014ms | 1.6629 KOps/s | 1.7180 KOps/s | |
test_vmap_mlp_speed[True-False] | 0.7211ms | 0.5874ms | 1.7025 KOps/s | 1.7259 KOps/s | |
test_vmap_mlp_speed[False-True] | 0.6398ms | 0.5315ms | 1.8816 KOps/s | 1.9459 KOps/s | |
test_vmap_mlp_speed[False-False] | 0.6130ms | 0.5303ms | 1.8857 KOps/s | 1.9475 KOps/s | |
test_vmap_mlp_speed_decorator[True-True] | 1.0169ms | 0.6492ms | 1.5403 KOps/s | 1.5024 KOps/s | |
test_vmap_mlp_speed_decorator[True-False] | 0.7522ms | 0.6457ms | 1.5487 KOps/s | 1.5464 KOps/s | |
test_vmap_mlp_speed_decorator[False-True] | 0.7397ms | 0.5678ms | 1.7613 KOps/s | 1.7444 KOps/s | |
test_vmap_mlp_speed_decorator[False-False] | 0.6843ms | 0.5665ms | 1.7654 KOps/s | 1.7495 KOps/s | |
test_vmap_transformer_speed[True-True] | 7.8796ms | 7.7120ms | 129.6688 Ops/s | 124.2438 Ops/s | |
test_vmap_transformer_speed[True-False] | 8.1541ms | 7.7593ms | 128.8780 Ops/s | 126.1784 Ops/s | |
test_vmap_transformer_speed[False-True] | 8.0496ms | 7.7658ms | 128.7697 Ops/s | 127.1041 Ops/s | |
test_vmap_transformer_speed[False-False] | 7.7065ms | 7.6234ms | 131.1754 Ops/s | 126.5953 Ops/s | |
test_vmap_transformer_speed_decorator[True-True] | 18.7108ms | 18.6302ms | 53.6762 Ops/s | 51.9538 Ops/s | |
test_vmap_transformer_speed_decorator[True-False] | 18.7658ms | 18.6493ms | 53.6212 Ops/s | 51.9454 Ops/s | |
test_vmap_transformer_speed_decorator[False-True] | 19.3984ms | 18.6391ms | 53.6505 Ops/s | 52.2029 Ops/s | |
test_vmap_transformer_speed_decorator[False-False] | 19.3457ms | 18.6760ms | 53.5448 Ops/s | 52.2122 Ops/s | |
test_to_module_speed[True] | 2.7817ms | 1.5195ms | 658.0960 Ops/s | 644.4116 Ops/s | |
test_to_module_speed[False] | 1.6257ms | 1.5063ms | 663.8648 Ops/s | 656.7854 Ops/s | |
test_tc_init | 51.4710μs | 26.8346μs | 37.2653 KOps/s | 46.7276 KOps/s | |
test_tc_init_nested | 90.6430μs | 54.5892μs | 18.3186 KOps/s | 22.2731 KOps/s | |
test_tc_first_layer_tensor | 0.7568μs | 0.3585μs | 2.7894 MOps/s | 2.7672 MOps/s | |
test_tc_first_layer_nontensor | 1.4200μs | 0.3864μs | 2.5882 MOps/s | 2.5488 MOps/s | |
test_tc_second_layer_tensor | 4.0262μs | 1.0205μs | 979.9558 KOps/s | 881.8082 KOps/s | |
test_tc_second_layer_nontensor | 4.6337μs | 0.8491μs | 1.1777 MOps/s | 1.1429 MOps/s | |
test_unbind | 0.1044s | 6.4875ms | 154.1438 Ops/s | 198.5695 Ops/s | |
test_full_like | 13.7326ms | 13.1930ms | 75.7980 Ops/s | 75.1571 Ops/s | |
test_zeros_like | 8.2989ms | 7.9172ms | 126.3080 Ops/s | 126.7885 Ops/s | |
test_ones_like | 8.3215ms | 7.8906ms | 126.7326 Ops/s | 125.8928 Ops/s | |
test_clone | 9.5156ms | 9.2769ms | 107.7947 Ops/s | 108.5138 Ops/s | |
test_squeeze | 96.5720μs | 10.8372μs | 92.2747 KOps/s | 90.0511 KOps/s | |
test_unsqueeze | 96.6620μs | 54.1298μs | 18.4741 KOps/s | 19.2624 KOps/s | |
test_split | 0.1411ms | 0.1018ms | 9.8207 KOps/s | 10.0716 KOps/s | |
test_permute | 0.1483ms | 0.1136ms | 8.8058 KOps/s | 8.8297 KOps/s | |
test_stack | 27.1761ms | 26.9349ms | 37.1266 Ops/s | 37.4528 Ops/s | |
test_cat | 28.0263ms | 26.8492ms | 37.2450 Ops/s | 37.5766 Ops/s |
# Conflicts: # tensordict/base.py # tensordict/utils.py
Description
Consolidates the storage of a tensordict in a single storage.
This is aimed at making serialization of a tensordict faster with mmap and pickle.
Micro-benchmarks: num_threads doesn't seem to help on devgpu (but there is some potential benefit if working with DTensors!)
This benchmark is a bit more realistic.
https://gist.github.com/vmoens/205f5173b3cba297389916655d07d1ce
The outcome is that the fastest serialization of the model is
tensordict.memmap_(filepath, num_threads=8)
with 150ms execution time. The second fastest istd.consolidate(filename=filepath)
. Unfortunatelytorch.save
performs equally bad (about 450-500ms) for both state-dicts and consolidated TDs in this example.cc @shagunsodhani @albanD @mikaylagawarecki @dstaay-fb @jsuarez5341 @dtsaras @jbschlosser @FrancescoSaverioZuppichini