Skip to content

Commit

Permalink
[BugFix] Fix key ordering in pointwise ops (#855)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jul 5, 2024
1 parent 60d8a61 commit f9ef888
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 38 deletions.
109 changes: 71 additions & 38 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4689,14 +4689,28 @@ def _values_list(
self,
include_nested: bool = False,
leaves_only: bool = False,
*,
collapse: bool = False,
is_leaf: Callable[[Type], bool] | None = None,
sorting_keys: List[NestedKey] | None = None,
) -> List:
return list(
self.values(
if sorting_keys is None:
return list(
self.values(
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=_NESTED_TENSORS_AS_LISTS if not collapse else is_leaf,
)
)
else:
keys, vals = self._items_list(
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=_NESTED_TENSORS_AS_LISTS,
is_leaf=is_leaf,
collapse=collapse,
)
)
source = dict(zip(keys, vals))
return [source[key] for key in sorting_keys]

@cache # noqa: B019
def _items_list(
Expand Down Expand Up @@ -7026,7 +7040,7 @@ def cosh_(self) -> T:
def add(self, other: TensorDictBase | float, alpha: float | None = None):
keys, vals = self._items_list(True, True)
if _is_tensor_collection(type(other)):
other_val = other._values_list(True, True)
other_val = other._values_list(True, True, sorting_keys=keys)
else:
other_val = other
if alpha is not None:
Expand All @@ -7044,13 +7058,15 @@ def add(self, other: TensorDictBase | float, alpha: float | None = None):

def add_(self, other: TensorDictBase | float, alpha: float | None = None):
if _is_tensor_collection(type(other)):
other_val = other._values_list(True, True)
keys, val = self._items_list(True, True)
other_val = other._values_list(True, True, sorting_keys=keys)
else:
val = self._values_list(True, True)
other_val = other
if alpha is not None:
torch._foreach_add_(self._values_list(True, True), other_val, alpha=alpha)
torch._foreach_add_(val, other_val, alpha=alpha)
else:
torch._foreach_add_(self._values_list(True, True), other_val)
torch._foreach_add_(val, other_val)
return self

def lerp(self, end: TensorDictBase | float, weight: TensorDictBase | float):
Expand Down Expand Up @@ -7154,15 +7170,16 @@ def addcmul_(self, other1, other2, value: float | None = 1):
return self

def sub(self, other: TensorDictBase | float, alpha: float | None = None):
keys, vals = self._items_list(True, True)
if _is_tensor_collection(type(other)):
other_val = other._values_list(True, True)
keys, val = self._items_list(True, True)
other_val = other._values_list(True, True, sorting_keys=keys)
else:
val = self._values_list(True, True)
other_val = other
if alpha is not None:
vals = torch._foreach_sub(vals, other_val, alpha=alpha)
vals = torch._foreach_sub(val, other_val, alpha=alpha)
else:
vals = torch._foreach_sub(vals, other_val)
vals = torch._foreach_sub(val, other_val)
items = dict(zip(keys, vals))
return self._fast_apply(
lambda name, val: items.get(name, val),
Expand All @@ -7174,30 +7191,34 @@ def sub(self, other: TensorDictBase | float, alpha: float | None = None):

def sub_(self, other: TensorDictBase | float, alpha: float | None = None):
if _is_tensor_collection(type(other)):
other_val = other._values_list(True, True)
keys, val = self._items_list(True, True)
other_val = other._values_list(True, True, sorting_keys=keys)
else:
val = self._values_list(True, True)
other_val = other
if alpha is not None:
torch._foreach_sub_(self._values_list(True, True), other_val, alpha=alpha)
torch._foreach_sub_(val, other_val, alpha=alpha)
else:
torch._foreach_sub_(self._values_list(True, True), other_val)
torch._foreach_sub_(val, other_val)
return self

def mul_(self, other: TensorDictBase | float) -> T:
if _is_tensor_collection(type(other)):
other_val = other._values_list(True, True)
keys, val = self._items_list(True, True)
other_val = other._values_list(True, True, sorting_keys=keys)
else:
val = self._values_list(True, True)
other_val = other
torch._foreach_mul_(self._values_list(True, True), other_val)
torch._foreach_mul_(val, other_val)
return self

def mul(self, other: TensorDictBase | float) -> T:
keys, vals = self._items_list(True, True)
keys, val = self._items_list(True, True)
if _is_tensor_collection(type(other)):
other_val = other._values_list(True, True)
other_val = other._values_list(True, True, sorting_keys=keys)
else:
other_val = other
vals = torch._foreach_mul(vals, other_val)
vals = torch._foreach_mul(val, other_val)
items = dict(zip(keys, vals))
return self._fast_apply(
lambda name, val: items.get(name, val),
Expand All @@ -7209,16 +7230,18 @@ def mul(self, other: TensorDictBase | float) -> T:

def maximum_(self, other: TensorDictBase | float) -> T:
if _is_tensor_collection(type(other)):
other_val = other._values_list(True, True)
keys, val = self._items_list(True, True)
other_val = other._values_list(True, True, sorting_keys=keys)
else:
val = self._values_list(True, True)
other_val = other
torch._foreach_maximum_(self._values_list(True, True), other_val)
torch._foreach_maximum_(val, other_val)
return self

def maximum(self, other: TensorDictBase | float) -> T:
keys, vals = self._items_list(True, True)
if _is_tensor_collection(type(other)):
other_val = other._values_list(True, True)
other_val = other._values_list(True, True, sorting_keys=keys)
else:
other_val = other
vals = torch._foreach_maximum(vals, other_val)
Expand All @@ -7233,16 +7256,18 @@ def maximum(self, other: TensorDictBase | float) -> T:

def minimum_(self, other: TensorDictBase | float) -> T:
if _is_tensor_collection(type(other)):
other_val = other._values_list(True, True)
keys, val = self._items_list(True, True)
other_val = other._values_list(True, True, sorting_keys=keys)
else:
val = self._values_list(True, True)
other_val = other
torch._foreach_minimum_(self._values_list(True, True), other_val)
torch._foreach_minimum_(val, other_val)
return self

def minimum(self, other: TensorDictBase | float) -> T:
keys, vals = self._items_list(True, True)
if _is_tensor_collection(type(other)):
other_val = other._values_list(True, True)
other_val = other._values_list(True, True, sorting_keys=keys)
else:
other_val = other
vals = torch._foreach_minimum(vals, other_val)
Expand All @@ -7257,16 +7282,18 @@ def minimum(self, other: TensorDictBase | float) -> T:

def clamp_max_(self, other: TensorDictBase | float) -> T:
if _is_tensor_collection(type(other)):
other_val = other._values_list(True, True)
keys, val = self._items_list(True, True)
other_val = other._values_list(True, True, sorting_keys=keys)
else:
val = self._values_list(True, True)
other_val = other
torch._foreach_clamp_max_(self._values_list(True, True), other_val)
torch._foreach_clamp_max_(val, other_val)
return self

def clamp_max(self, other: TensorDictBase | float) -> T:
keys, vals = self._items_list(True, True)
if _is_tensor_collection(type(other)):
other_val = other._values_list(True, True)
other_val = other._values_list(True, True, sorting_keys=keys)
else:
other_val = other
vals = torch._foreach_clamp_max(vals, other_val)
Expand All @@ -7281,16 +7308,18 @@ def clamp_max(self, other: TensorDictBase | float) -> T:

def clamp_min_(self, other: TensorDictBase | float) -> T:
if _is_tensor_collection(type(other)):
other_val = other._values_list(True, True)
keys, val = self._items_list(True, True)
other_val = other._values_list(True, True, sorting_keys=keys)
else:
val = self._values_list(True, True)
other_val = other
torch._foreach_clamp_min_(self._values_list(True, True), other_val)
torch._foreach_clamp_min_(val, other_val)
return self

def clamp_min(self, other: TensorDictBase | float) -> T:
keys, vals = self._items_list(True, True)
if _is_tensor_collection(type(other)):
other_val = other._values_list(True, True)
other_val = other._values_list(True, True, sorting_keys=keys)
else:
other_val = other
vals = torch._foreach_clamp_min(vals, other_val)
Expand All @@ -7305,16 +7334,18 @@ def clamp_min(self, other: TensorDictBase | float) -> T:

def pow_(self, other: TensorDictBase | float) -> T:
if _is_tensor_collection(type(other)):
other_val = other._values_list(True, True)
keys, val = self._items_list(True, True)
other_val = other._values_list(True, True, sorting_keys=keys)
else:
val = self._values_list(True, True)
other_val = other
torch._foreach_pow_(self._values_list(True, True), other_val)
torch._foreach_pow_(val, other_val)
return self

def pow(self, other: TensorDictBase | float) -> T:
keys, vals = self._items_list(True, True)
if _is_tensor_collection(type(other)):
other_val = other._values_list(True, True)
other_val = other._values_list(True, True, sorting_keys=keys)
else:
other_val = other
vals = torch._foreach_pow(vals, other_val)
Expand All @@ -7329,16 +7360,18 @@ def pow(self, other: TensorDictBase | float) -> T:

def div_(self, other: TensorDictBase | float) -> T:
if _is_tensor_collection(type(other)):
other_val = other._values_list(True, True)
keys, val = self._items_list(True, True)
other_val = other._values_list(True, True, sorting_keys=keys)
else:
val = self._values_list(True, True)
other_val = other
torch._foreach_div_(self._values_list(True, True), other_val)
torch._foreach_div_(val, other_val)
return self

def div(self, other: TensorDictBase | float) -> T:
keys, vals = self._items_list(True, True)
if _is_tensor_collection(type(other)):
other_val = other._values_list(True, True)
other_val = other._values_list(True, True, sorting_keys=keys)
else:
other_val = other
vals = torch._foreach_div(vals, other_val)
Expand Down
9 changes: 9 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2437,6 +2437,15 @@ def dummy_td_1(self):
def dummy_td_2(self):
return self.dummy_td_0.apply(lambda x: x + 2)

def test_ordering(self):

x0 = TensorDict({"y": torch.zeros(3), "x": torch.ones(3)})

x1 = TensorDict({"x": torch.ones(3), "y": torch.zeros(3)})
assert ((x0 + x1)["x"] == 2).all()
assert ((x0 * x1)["x"] == 1).all()
assert ((x0 - x1)["x"] == 0).all()

@pytest.mark.parametrize("locked", [True, False])
def test_add(self, locked):
td = self.dummy_td_0
Expand Down

0 comments on commit f9ef888

Please sign in to comment.