An exercice both native JaxNS and Numpyro wrapper #41
Replies: 3 comments 1 reply
-
Q0: In your code you're not jit compling. Try: ns = jaxns.nested_sampling.NestedSampler(log_lik, prior_chain,
num_live_points=prior_chain.U_ndims*500)
ns = jax.jit(ns) Q1: Just to be clear are you asking about Q2: Bug in 0.0.7, that I've fixed, and will release in 0.0.8. Q3: you never need to resample for corner plot. It does it for you. You just feed in R4: With nested sampling you can't really control the number of samples you get out, since they are weighted. I guess it would be possible to recalculate ESS after each run and make a stopping criterion that is based on having enough ESS. That's not too hard. I'll make an issue for it. |
Beta Was this translation helpful? Give feedback.
-
Thanks @Joshuaalbert for your kind answers. |
Beta Was this translation helpful? Give feedback.
-
@jecampagne anything unanswered here? |
Beta Was this translation helpful? Give feedback.
-
Here is an exercise with a fit over observations. Here is the snippet un JaxNS (0.0.7). At the end I point some questions:
Q0 : I am running on a CPU machine (ie not a GPU) and it takes 3min30sec to get the results. This is strange because using Numpyro wrapper takes about 30sec to get running the NestedSampler and get 5000 samples.
Then
crashes
TypeError: percentile requires ndarray or scalar arguments, got <class 'list'> at position 1.
I get this error with other JaxNS exercices.I do not know if its ok or not, but after
I cannot use the corner plots given by the library, so I have written mine
So,
gives
which looks fine, at least it corresponds to the Numpy Wrapper results and also to NUTS sampling results too.
But (see also Q0 above):
Q1) I wander: from where JaxNS tells to give me 100_000 samples? while sometimes it is few hundred only?
Q2) Also why the percentile crashes in the JaxNS disgnotics (print/corner...)
Q3) why in the
https://github.com/Joshuaalbert/jaxns/blob/master/examples/jones_scalar_model.py
example there is no needs to usejaxns.utils.resample
to get the samples for corner plot ???R4) By the way a
get_samples(random key, num_samples)
a-la-Numpyro wrapping would be useful.Thanks
Beta Was this translation helpful? Give feedback.
All reactions