Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix install #213

Merged
merged 6 commits into from
Dec 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,8 @@ before importing JAXNS.

# Change Log

7 Dec, 2024 -- JAXNS 2.6.7 released. Fix pip dependencies install.

13 Nov, 2024 -- JAXNS 2.6.6 released. Minor improvements to plotting.

9 Nov, 2024 -- JAXNS 2.6.5 released. Added gradient guided nested sampling. Removed `num_parallel_workers` in favour
Expand Down
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
project = "jaxns"
copyright = "2024, Joshua G. Albert"
author = "Joshua G. Albert"
release = "2.6.6"
release = "2.6.7"

# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
Expand Down
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "jaxns"
version = "2.6.6"
version = "2.6.7"
description = "Nested Sampling in JAX"
readme = "README.md"
requires-python = ">=3.9"
Expand All @@ -19,10 +19,12 @@ classifiers = [
"Operating System :: OS Independent"
]
urls = { "Homepage" = "https://github.com/joshuaalbert/jaxns" }
dynamic = ["dependencies"]

[project.optional-dependencies]
# Define the extras here; they will be loaded dynamically from setup.py
notebooks = [] # Placeholders; extras will load from setup.py
examples = [] # Placeholders; extras will load from setup.py
tests = [] # Placeholders; extras will load from setup.py

[tool.setuptools]
include-package-data = true
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@ def load_requirements(file_name):
install_requires=load_requirements("requirements.txt"),
extras_require={
"examples": load_requirements("requirements-examples.txt"),
},
tests_require=load_requirements("requirements-tests.txt"),
"tests": load_requirements("requirements-tests.txt"),
}
)
18 changes: 16 additions & 2 deletions src/jaxns/framework/special_priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
"Poisson",
"UnnormalisedDirichlet",
"Empirical",
"TruncationWrapper"
"TruncationWrapper",
"ExplicitDensityPrior",
]


Expand Down Expand Up @@ -72,6 +73,7 @@ def _quantile(self, U):
sample = jnp.less(U, probs)
return sample.astype(self.dtype)


class Beta(SpecialPrior):
def __init__(self, *, concentration0=None, concentration1=None, name: Optional[str] = None):
super(Beta, self).__init__(name=name)
Expand Down Expand Up @@ -443,7 +445,8 @@ class Empirical(SpecialPrior):
Represents the empirical distribution of a set of 1D samples, with arbitrary batch dimension.
"""

def __init__(self, *, samples: jax.Array, resolution: int = 100, name: Optional[str] = None):
def __init__(self, *, samples: jax.Array, support_min: Optional[FloatArray] = None,
support_max: Optional[FloatArray] = None, resolution: int = 100, name: Optional[str] = None):
super(Empirical, self).__init__(name=name)
if len(np.shape(samples)) < 1:
raise ValueError("Samples must have at least one dimension")
Expand All @@ -452,6 +455,17 @@ def __init__(self, *, samples: jax.Array, resolution: int = 100, name: Optional[
if resolution < 1:
raise ValueError("Resolution must be at least 1")
samples = jnp.asarray(samples)
# Add 1 point for each support endpoint
endpoints = []
if support_min is not None:
endpoints.append(support_min)
if support_max is not None:
endpoints.append(support_max)
if len(endpoints) > 0:
samples = jnp.concatenate([samples, jnp.asarray(endpoints)])

resolution = min(resolution, len(samples) - 1)

self._q = jnp.linspace(0., 100., resolution + 1)
self._percentiles = jnp.reshape(jnp.percentile(samples, self._q, axis=-1), (resolution + 1, -1))

Expand Down
8 changes: 4 additions & 4 deletions src/jaxns/framework/tests/test_prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,15 +321,15 @@ def test_forced_identifiability():


def test_empirical():
samples = jax.random.normal(jax.random.PRNGKey(42), shape=(5, 2000), dtype=mp_policy.measure_dtype)
samples = jax.random.normal(jax.random.PRNGKey(42), shape=(2000,), dtype=mp_policy.measure_dtype)
prior = Empirical(samples=samples, resolution=100, name='x')
assert prior._percentiles.shape == (101, 5)
assert prior._percentiles.shape == (101, 1)

x = prior.forward(jnp.ones(prior.base_shape, mp_policy.measure_dtype))
assert x.shape == (5,)
assert x.shape == ()
assert jnp.all(jnp.bitwise_not(jnp.isnan(x)))
x = prior.forward(jnp.zeros(prior.base_shape, mp_policy.measure_dtype))
assert x.shape == (5,)
assert x.shape == ()
assert jnp.all(jnp.bitwise_not(jnp.isnan(x)))

x = prior.forward(0.5 * jnp.ones(prior.base_shape, mp_policy.measure_dtype))
Expand Down
6 changes: 3 additions & 3 deletions src/jaxns/public.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class NestedSampler:
max_samples: Optional[Union[int, float]] = None
num_live_points: Optional[int] = None
num_slices: Optional[int] = None
s: Optional[int] = None
s: Optional[Union[int, float]] = None
k: Optional[int] = None
c: Optional[int] = None
devices: Optional[List[xla_client.Device]] = None
Expand All @@ -70,9 +70,9 @@ def __post_init__(self):
# Determine number of slices per acceptance
if self.num_slices is None:
if self.difficult_model:
self.s = 10 if self.s is None else int(self.s)
self.s = 10 if self.s is None else float(self.s)
else:
self.s = 5 if self.s is None else int(self.s)
self.s = 5 if self.s is None else float(self.s)
if self.s <= 0:
raise ValueError(f"Expected s > 0, got s={self.s}")
self.num_slices = self.model.U_ndims * self.s
Expand Down
Loading