Skip to content

Commit

Permalink
Adding some doc
Browse files Browse the repository at this point in the history
  • Loading branch information
alonkukl committed Feb 15, 2024
1 parent c3b224c commit 256f863
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions qfactorjax/qfactor_sample_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,12 @@ def safe_call_jited_vmaped_state_sample_sweep(
untrys: Array,
training_states_kets: Array,
) -> tuple[Array, Array[float], Array[float], Array[int], Array[bool]]:
"""
We couldn't find a way to check if we are going to allocate more than
the GPU memory, so we created this "safe" function that calls
qfactor-sample and then if OOM exception is caught it recursively
calls qfactor-sample with half the multistarts
"""

try:
results = _jited_loop_vmaped_state_sample_sweep(
Expand Down

0 comments on commit 256f863

Please sign in to comment.