From 58e912966ac86e6b09b82ec705f7fb4418172f97 Mon Sep 17 00:00:00 2001 From: Neil Mehta Date: Sat, 8 Mar 2025 08:55:49 -0500 Subject: [PATCH 1/3] top_p refactor --- llms/mlx_lm/sample_utils.py | 16 ++++++++---- llms/tests/test_sample_utils.py | 45 ++++++++++++++++++++------------- 2 files changed, 39 insertions(+), 22 deletions(-) diff --git a/llms/mlx_lm/sample_utils.py b/llms/mlx_lm/sample_utils.py index 23e08d97a..d7049f7dc 100644 --- a/llms/mlx_lm/sample_utils.py +++ b/llms/mlx_lm/sample_utils.py @@ -169,19 +169,18 @@ def min_p_sampling( @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) -def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.array: +def top_p_sampling(logits: mx.array, top_p: float) -> mx.array: """ Apply top-p (nucleus) sampling to logits. Args: logits: The logits from the model's output. top_p: The cumulative probability threshold for top-p filtering. - temperature: Temperature parameter for softmax distribution reshaping. Returns: token selected based on the top-p criterion. """ # referenced implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L449-L460 - probs = mx.softmax(logits * (1 / temperature), axis=-1) + probs = mx.softmax(logits, axis=-1) # sort probs in ascending order sorted_indices = mx.argsort(probs, axis=-1) @@ -196,8 +195,15 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr 0, ) - sorted_tokens = mx.random.categorical(mx.log(top_probs), axis=-1)[:, None] - return mx.take_along_axis(sorted_indices, sorted_tokens, axis=-1).squeeze(1) + # Create a mapping to rearrange back to original indices + # Use argsort of sorted_indices to get the inverse permutation + inverse_indices = mx.argsort(sorted_indices, axis=-1) + + # Rearrange top_probs back to original order + original_order_probs = mx.take_along_axis(top_probs, inverse_indices, axis=-1) + + # Convert back to logits and return + return mx.log(mx.where(original_order_probs > 0, original_order_probs, 0)) @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) diff --git a/llms/tests/test_sample_utils.py b/llms/tests/test_sample_utils.py index f12abbf48..5a3d8847a 100644 --- a/llms/tests/test_sample_utils.py +++ b/llms/tests/test_sample_utils.py @@ -8,31 +8,42 @@ class TestSampleUtils(unittest.TestCase): def test_top_p_sampling(self): probs = mx.array([0.9, 0.0, 0.0, 0.1])[None] logits = mx.log(probs) - temperature = 1.0 - token = top_p_sampling(logits, 0.3, temperature).item() - self.assertEqual(token, 0) + actual_logits = top_p_sampling(logits, 0.3) + actual_probs = mx.softmax(actual_logits.squeeze()) + self.assertEqual(actual_probs.tolist(), [1.0, 0.0, 0.0, 0.0]) - token = top_p_sampling(logits, 0.95, temperature).item() - self.assertTrue(token in (0, 3)) + actual_logits = top_p_sampling(logits, 0.95) + actual_probs = mx.softmax(actual_logits.squeeze()) + self.assertEqual(probs.squeeze().tolist(), actual_probs.tolist()) probs = mx.array([0.0, 0.5, 0.4, 0.1])[None] logits = mx.log(probs) - - token = top_p_sampling(logits, 0.4, temperature).item() - self.assertEqual(token, 1) - - token = top_p_sampling(logits, 0.6, temperature).item() - self.assertTrue(token in (1, 2)) - - token = top_p_sampling(logits, 0.95, temperature).item() - self.assertTrue(token in (1, 2, 3)) + actual_logits = top_p_sampling(logits, 0.4) + actual_probs = mx.softmax(actual_logits.squeeze()) + self.assertEqual(actual_probs.tolist(), [0.0, 1.0, 0.0, 0.0]) + + actual_logits = top_p_sampling(logits, 0.6) + actual_probs = mx.softmax(actual_logits.squeeze()) + self.assertEqual( + [round(p, 4) for p in actual_probs.tolist()], [0.0, 0.5556, 0.4444, 0.0] + ) + + actual_logits = top_p_sampling(logits, 0.95) + actual_probs = mx.softmax(actual_logits.squeeze()) + actual_rounded = [round(p, 4) for p in actual_probs.tolist()] + expected_rounded = [0.0, 0.5, 0.4, 0.1] + self.assertEqual(actual_rounded, expected_rounded) + self.assertAlmostEqual(sum(actual_probs.tolist()), 1.0) # Batch mode works - probs = mx.array([[0.9, 0.0, 0.0, 0.1], [0.0, 0.8, 0.0, 0.1]]) + probs = mx.array([[0.9, 0.0, 0.0, 0.1], [0.0, 0.8, 0.1, 0.1]]) logits = mx.log(probs) - tokens = top_p_sampling(logits, 0.5, temperature) - self.assertEqual(tokens.tolist(), [0, 1]) + actual_logits = top_p_sampling(logits, 0.5) + actual_probs = mx.softmax(actual_logits, axis=-1) + self.assertEqual( + actual_probs.tolist(), [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]] + ) def test_min_p_sampling(self): probs = mx.array([0.9, 0.0, 0.0, 0.1])[None] From 932b7c0510476b320b30352998eaa6acfcb9068e Mon Sep 17 00:00:00 2001 From: Neil Mehta Date: Sat, 8 Mar 2025 10:12:28 -0500 Subject: [PATCH 2/3] top_k and min_p refactor --- llms/mlx_lm/sample_utils.py | 20 +++++----- llms/tests/test_sample_utils.py | 69 ++++++++++++++++++--------------- 2 files changed, 49 insertions(+), 40 deletions(-) diff --git a/llms/mlx_lm/sample_utils.py b/llms/mlx_lm/sample_utils.py index d7049f7dc..5ad3d2c54 100644 --- a/llms/mlx_lm/sample_utils.py +++ b/llms/mlx_lm/sample_utils.py @@ -88,7 +88,6 @@ def logit_bias_processor(_, logits): def top_k_sampling( logprobs: mx.array, top_k: int, - temperature=1.0, ) -> mx.array: """ Sample from only the top K tokens ranked by probability. @@ -103,12 +102,11 @@ def top_k_sampling( f"`top_k` has to be an integer in the (0, {vocab_size}] interval," f" but is {top_k}." ) - logprobs = logprobs * (1 / temperature) mask_idx = mx.argpartition(-logprobs, kth=top_k - 1, axis=-1)[..., top_k:] masked_logprobs = mx.put_along_axis( logprobs, mask_idx, mx.array(-float("inf"), logprobs.dtype), axis=-1 ) - return mx.random.categorical(masked_logprobs, axis=-1) + return masked_logprobs @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) @@ -116,7 +114,6 @@ def min_p_sampling( logprobs: mx.array, min_p: float, min_tokens_to_keep: int = 1, - temperature=1.0, ) -> mx.array: """ Apply min-p sampling to the logprobs. @@ -144,8 +141,6 @@ def min_p_sampling( ) # reference implementation: https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L531-L605 - logprobs = logprobs * (1 / temperature) - # Indices sorted in decreasing order sorted_indices = mx.argsort(-logprobs, axis=-1) sorted_logprobs = mx.take_along_axis(logprobs, sorted_indices, axis=-1) @@ -163,9 +158,16 @@ def min_p_sampling( # Create pool of tokens with probability less than scaled min_p selected_logprobs = mx.where(tokens_to_remove, -float("inf"), sorted_logprobs) - # Return sampled tokens - sorted_tokens = mx.random.categorical(selected_logprobs, axis=-1)[:, None] - return mx.take_along_axis(sorted_indices, sorted_tokens, axis=-1).squeeze(1) + # Create a mapping to rearrange back to original indices + # Use argsort of sorted_indices to get the inverse permutation + inverse_indices = mx.argsort(sorted_indices, axis=-1) + + # Rearrange selected_logprobs back to original order + original_order_logprobs = mx.take_along_axis( + selected_logprobs, inverse_indices, axis=-1 + ) + + return original_order_logprobs @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) diff --git a/llms/tests/test_sample_utils.py b/llms/tests/test_sample_utils.py index 5a3d8847a..19b65e4f5 100644 --- a/llms/tests/test_sample_utils.py +++ b/llms/tests/test_sample_utils.py @@ -9,28 +9,28 @@ def test_top_p_sampling(self): probs = mx.array([0.9, 0.0, 0.0, 0.1])[None] logits = mx.log(probs) - actual_logits = top_p_sampling(logits, 0.3) - actual_probs = mx.softmax(actual_logits.squeeze()) + new_logits = top_p_sampling(logits, 0.3) + actual_probs = mx.softmax(new_logits.squeeze()) self.assertEqual(actual_probs.tolist(), [1.0, 0.0, 0.0, 0.0]) - actual_logits = top_p_sampling(logits, 0.95) - actual_probs = mx.softmax(actual_logits.squeeze()) + new_logits = top_p_sampling(logits, 0.95) + actual_probs = mx.softmax(new_logits.squeeze()) self.assertEqual(probs.squeeze().tolist(), actual_probs.tolist()) probs = mx.array([0.0, 0.5, 0.4, 0.1])[None] logits = mx.log(probs) - actual_logits = top_p_sampling(logits, 0.4) - actual_probs = mx.softmax(actual_logits.squeeze()) + new_logits = top_p_sampling(logits, 0.4) + actual_probs = mx.softmax(new_logits.squeeze()) self.assertEqual(actual_probs.tolist(), [0.0, 1.0, 0.0, 0.0]) - actual_logits = top_p_sampling(logits, 0.6) - actual_probs = mx.softmax(actual_logits.squeeze()) + new_logits = top_p_sampling(logits, 0.6) + actual_probs = mx.softmax(new_logits.squeeze()) self.assertEqual( [round(p, 4) for p in actual_probs.tolist()], [0.0, 0.5556, 0.4444, 0.0] ) - actual_logits = top_p_sampling(logits, 0.95) - actual_probs = mx.softmax(actual_logits.squeeze()) + new_logits = top_p_sampling(logits, 0.95) + actual_probs = mx.softmax(new_logits.squeeze()) actual_rounded = [round(p, 4) for p in actual_probs.tolist()] expected_rounded = [0.0, 0.5, 0.4, 0.1] self.assertEqual(actual_rounded, expected_rounded) @@ -39,8 +39,8 @@ def test_top_p_sampling(self): # Batch mode works probs = mx.array([[0.9, 0.0, 0.0, 0.1], [0.0, 0.8, 0.1, 0.1]]) logits = mx.log(probs) - actual_logits = top_p_sampling(logits, 0.5) - actual_probs = mx.softmax(actual_logits, axis=-1) + new_logits = top_p_sampling(logits, 0.5) + actual_probs = mx.softmax(new_logits, axis=-1) self.assertEqual( actual_probs.tolist(), [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]] ) @@ -48,43 +48,50 @@ def test_top_p_sampling(self): def test_min_p_sampling(self): probs = mx.array([0.9, 0.0, 0.0, 0.1])[None] logits = mx.log(probs) - temperature = 1.0 - token = min_p_sampling(logits, 0.8) - self.assertEqual(token, 0) + new_logits = min_p_sampling(logits, 0.8) + actual_probs = mx.softmax(new_logits.squeeze()) + self.assertEqual(actual_probs.tolist(), [1.0, 0.0, 0.0, 0.0]) probs = mx.array([0.9, 0.0, 0.0, 0.1])[None] logits = mx.log(probs) - temperature = 1.0 - for _ in range(5): - token = min_p_sampling(logits, 0.05) - self.assertTrue(token in (0, 3)) + new_logits = min_p_sampling(logits, 0.05) + actual_probs = mx.softmax(new_logits.squeeze()) + self.assertEqual(actual_probs.tolist(), mx.squeeze(probs).tolist()) # Batch mode works probs = mx.array([[0.9, 0.0, 0.0, 0.1], [0.0, 0.8, 0.0, 0.1]]) logits = mx.log(probs) - tokens = min_p_sampling(logits, 0.7) - self.assertEqual(tokens.tolist(), [0, 1]) + new_logits = min_p_sampling(logits, 0.7) + actual_probs = mx.softmax(new_logits, axis=-1) + self.assertEqual( + actual_probs.tolist(), [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]] + ) def test_top_k_sampling(self): probs = mx.array([0.9, 0.0, 0.0, 0.1])[None] logits = mx.log(probs) - token = top_k_sampling(logits, 1).item() - self.assertEqual(token, 0) + new_logits = top_k_sampling(logits, 1) + actual_probs = mx.softmax(new_logits.squeeze()) + self.assertEqual(actual_probs.tolist(), [1.0, 0.0, 0.0, 0.0]) - probs = mx.array([0.5, 0.0, 0.0, 0.5])[None] - tokens = set() - for _ in range(100): - token = top_k_sampling(logits, 2) - tokens.add(token.item()) - self.assertEqual(tokens, {0, 3}) + probs = mx.array([0.6, 0.0, 0.1, 0.3])[None] + logits = mx.log(probs) + new_logits = top_k_sampling(logits, 2) + actual_probs = mx.softmax(new_logits.squeeze()) + self.assertEqual( + [round(p, 4) for p in actual_probs.tolist()], [0.6667, 0.0, 0.0, 0.3333] + ) # Batch mode works probs = mx.array([[0.9, 0.0, 0.0, 0.1], [0.0, 0.8, 0.0, 0.1]]) logits = mx.log(probs) - tokens = top_k_sampling(logits, 1) - self.assertEqual(tokens.tolist(), [0, 1]) + new_logits = top_k_sampling(logits, 1) + actual_probs = mx.softmax(new_logits, axis=-1) + self.assertEqual( + actual_probs.tolist(), [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]] + ) if __name__ == "__main__": From 956da0ddc7485cab97a01487fcc04c5174a5e3b7 Mon Sep 17 00:00:00 2001 From: Neil Mehta Date: Sat, 8 Mar 2025 14:08:55 -0500 Subject: [PATCH 3/3] Create sampler chain --- llms/mlx_lm/sample_utils.py | 33 ++++++++++++++++++++++----------- llms/tests/test_sample_utils.py | 32 ++++++++++++++++---------------- 2 files changed, 38 insertions(+), 27 deletions(-) diff --git a/llms/mlx_lm/sample_utils.py b/llms/mlx_lm/sample_utils.py index 5ad3d2c54..d62c7f755 100644 --- a/llms/mlx_lm/sample_utils.py +++ b/llms/mlx_lm/sample_utils.py @@ -35,14 +35,25 @@ def make_sampler( """ if temp == 0: return lambda x: mx.argmax(x, axis=-1) - elif top_p > 0 and top_p < 1.0: - return lambda x: top_p_sampling(x, top_p, temp) - elif min_p != 0.0: - return lambda x: min_p_sampling(x, min_p, min_tokens_to_keep, temp) - elif top_k > 0: - return lambda x: top_k_sampling(x, top_k, temp) - else: - return lambda x: categorical_sampling(x, temp) + + # Create sampler chain + sampling_methods = [] + if top_k > 0: + sampling_methods.append(lambda x: apply_top_k(x, top_k)) + if top_p > 0 and top_p < 1.0: + sampling_methods.append(lambda x: apply_top_p(x, top_p)) + if min_p != 0.0: + sampling_methods.append(lambda x: apply_min_p(x, min_p, min_tokens_to_keep)) + + # Apply the sampling methods + def sampler(logits): + for method in sampling_methods: + logits = method(logits) + + # Return the sampled token + return categorical_sampling(logits, temp) + + return sampler def make_logits_processors( @@ -85,7 +96,7 @@ def logit_bias_processor(_, logits): @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) -def top_k_sampling( +def apply_top_k( logprobs: mx.array, top_k: int, ) -> mx.array: @@ -110,7 +121,7 @@ def top_k_sampling( @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) -def min_p_sampling( +def apply_min_p( logprobs: mx.array, min_p: float, min_tokens_to_keep: int = 1, @@ -171,7 +182,7 @@ def min_p_sampling( @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) -def top_p_sampling(logits: mx.array, top_p: float) -> mx.array: +def apply_top_p(logits: mx.array, top_p: float) -> mx.array: """ Apply top-p (nucleus) sampling to logits. diff --git a/llms/tests/test_sample_utils.py b/llms/tests/test_sample_utils.py index 19b65e4f5..a8664fd9a 100644 --- a/llms/tests/test_sample_utils.py +++ b/llms/tests/test_sample_utils.py @@ -1,35 +1,35 @@ import unittest import mlx.core as mx -from mlx_lm.sample_utils import min_p_sampling, top_k_sampling, top_p_sampling +from mlx_lm.sample_utils import apply_min_p, apply_top_k, apply_top_p class TestSampleUtils(unittest.TestCase): - def test_top_p_sampling(self): + def test_apply_top_p(self): probs = mx.array([0.9, 0.0, 0.0, 0.1])[None] logits = mx.log(probs) - new_logits = top_p_sampling(logits, 0.3) + new_logits = apply_top_p(logits, 0.3) actual_probs = mx.softmax(new_logits.squeeze()) self.assertEqual(actual_probs.tolist(), [1.0, 0.0, 0.0, 0.0]) - new_logits = top_p_sampling(logits, 0.95) + new_logits = apply_top_p(logits, 0.95) actual_probs = mx.softmax(new_logits.squeeze()) self.assertEqual(probs.squeeze().tolist(), actual_probs.tolist()) probs = mx.array([0.0, 0.5, 0.4, 0.1])[None] logits = mx.log(probs) - new_logits = top_p_sampling(logits, 0.4) + new_logits = apply_top_p(logits, 0.4) actual_probs = mx.softmax(new_logits.squeeze()) self.assertEqual(actual_probs.tolist(), [0.0, 1.0, 0.0, 0.0]) - new_logits = top_p_sampling(logits, 0.6) + new_logits = apply_top_p(logits, 0.6) actual_probs = mx.softmax(new_logits.squeeze()) self.assertEqual( [round(p, 4) for p in actual_probs.tolist()], [0.0, 0.5556, 0.4444, 0.0] ) - new_logits = top_p_sampling(logits, 0.95) + new_logits = apply_top_p(logits, 0.95) actual_probs = mx.softmax(new_logits.squeeze()) actual_rounded = [round(p, 4) for p in actual_probs.tolist()] expected_rounded = [0.0, 0.5, 0.4, 0.1] @@ -39,45 +39,45 @@ def test_top_p_sampling(self): # Batch mode works probs = mx.array([[0.9, 0.0, 0.0, 0.1], [0.0, 0.8, 0.1, 0.1]]) logits = mx.log(probs) - new_logits = top_p_sampling(logits, 0.5) + new_logits = apply_top_p(logits, 0.5) actual_probs = mx.softmax(new_logits, axis=-1) self.assertEqual( actual_probs.tolist(), [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]] ) - def test_min_p_sampling(self): + def test_apply_min_p(self): probs = mx.array([0.9, 0.0, 0.0, 0.1])[None] logits = mx.log(probs) - new_logits = min_p_sampling(logits, 0.8) + new_logits = apply_min_p(logits, 0.8) actual_probs = mx.softmax(new_logits.squeeze()) self.assertEqual(actual_probs.tolist(), [1.0, 0.0, 0.0, 0.0]) probs = mx.array([0.9, 0.0, 0.0, 0.1])[None] logits = mx.log(probs) - new_logits = min_p_sampling(logits, 0.05) + new_logits = apply_min_p(logits, 0.05) actual_probs = mx.softmax(new_logits.squeeze()) self.assertEqual(actual_probs.tolist(), mx.squeeze(probs).tolist()) # Batch mode works probs = mx.array([[0.9, 0.0, 0.0, 0.1], [0.0, 0.8, 0.0, 0.1]]) logits = mx.log(probs) - new_logits = min_p_sampling(logits, 0.7) + new_logits = apply_min_p(logits, 0.7) actual_probs = mx.softmax(new_logits, axis=-1) self.assertEqual( actual_probs.tolist(), [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]] ) - def test_top_k_sampling(self): + def test_apply_top_k(self): probs = mx.array([0.9, 0.0, 0.0, 0.1])[None] logits = mx.log(probs) - new_logits = top_k_sampling(logits, 1) + new_logits = apply_top_k(logits, 1) actual_probs = mx.softmax(new_logits.squeeze()) self.assertEqual(actual_probs.tolist(), [1.0, 0.0, 0.0, 0.0]) probs = mx.array([0.6, 0.0, 0.1, 0.3])[None] logits = mx.log(probs) - new_logits = top_k_sampling(logits, 2) + new_logits = apply_top_k(logits, 2) actual_probs = mx.softmax(new_logits.squeeze()) self.assertEqual( [round(p, 4) for p in actual_probs.tolist()], [0.6667, 0.0, 0.0, 0.3333] @@ -87,7 +87,7 @@ def test_top_k_sampling(self): probs = mx.array([[0.9, 0.0, 0.0, 0.1], [0.0, 0.8, 0.0, 0.1]]) logits = mx.log(probs) - new_logits = top_k_sampling(logits, 1) + new_logits = apply_top_k(logits, 1) actual_probs = mx.softmax(new_logits, axis=-1) self.assertEqual( actual_probs.tolist(), [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]]