From f9ef888acf253a89235e0ff3b5b228b9b4001164 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 5 Jul 2024 13:25:19 +0100 Subject: [PATCH] [BugFix] Fix key ordering in pointwise ops (#855) --- tensordict/base.py | 109 ++++++++++++++++++++++++++-------------- test/test_tensordict.py | 9 ++++ 2 files changed, 80 insertions(+), 38 deletions(-) diff --git a/tensordict/base.py b/tensordict/base.py index 124815a95..ccea68481 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -4689,14 +4689,28 @@ def _values_list( self, include_nested: bool = False, leaves_only: bool = False, + *, + collapse: bool = False, + is_leaf: Callable[[Type], bool] | None = None, + sorting_keys: List[NestedKey] | None = None, ) -> List: - return list( - self.values( + if sorting_keys is None: + return list( + self.values( + include_nested=include_nested, + leaves_only=leaves_only, + is_leaf=_NESTED_TENSORS_AS_LISTS if not collapse else is_leaf, + ) + ) + else: + keys, vals = self._items_list( include_nested=include_nested, leaves_only=leaves_only, - is_leaf=_NESTED_TENSORS_AS_LISTS, + is_leaf=is_leaf, + collapse=collapse, ) - ) + source = dict(zip(keys, vals)) + return [source[key] for key in sorting_keys] @cache # noqa: B019 def _items_list( @@ -7026,7 +7040,7 @@ def cosh_(self) -> T: def add(self, other: TensorDictBase | float, alpha: float | None = None): keys, vals = self._items_list(True, True) if _is_tensor_collection(type(other)): - other_val = other._values_list(True, True) + other_val = other._values_list(True, True, sorting_keys=keys) else: other_val = other if alpha is not None: @@ -7044,13 +7058,15 @@ def add(self, other: TensorDictBase | float, alpha: float | None = None): def add_(self, other: TensorDictBase | float, alpha: float | None = None): if _is_tensor_collection(type(other)): - other_val = other._values_list(True, True) + keys, val = self._items_list(True, True) + other_val = other._values_list(True, True, sorting_keys=keys) else: + val = self._values_list(True, True) other_val = other if alpha is not None: - torch._foreach_add_(self._values_list(True, True), other_val, alpha=alpha) + torch._foreach_add_(val, other_val, alpha=alpha) else: - torch._foreach_add_(self._values_list(True, True), other_val) + torch._foreach_add_(val, other_val) return self def lerp(self, end: TensorDictBase | float, weight: TensorDictBase | float): @@ -7154,15 +7170,16 @@ def addcmul_(self, other1, other2, value: float | None = 1): return self def sub(self, other: TensorDictBase | float, alpha: float | None = None): - keys, vals = self._items_list(True, True) if _is_tensor_collection(type(other)): - other_val = other._values_list(True, True) + keys, val = self._items_list(True, True) + other_val = other._values_list(True, True, sorting_keys=keys) else: + val = self._values_list(True, True) other_val = other if alpha is not None: - vals = torch._foreach_sub(vals, other_val, alpha=alpha) + vals = torch._foreach_sub(val, other_val, alpha=alpha) else: - vals = torch._foreach_sub(vals, other_val) + vals = torch._foreach_sub(val, other_val) items = dict(zip(keys, vals)) return self._fast_apply( lambda name, val: items.get(name, val), @@ -7174,30 +7191,34 @@ def sub(self, other: TensorDictBase | float, alpha: float | None = None): def sub_(self, other: TensorDictBase | float, alpha: float | None = None): if _is_tensor_collection(type(other)): - other_val = other._values_list(True, True) + keys, val = self._items_list(True, True) + other_val = other._values_list(True, True, sorting_keys=keys) else: + val = self._values_list(True, True) other_val = other if alpha is not None: - torch._foreach_sub_(self._values_list(True, True), other_val, alpha=alpha) + torch._foreach_sub_(val, other_val, alpha=alpha) else: - torch._foreach_sub_(self._values_list(True, True), other_val) + torch._foreach_sub_(val, other_val) return self def mul_(self, other: TensorDictBase | float) -> T: if _is_tensor_collection(type(other)): - other_val = other._values_list(True, True) + keys, val = self._items_list(True, True) + other_val = other._values_list(True, True, sorting_keys=keys) else: + val = self._values_list(True, True) other_val = other - torch._foreach_mul_(self._values_list(True, True), other_val) + torch._foreach_mul_(val, other_val) return self def mul(self, other: TensorDictBase | float) -> T: - keys, vals = self._items_list(True, True) + keys, val = self._items_list(True, True) if _is_tensor_collection(type(other)): - other_val = other._values_list(True, True) + other_val = other._values_list(True, True, sorting_keys=keys) else: other_val = other - vals = torch._foreach_mul(vals, other_val) + vals = torch._foreach_mul(val, other_val) items = dict(zip(keys, vals)) return self._fast_apply( lambda name, val: items.get(name, val), @@ -7209,16 +7230,18 @@ def mul(self, other: TensorDictBase | float) -> T: def maximum_(self, other: TensorDictBase | float) -> T: if _is_tensor_collection(type(other)): - other_val = other._values_list(True, True) + keys, val = self._items_list(True, True) + other_val = other._values_list(True, True, sorting_keys=keys) else: + val = self._values_list(True, True) other_val = other - torch._foreach_maximum_(self._values_list(True, True), other_val) + torch._foreach_maximum_(val, other_val) return self def maximum(self, other: TensorDictBase | float) -> T: keys, vals = self._items_list(True, True) if _is_tensor_collection(type(other)): - other_val = other._values_list(True, True) + other_val = other._values_list(True, True, sorting_keys=keys) else: other_val = other vals = torch._foreach_maximum(vals, other_val) @@ -7233,16 +7256,18 @@ def maximum(self, other: TensorDictBase | float) -> T: def minimum_(self, other: TensorDictBase | float) -> T: if _is_tensor_collection(type(other)): - other_val = other._values_list(True, True) + keys, val = self._items_list(True, True) + other_val = other._values_list(True, True, sorting_keys=keys) else: + val = self._values_list(True, True) other_val = other - torch._foreach_minimum_(self._values_list(True, True), other_val) + torch._foreach_minimum_(val, other_val) return self def minimum(self, other: TensorDictBase | float) -> T: keys, vals = self._items_list(True, True) if _is_tensor_collection(type(other)): - other_val = other._values_list(True, True) + other_val = other._values_list(True, True, sorting_keys=keys) else: other_val = other vals = torch._foreach_minimum(vals, other_val) @@ -7257,16 +7282,18 @@ def minimum(self, other: TensorDictBase | float) -> T: def clamp_max_(self, other: TensorDictBase | float) -> T: if _is_tensor_collection(type(other)): - other_val = other._values_list(True, True) + keys, val = self._items_list(True, True) + other_val = other._values_list(True, True, sorting_keys=keys) else: + val = self._values_list(True, True) other_val = other - torch._foreach_clamp_max_(self._values_list(True, True), other_val) + torch._foreach_clamp_max_(val, other_val) return self def clamp_max(self, other: TensorDictBase | float) -> T: keys, vals = self._items_list(True, True) if _is_tensor_collection(type(other)): - other_val = other._values_list(True, True) + other_val = other._values_list(True, True, sorting_keys=keys) else: other_val = other vals = torch._foreach_clamp_max(vals, other_val) @@ -7281,16 +7308,18 @@ def clamp_max(self, other: TensorDictBase | float) -> T: def clamp_min_(self, other: TensorDictBase | float) -> T: if _is_tensor_collection(type(other)): - other_val = other._values_list(True, True) + keys, val = self._items_list(True, True) + other_val = other._values_list(True, True, sorting_keys=keys) else: + val = self._values_list(True, True) other_val = other - torch._foreach_clamp_min_(self._values_list(True, True), other_val) + torch._foreach_clamp_min_(val, other_val) return self def clamp_min(self, other: TensorDictBase | float) -> T: keys, vals = self._items_list(True, True) if _is_tensor_collection(type(other)): - other_val = other._values_list(True, True) + other_val = other._values_list(True, True, sorting_keys=keys) else: other_val = other vals = torch._foreach_clamp_min(vals, other_val) @@ -7305,16 +7334,18 @@ def clamp_min(self, other: TensorDictBase | float) -> T: def pow_(self, other: TensorDictBase | float) -> T: if _is_tensor_collection(type(other)): - other_val = other._values_list(True, True) + keys, val = self._items_list(True, True) + other_val = other._values_list(True, True, sorting_keys=keys) else: + val = self._values_list(True, True) other_val = other - torch._foreach_pow_(self._values_list(True, True), other_val) + torch._foreach_pow_(val, other_val) return self def pow(self, other: TensorDictBase | float) -> T: keys, vals = self._items_list(True, True) if _is_tensor_collection(type(other)): - other_val = other._values_list(True, True) + other_val = other._values_list(True, True, sorting_keys=keys) else: other_val = other vals = torch._foreach_pow(vals, other_val) @@ -7329,16 +7360,18 @@ def pow(self, other: TensorDictBase | float) -> T: def div_(self, other: TensorDictBase | float) -> T: if _is_tensor_collection(type(other)): - other_val = other._values_list(True, True) + keys, val = self._items_list(True, True) + other_val = other._values_list(True, True, sorting_keys=keys) else: + val = self._values_list(True, True) other_val = other - torch._foreach_div_(self._values_list(True, True), other_val) + torch._foreach_div_(val, other_val) return self def div(self, other: TensorDictBase | float) -> T: keys, vals = self._items_list(True, True) if _is_tensor_collection(type(other)): - other_val = other._values_list(True, True) + other_val = other._values_list(True, True, sorting_keys=keys) else: other_val = other vals = torch._foreach_div(vals, other_val) diff --git a/test/test_tensordict.py b/test/test_tensordict.py index b12517abd..f4ff0ce77 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -2437,6 +2437,15 @@ def dummy_td_1(self): def dummy_td_2(self): return self.dummy_td_0.apply(lambda x: x + 2) + def test_ordering(self): + + x0 = TensorDict({"y": torch.zeros(3), "x": torch.ones(3)}) + + x1 = TensorDict({"x": torch.ones(3), "y": torch.zeros(3)}) + assert ((x0 + x1)["x"] == 2).all() + assert ((x0 * x1)["x"] == 1).all() + assert ((x0 - x1)["x"] == 0).all() + @pytest.mark.parametrize("locked", [True, False]) def test_add(self, locked): td = self.dummy_td_0