Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Utilities for simplifying interactions between PyroSample and plates #3385

Open
eb8680 opened this issue Jul 17, 2024 · 2 comments
Open

Utilities for simplifying interactions between PyroSample and plates #3385

eb8680 opened this issue Jul 17, 2024 · 2 comments

Comments

@eb8680
Copy link
Member

eb8680 commented Jul 17, 2024

Problem

PyroModule and PyroSample make it straightforward to compositionally specify probabilistic models with random parameters. However, PyroSample has a somewhat awkward interaction with pyro.plate:

class Model(pyro.nn.PyroModule):

  @pyro.nn.PyroSample
  def loc(self):
    return pyro.distributions.Normal(0, 1)

  @pyro.nn.PyroSample
  def scale(self):
    return pyro.distributions.LogNormal(0, 1)

  def forward(self, x_obs):
    assert self.scale.shape == ()  # accessing self.scale triggers pyro.sample outside the plate
    with pyro.plate("data", x_obs.shape[0], dim=-1):
      assert self.loc.shape == (x_obs.shape[0],)  # accessing self.loc here triggers pyro.sample inside the plate
      return pyro.sample("x", pyro.distributions.Normal(self.loc, self.scale), obs=x_obs)

To ensure loc and scale are sampled globally, it is necessary to access them outside the data plate as scale is in the above - inlining self.loc in the final line samples a different loc for each datapoint. This behavior is unambiguous semantically, but it can cause confusion in more complex models and require lots of ugly boilerplate code in the model that manually samples random parameters of submodules in the correct plate context.

For example, in the below code the intuitive behavior for Model.linear is clearly for linear.weight to be sampled outside of the data plate, but because self.linear is invoked for the first time inside the plate, there will be separate random copies of linear.weight for each plate slice:

class BayesianLinear(pyro.nn.PyroModule[torch.nn.Linear]):

  @pyro.nn.PyroSample
  def weight(self):
    return dist.Normal(0, 1).expand([self.num_input, self.num_output]).to_event(2)

class Model(pyro.nn.PyroModule):
  def __init__(self, num_inputs, num_outputs):
    super().__init__()
    self.linear = BayesianLinear(num_inputs, num_outputs)

  def forward(self, x):
    with pyro.plate("data", x.shape[-2], dim=-1):
      loc = self.linear(x)
      assert self.linear.weight.shape[-3] == x.shape[-2]
      return pyro.sample("y", dist.Normal(loc, 1))

However, it would not be correct to simply ignore all plates when executing PyroSamples - in this example, we might want to use a multi-sample ELBO estimator in inferring self.linear.weight (e.g. pyro.infer.Trace_ELBO(num_particles=10, vectorize_particles=True)), which is implemented with another plate that should not be ignored.

Proposed fix

It would be nice to have a feature that enabled the intuitive behavior in the second example above without breaking backwards compatibility with PyroSample's existing semantics or its correctness in the presence of enclosing plates like that introduced by the multi-sample ELBO.

This could potentially be achieved with a new handler PyroSamplePlateScope such that PyroSample statements executed inside its context are only modified by plates entered outside of it, while ordinary pyro.sample statements are unaffected and behave in the usual way:

class Model(pyro.nn.PyroModule):
  def __init__(self, num_inputs, num_outputs):
    super().__init__()
    self.num_inputs = num_inputs
    self.num_outputs = num_outputs
    self.linear = BayesianLinear(num_inputs, num_outputs)

  @pyro.nn.PyroSample
  def scale(self):
    return pyro.distributions.LogNormal(0, 1).expand([self.num_outputs]).to_event(1)

  @PyroSamplePlateScope()
  def forward(self, x):
    with pyro.plate("data", x.shape[-2], dim=-1):
      loc = self.linear(x)
      assert self.linear.weight.shape[-3] == 1  # sampled outside data plate
      assert self.scale.shape[-2] == 1  # sampled outside data plate
      y = pyro.sample("y", dist.Normal(loc, self.scale).to_event(1))
      assert y.shape[-2] == x.shape[-2]  # ordinary pyro.sample statement
      return y
@eb8680
Copy link
Member Author

eb8680 commented Jul 17, 2024

@ordabayevy what do you think about this?

@ordabayevy
Copy link
Member

ordabayevy commented Jul 23, 2024

This makes sense to me. My two comments are:

  1. Personally, for me it is more intuitive to treat PyroSample the same as pyro.sample and not inline it. But I don't use PyroSample much and I can see that there might be the convenience of inlining it if used a lot.
  2. As a consideration, should plate scoping be implemented as a context manager like in PyroSamplePlateScope or done per individual PyroSample (e.g. through infer={"ignored_plates": ...} which would also work with pyro.sample)? For example if you want self.loc to be sampled inside of the data plate and self.scale sampled outside of the data plate:
class Model(pyro.nn.PyroModule):

  @pyro.nn.PyroSample
  def loc(self):
    return pyro.distributions.Normal(0, 1)

  @pyro.nn.PyroSample(infer={"ignored_plates": ["data"]})  # new syntax
  def scale(self):
    return pyro.distributions.LogNormal(0, 1)

  def forward(self, x_obs):
    with pyro.plate("data", x_obs.shape[0], dim=-1):
      return pyro.sample("x", pyro.distributions.Normal(self.loc, self.scale), obs=x_obs)  # self.loc is local and self.scale is global

(one drawback of this approach is that ignored_plates is not in the forward method and hidden elsewhere which can make it harder to read the code)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants