diff --git a/haiku/_src/base.py b/haiku/_src/base.py index a14e15ad8..58f339e3c 100644 --- a/haiku/_src/base.py +++ b/haiku/_src/base.py @@ -723,4 +723,5 @@ def with_rng(key: PRNGKey): Returns: Context manager under which the given sequence is active. """ + assert_context("with_rng") return current_frame().rng_stack(PRNGSequence(key)) diff --git a/haiku/_src/base_test.py b/haiku/_src/base_test.py index 7e90bd3e0..03578d467 100644 --- a/haiku/_src/base_test.py +++ b/haiku/_src/base_test.py @@ -329,6 +329,12 @@ def test_with_rng(self, seed): self.assertNotEqual(without_decorator_out, expected_output) self.assertEqual(with_decorator_out, expected_output) + def test_with_rng_no_transform(self): + with self.assertRaisesRegex(ValueError, + "must be used as part of an `hk.transform`"): + with base.with_rng(jax.random.PRNGKey(428)): + pass + def test_new_context(self): with base.new_context() as ctx: pass