Skip to content

Commit

Permalink
Merge pull request #40 from JakobRobnik/module
Browse files Browse the repository at this point in the history
mclmc/...
  • Loading branch information
JakobRobnik authored Nov 11, 2023
2 parents 945cc58 + de8874a commit 92c8a13
Show file tree
Hide file tree
Showing 22 changed files with 24 additions and 39 deletions.
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,4 @@ work and might even find some help from other people.

**Docs**:

For API docs, we use `pdoc`, which you can install with `pip`. Then do e.g. `pdoc sampling/sampler.py -o ./apidocs`. `pdoc` supports markdown, so markdown styled code comments inside triple quotes will be rendered automatically as documentation.
For API docs, we use `pdoc`, which you can install with `pip`. Then do e.g. `pdoc mclmc/sampling/sampler.py -o ./apidocs`. `pdoc` supports markdown, so markdown styled code comments inside triple quotes will be rendered automatically as documentation.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ PKG_VERSION = $(shell python setup.py --version)

test:
JAX_PLATFORM_NAME=cpu pytest --benchmark-disable
mypy sampling/sampler.py
mypy mclmc/sampling/sampler.py

set-bench:
pytest --benchmark-autosave
Expand Down
File renamed without changes.
Empty file added mclmc/sampling/__init__.py
Empty file.
4 changes: 2 additions & 2 deletions sampling/annealing.py → mclmc/sampling/annealing.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
from sampling import dynamics
from mclmc.sampling import dynamics

from sampling.dynamics import update_momentum
from mclmc.sampling.dynamics import update_momentum


class vmap_target:
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
4 changes: 2 additions & 2 deletions sampling/sampler.py → mclmc/sampling/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import jax
import jax.numpy as jnp
import numpy as np
from sampling import dynamics
from mclmc.sampling import dynamics

from sampling.dynamics import MCLMCInfo, MCLMCState, build_kernel, run_kernel
from mclmc.sampling.dynamics import MCLMCInfo, MCLMCState, build_kernel, run_kernel
from .correlation_length import ess_corr

class Target():
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion notebooks/tutorials/intro_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"from sampling.sampler import Sampler"
"from mclmc.sampling.sampler import Sampler"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion notebooks/tutorials/smc.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
"import sys \n",
"sys.path.insert(0, '../../')\n",
"\n",
"from sampling.smc import Sampler\n",
"from mclmc.sampling.smc import Sampler\n",
"import jax\n",
"import jax.numpy as jnp\n",
"\n",
Expand Down
Binary file not shown.
Binary file removed sampling/__pycache__/sampler.cpython-38.pyc
Binary file not shown.
6 changes: 3 additions & 3 deletions speed-bench/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@

sys.path.insert(0, './')

import sampling
import mclmc.sampling

from sampling.sampler import Sampler, Target
from sampling.dynamics import update_momentum
from mclmc.sampling.sampler import Sampler, Target
from mclmc.sampling.dynamics import update_momentum
import jax

import jax.numpy as jnp
Expand Down
6 changes: 3 additions & 3 deletions speed-bench/single.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@

sys.path.insert(0, './')

import sampling
import mclmc.sampling

from sampling.sampler import Sampler, Target
from sampling.dynamics import update_momentum
from mclmc.sampling.sampler import Sampler, Target
from mclmc.sampling.dynamics import update_momentum
import jax

import jax.numpy as jnp
Expand Down
2 changes: 1 addition & 1 deletion tests/benchmarks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from benchmarks.benchmarks_mchmc import *
from sampling.sampler import Sampler
from mclmc.sampling.sampler import Sampler
#from benchmarks import german_credit
import os
import jax
Expand Down
12 changes: 3 additions & 9 deletions tests/test_annealing.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,8 @@
import sys

sys.path.insert(0, '../../')
sys.path.insert(0, './')

# from sampling.annealing import Sampler
import jax
import jax.numpy as jnp
from sampling.annealing import Annealing
from sampling.sampler import Sampler, Target
import sampling.old_annealing as A
from mclmc.sampling.annealing import Annealing
from mclmc.sampling.sampler import Sampler, Target
import mclmc.sampling.old_annealing as A

temp_schedule = jnp.array([3.0, 2.0, 1.0])

Expand Down
9 changes: 2 additions & 7 deletions tests/test_mclmc.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,14 @@
import sys
import time

from pytest import raises
import pytest

sys.path.insert(0, '../../')
sys.path.insert(0, './')
from sampling.dynamics import leapfrog
from mclmc.sampling.dynamics import leapfrog
from mclmc.sampling.sampler import OutputType, Sampler, Target

import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt

from sampling.sampler import OutputType, Sampler, Target

nlogp = lambda x: 0.5*jnp.sum(jnp.square(x))

Expand Down
6 changes: 1 addition & 5 deletions tests/test_momentum_update.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@

import sys
sys.path.insert(0, './')

from sampling.dynamics import update_momentum
from mclmc.sampling.dynamics import update_momentum
import jax

import jax.numpy as jnp

def update_momentum_unstable(d):
Expand Down
6 changes: 3 additions & 3 deletions tests/tst_diagonal_precond.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import jax.numpy as jnp
import pandas as pd

from sampling.sampler import Sampler
from sampling.benchmark_targets import *
from sampling.grid_search import search_wrapper
from mclmc.sampling.sampler import Sampler
from mclmc.sampling.benchmark_targets import *
from mclmc.sampling.grid_search import search_wrapper



Expand Down

0 comments on commit 92c8a13

Please sign in to comment.