Skip to content

Commit

Permalink
feat: conditional coupled neural spline flow (#51)
Browse files Browse the repository at this point in the history
* Added option to use neural spline coupling flow as conditional flow

* Updated description

* Updated article reference

* Hint that there are more papers
  • Loading branch information
VincentStimper authored Oct 25, 2023
1 parent ee48e44 commit 9607072
Show file tree
Hide file tree
Showing 5 changed files with 257 additions and 41 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ to the existing documentation. Once you finished coding and testing, please

## Used by

The package has been used in several research papers, which are listed below.
The package has been used in several research papers. Some of them are listed below.

> Andrew Campbell, Wenlong Chen, Vincent Stimper, José Miguel Hernández-Lobato, and Yichuan Zhang.
> [A gradient based strategy for Hamiltonian Monte Carlo hyperparameter optimization](https://proceedings.mlr.press/v139/campbell21a.html).
Expand All @@ -219,7 +219,7 @@ The package has been used in several research papers, which are listed below.
> Laurence I. Midgley, Vincent Stimper, Gregor N. C. Simm, Bernhard Schölkopf, José Miguel Hernández-Lobato.
> [Flow Annealed Importance Sampling Bootstrap](https://arxiv.org/abs/2208.01893).
> arXiv preprint arXiv:2208.01893, 2022.
> The Eleventh International Conference on Learning Representations, 2023.
>
> [Code available on GitHub.](https://github.com/lollcat/fab-torch)
Expand Down
217 changes: 205 additions & 12 deletions examples/conditional_flow.ipynb

Large diffs are not rendered by default.

20 changes: 13 additions & 7 deletions normflows/flows/flow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,31 +11,37 @@ class FlowTest(unittest.TestCase):
def assertClose(self, actual, expected, atol=None, rtol=None):
assert_close(actual, expected, atol=atol, rtol=rtol)

def checkForward(self, flow, inputs):
def checkForward(self, flow, inputs, context=None):
# Do forward transform
outputs, log_det = flow(inputs)
if context is None:
outputs, log_det = flow(inputs)
else:
outputs, log_det = flow(inputs, context)
# Check type
assert outputs.dtype == inputs.dtype
# Check shape
assert outputs.shape == inputs.shape
# Return results
return outputs, log_det

def checkInverse(self, flow, inputs):
def checkInverse(self, flow, inputs, context=None):
# Do inverse transform
outputs, log_det = flow.inverse(inputs)
if context is None:
outputs, log_det = flow.inverse(inputs)
else:
outputs, log_det = flow.inverse(inputs, context)
# Check type
assert outputs.dtype == inputs.dtype
# Check shape
assert outputs.shape == inputs.shape
# Return results
return outputs, log_det

def checkForwardInverse(self, flow, inputs, atol=None, rtol=None):
def checkForwardInverse(self, flow, inputs, context=None, atol=None, rtol=None):
# Check forward
outputs, log_det = self.checkForward(flow, inputs)
outputs, log_det = self.checkForward(flow, inputs, context)
# Check inverse
input_, log_det_ = self.checkInverse(flow, outputs)
input_, log_det_ = self.checkInverse(flow, outputs, context)
# Check identity
self.assertClose(input_, inputs, atol, rtol)
ld_id = log_det + log_det_
Expand Down
24 changes: 14 additions & 10 deletions normflows/flows/neural_spline/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(
num_input_channels,
num_blocks,
num_hidden_channels,
num_context_channels=None,
num_bins=8,
tails="linear",
tail_bound=3.0,
Expand All @@ -35,6 +36,7 @@ def __init__(
num_input_channels (int): Flow dimension
num_blocks (int): Number of residual blocks of the parameter NN
num_hidden_channels (int): Number of hidden units of the NN
num_context_channels (int): Number of context/conditional channels
num_bins (int): Number of bins
tails (str): Behaviour of the tails of the distribution, can be linear, circular for periodic distribution, or None for distribution on the compact interval
tail_bound (float): Bound of the spline tails
Expand All @@ -48,7 +50,7 @@ def transform_net_create_fn(in_features, out_features):
return ResidualNet(
in_features=in_features,
out_features=out_features,
context_features=None,
context_features=num_context_channels,
hidden_features=num_hidden_channels,
num_blocks=num_blocks,
activation=activation(),
Expand All @@ -66,12 +68,12 @@ def transform_net_create_fn(in_features, out_features):
apply_unconditional_transform=True,
)

def forward(self, z):
z, log_det = self.prqct.inverse(z)
def forward(self, z, context=None):
z, log_det = self.prqct.inverse(z, context)
return z, log_det.view(-1)

def inverse(self, z):
z, log_det = self.prqct(z)
def inverse(self, z, context=None):
z, log_det = self.prqct(z, context)
return z, log_det.view(-1)


Expand All @@ -86,6 +88,7 @@ def __init__(
num_blocks,
num_hidden_channels,
ind_circ,
num_context_channels=None,
num_bins=8,
tail_bound=3.0,
activation=nn.ReLU,
Expand All @@ -100,6 +103,7 @@ def __init__(
num_input_channels (int): Flow dimension
num_blocks (int): Number of residual blocks of the parameter NN
num_hidden_channels (int): Number of hidden units of the NN
num_context_channels (int): Number of context/conditional channels
ind_circ (Iterable): Indices of the circular coordinates
num_bins (int): Number of bins
tail_bound (float or Iterable): Bound of the spline tails
Expand Down Expand Up @@ -134,7 +138,7 @@ def transform_net_create_fn(in_features, out_features):
net = ResidualNet(
in_features=in_features,
out_features=out_features,
context_features=None,
context_features=num_context_channels,
hidden_features=num_hidden_channels,
num_blocks=num_blocks,
activation=activation(),
Expand Down Expand Up @@ -162,12 +166,12 @@ def transform_net_create_fn(in_features, out_features):
apply_unconditional_transform=True,
)

def forward(self, z):
z, log_det = self.prqct.inverse(z)
def forward(self, z, context=None):
z, log_det = self.prqct.inverse(z, context)
return z, log_det.view(-1)

def inverse(self, z):
z, log_det = self.prqct(z)
def inverse(self, z, context=None):
z, log_det = self.prqct(z, context)
return z, log_det.view(-1)


Expand Down
33 changes: 23 additions & 10 deletions normflows/flows/neural_spline/wrapper_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,16 @@ def test_normal_nsf(self):
for latent_size in [2, 5]:
for flow_cls in [CoupledRationalQuadraticSpline,
AutoregressiveRationalQuadraticSpline]:
with self.subTest(latent_size=latent_size, flow_cls=flow_cls):
flow = flow_cls(latent_size, hidden_units, hidden_layers)
inputs = torch.randn((batch_size, latent_size))
self.checkForwardInverse(flow, inputs)
for context_feature in [None, 3]:
with self.subTest(latent_size=latent_size, flow_cls=flow_cls):
flow = flow_cls(latent_size, hidden_units, hidden_layers,
num_context_channels=context_feature)
inputs = torch.randn((batch_size, latent_size))
if context_feature is None:
context = None
else:
context = torch.randn((batch_size, context_feature))
self.checkForwardInverse(flow, inputs, context)

def test_circular_nsf(self):
batch_size = 3
Expand All @@ -32,12 +38,19 @@ def test_circular_nsf(self):
for latent_size, ind_circ, tail_bound in params:
for flow_cls in [CircularCoupledRationalQuadraticSpline,
CircularAutoregressiveRationalQuadraticSpline]:
with self.subTest(latent_size=latent_size, ind_circ=ind_circ,
tail_bound=tail_bound, flow_cls=flow_cls):
flow = flow_cls(latent_size, hidden_units, hidden_layers,
ind_circ, tail_bound=tail_bound)
inputs = 6 * torch.rand((batch_size, latent_size)) - 3
self.checkForwardInverse(flow, inputs)
for context_feature in [None, 3]:
with self.subTest(latent_size=latent_size, ind_circ=ind_circ,
tail_bound=tail_bound, flow_cls=flow_cls,
context_feature=context_feature):
flow = flow_cls(latent_size, hidden_units, hidden_layers,
ind_circ, tail_bound=tail_bound,
num_context_channels=context_feature)
inputs = 6 * torch.rand((batch_size, latent_size)) - 3
if context_feature is None:
context = None
else:
context = torch.randn((batch_size, context_feature))
self.checkForwardInverse(flow, inputs, context)


if __name__ == "__main__":
Expand Down

0 comments on commit 9607072

Please sign in to comment.