Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] computation of log prob in composite distribution for batched samples #1054

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

thomasbbrunner
Copy link

@thomasbbrunner thomasbbrunner commented Oct 22, 2024

Description

As far as I can tell, there's a bug in the latest version of the CompositeDistribution when calculating the log prob.

Motivation and Context

The log prob is erroneously flattened according to the sample's ndim. More specifically, the ndim of the batch shape of the root level of the sample tensordict.

However, the sample's ndim does not always match the batch shape of the distribution itself!

This is the case, for instance, in a multi-agent setup.

Let's take an environment with two grouped agents and a policy with two categorical heads as an example. The tensordict in this case looks something like (edited for clarity):

>>> td = policy_module(env.reset())
TensorDict(
    fields={
        agents: TensorDict(
            fields={
                action: TensorDict(
                    fields={
                        head_0: TensorDict(
                            fields={
                                action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False),
                                action_log_prob: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)},
                            batch_size=torch.Size([2]),),
                        head_1: TensorDict(
                            fields={
                                action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False),
                                action_log_prob: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)},
                            batch_size=torch.Size([2]),)},
                    batch_size=torch.Size([2]),),
                done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([2, 70]), device=cpu, dtype=torch.float32, is_shared=False),
                params: TensorDict(
                    fields={
                        head_0: TensorDict(
                            fields={
                                logits: Tensor(shape=torch.Size([2, 9]), device=cpu, dtype=torch.float32, is_shared=False)},
                            batch_size=torch.Size([2]),),
                        head_1: TensorDict(
                            fields={
                                logits: Tensor(shape=torch.Size([2, 9]), device=cpu, dtype=torch.float32, is_shared=False)},
                            batch_size=torch.Size([2]),),},
                    batch_size=torch.Size([2]),),
                terminated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                truncated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([2]),),
        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([]),)

With the current code, the resulting sample_log_prob is:

>>> dist = policy_module.get_dist(td)
>>> dist.log_prob(td)
tensor(-16.9137, grad_fn=<AddBackward0>)

Note that it outputs a single float. This does not make sense, as our logits field has a batch shape of (2,)!

The entropy is correctly computed:

>>> dist.entropy(td)
tensor([8.2005, 8.2096], grad_fn=<AddBackward0>)

With the changes in this MR, the log prob is computed just like the entropy, resulting in the correct shape:

>>> dist = policy_module.get_dist(td)
>>> dist.log_prob(td)
tensor([-9.3534, -8.6486], grad_fn=<AddBackward0>)

Types of changes

  • Bug fix (non-breaking change which fixes an issue)

Checklist

  • I have read the CONTRIBUTION guide (required)
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 22, 2024
@thomasbbrunner
Copy link
Author

@albertbou92 tagging you, as this might be of interest.

@thomasbbrunner
Copy link
Author

Some tests are failing, I'll address this asap!

@vmoens
Copy link
Contributor

vmoens commented Oct 22, 2024

This will go in a minor thanks for flagging

@vmoens vmoens changed the title Fix computation of log prob in composite distribution for batched samples. [BugFix] computation of log prob in composite distribution for batched samples Oct 22, 2024
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey
So I reviewed these changes and I think the current behaviour is accurate - but I'm open to change my mind.

  1. The "advised" way of using log_prob is going to be to write the tensors in the TD from now on. If that is done, they should have the proper shape
  2. We can also return an aggregate but we will require it to have the same shape as the root tensordict. I thought about making it less restrictive and allow the lp to have the shape of the parent tensordict (eg ("group0", "agent1")) but then what happens if we sum the lps from ("group0", "agent0"), ("group0", "agent1"), ("group1", "agent0") and ("group1", "agent1") where the first two have, say, shape [2] and the last two have shape [3]? In this case we cannot sum it - the only thing we can do is to reduce the lp to a single float. Because the content of the tensordict should not condition the shape that you get, I think that having an lp that has the shape of the sample tensordict (as it is the case now) is the most robust behaviour.

The main consideration I have is whether this is BC-breaking - did you encounter a problem when switching to 0.6?

Happy to hear your thoughts on this!

@louisfaury
Copy link

Hi Vincent,

Some thoughts after your comment. I think that, indeed, the main pain point is "where" the summed lp should be written. In the example above, I do believe it should not be in the root, but in the agents tensordict.

Let's build on the concrete example you laid out. I do not believe it makes sense to sum log-probs across group_0 and group_1. When we split agent in groups, I think its safe to assume we do this because their i/o shapes are incompatible -- hence the joint action between each group should be generated by different distributions. In that case, it feels natural to have an aggregate log_probs under both the group_0 and group_1 tensordicts.

For the same example, I do not see the learning-theoretic usefulness of the scalar log-prob. If we think of training this multi-agent dynamical system with PPO, the scalar sample_log_prob aggregated over all the agents is useless for learning -- only the one that is aggregated at the group level is.

@vmoens vmoens added the bug Something isn't working label Oct 25, 2024
@vmoens
Copy link
Contributor

vmoens commented Oct 25, 2024

Some thoughts after your comment. I think that, indeed, the main pain point is "where" the summed lp should be written. In the example above, I do believe it should not be in the root, but in the agents tensordict.

that's already the case no?

d[_add_suffix(name, "_log_prob")] = lp = dist.log_prob(sample.get(name))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants