Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Helper Function for evaluating log_likelihood #3373

Open
julian-8897 opened this issue Jun 6, 2024 · 4 comments
Open

Add Helper Function for evaluating log_likelihood #3373

julian-8897 opened this issue Jun 6, 2024 · 4 comments

Comments

@julian-8897
Copy link

Issue Description

Currently training some bayesian neural networks using HMC, would it be useful to include the calculation of log likelihood as a helper function? probably something like this using conditioning and trace?

def log_likelihood(model, posterior_samples, x, y):
    log_likelihoods = []

    sample_count = next(iter(posterior_samples.values())).shape[0]

    for i in range(sample_count):
        # Set the parameters of the model to the values in the i-th sample
        conditioned_model = pyro.condition(
            model, data={k: torch.tensor(v[i]) for k, v in posterior_samples.items()}
        )

        # Compute the log likelihood of the data given these parameters
        trace = pyro.poutine.trace(conditioned_model).get_trace(  # type: ignore
            torch.from_numpy(x).to(torch.float32), torch.from_numpy(y).to(torch.float32)
        )
        log_likelihoods.append(trace.log_prob_sum())

    # Average the log likelihoods over all samples
    return torch.stack(log_likelihoods).mean()
@fritzo
Copy link
Member

fritzo commented Jun 8, 2024

I believe the above code computes the posterior predictive log density, which includes both prior and likelihood. In the past, when I've computed log-likelhood, I've manually masked out the prior sites. I'm unsure whether it's practical to automatically mask out prior sites in a way that is correct for reparametrization and other auxiliary variables.

Maybe a first step could be adding a log-likelihood computation to a couple existing tutorials, then seeing if there's a general implementation (that is e.g. batchable)?

@julian-8897
Copy link
Author

@fritzo thanks for pointing out my mistake, much appreciated! Was just wondering how did you mask out your prior sites systematically (I'm very much new to Pyro)? You're right, it might be worth wrtiting up a couple of tutorials fot the log-likelihood computation, do you have any recommendations of where to start?

@fritzo
Copy link
Member

fritzo commented Jun 12, 2024

how did you mask out your prior sites systematically?

I've enclosed the top of a hierarchical model in a boolean poutine.mask(mask=___), e.g.

def example_model(data, include_prior: bool = True):
    # Sample top level variables from the prior.
    with poutine.mask(mask=include_prior):
        loc = pyro.sample("loc", Normal(0, 1))
        scale = pyro.sample("scale", LogNormal(0, 1))
    # Observe data.
    pyro.sample("data", Normal(loc, scale), obs=data)

do you have any recommendations of where to start?

Gosh there are over 50 tutorials on https://pyro.ai/examples . You might pick a domain you're interested in and add a section at the end. Then "likelihood" should still show up in search results.

@julian-8897
Copy link
Author

@fritzo Thanks for this. Great, I can add them in the tutorial, might be useful!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants