Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make_sampler creates sampler chain with all sampling parameters #1330

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 43 additions & 24 deletions llms/mlx_lm/sample_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,25 @@ def make_sampler(
"""
if temp == 0:
return lambda x: mx.argmax(x, axis=-1)
elif top_p > 0 and top_p < 1.0:
return lambda x: top_p_sampling(x, top_p, temp)
elif min_p != 0.0:
return lambda x: min_p_sampling(x, min_p, min_tokens_to_keep, temp)
elif top_k > 0:
return lambda x: top_k_sampling(x, top_k, temp)
else:
return lambda x: categorical_sampling(x, temp)

# Create sampler chain
sampling_methods = []
if top_k > 0:
sampling_methods.append(lambda x: apply_top_k(x, top_k))
if top_p > 0 and top_p < 1.0:
sampling_methods.append(lambda x: apply_top_p(x, top_p))
if min_p != 0.0:
sampling_methods.append(lambda x: apply_min_p(x, min_p, min_tokens_to_keep))

# Apply the sampling methods
def sampler(logits):
for method in sampling_methods:
logits = method(logits)

# Return the sampled token
return categorical_sampling(logits, temp)

return sampler


def make_logits_processors(
Expand Down Expand Up @@ -85,10 +96,9 @@ def logit_bias_processor(_, logits):


@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
def top_k_sampling(
def apply_top_k(
logprobs: mx.array,
top_k: int,
temperature=1.0,
) -> mx.array:
"""
Sample from only the top K tokens ranked by probability.
Expand All @@ -103,20 +113,18 @@ def top_k_sampling(
f"`top_k` has to be an integer in the (0, {vocab_size}] interval,"
f" but is {top_k}."
)
logprobs = logprobs * (1 / temperature)
mask_idx = mx.argpartition(-logprobs, kth=top_k - 1, axis=-1)[..., top_k:]
masked_logprobs = mx.put_along_axis(
logprobs, mask_idx, mx.array(-float("inf"), logprobs.dtype), axis=-1
)
return mx.random.categorical(masked_logprobs, axis=-1)
return masked_logprobs


@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
def min_p_sampling(
def apply_min_p(
logprobs: mx.array,
min_p: float,
min_tokens_to_keep: int = 1,
temperature=1.0,
) -> mx.array:
"""
Apply min-p sampling to the logprobs.
Expand Down Expand Up @@ -144,8 +152,6 @@ def min_p_sampling(
)
# reference implementation: https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L531-L605

logprobs = logprobs * (1 / temperature)

# Indices sorted in decreasing order
sorted_indices = mx.argsort(-logprobs, axis=-1)
sorted_logprobs = mx.take_along_axis(logprobs, sorted_indices, axis=-1)
Expand All @@ -163,25 +169,31 @@ def min_p_sampling(
# Create pool of tokens with probability less than scaled min_p
selected_logprobs = mx.where(tokens_to_remove, -float("inf"), sorted_logprobs)

# Return sampled tokens
sorted_tokens = mx.random.categorical(selected_logprobs, axis=-1)[:, None]
return mx.take_along_axis(sorted_indices, sorted_tokens, axis=-1).squeeze(1)
# Create a mapping to rearrange back to original indices
# Use argsort of sorted_indices to get the inverse permutation
inverse_indices = mx.argsort(sorted_indices, axis=-1)

# Rearrange selected_logprobs back to original order
original_order_logprobs = mx.take_along_axis(
selected_logprobs, inverse_indices, axis=-1
)

return original_order_logprobs


@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.array:
def apply_top_p(logits: mx.array, top_p: float) -> mx.array:
"""
Apply top-p (nucleus) sampling to logits.
Args:
logits: The logits from the model's output.
top_p: The cumulative probability threshold for top-p filtering.
temperature: Temperature parameter for softmax distribution reshaping.
Returns:
token selected based on the top-p criterion.
"""
# referenced implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L449-L460
probs = mx.softmax(logits * (1 / temperature), axis=-1)
probs = mx.softmax(logits, axis=-1)

# sort probs in ascending order
sorted_indices = mx.argsort(probs, axis=-1)
Expand All @@ -196,8 +208,15 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr
0,
)

sorted_tokens = mx.random.categorical(mx.log(top_probs), axis=-1)[:, None]
return mx.take_along_axis(sorted_indices, sorted_tokens, axis=-1).squeeze(1)
# Create a mapping to rearrange back to original indices
# Use argsort of sorted_indices to get the inverse permutation
inverse_indices = mx.argsort(sorted_indices, axis=-1)

# Rearrange top_probs back to original order
original_order_probs = mx.take_along_axis(top_probs, inverse_indices, axis=-1)

# Convert back to logits and return
return mx.log(mx.where(original_order_probs > 0, original_order_probs, 0))


@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
Expand Down
98 changes: 58 additions & 40 deletions llms/tests/test_sample_utils.py
Original file line number Diff line number Diff line change
@@ -1,79 +1,97 @@
import unittest

import mlx.core as mx
from mlx_lm.sample_utils import min_p_sampling, top_k_sampling, top_p_sampling
from mlx_lm.sample_utils import apply_min_p, apply_top_k, apply_top_p


class TestSampleUtils(unittest.TestCase):
def test_top_p_sampling(self):
def test_apply_top_p(self):
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
logits = mx.log(probs)
temperature = 1.0

token = top_p_sampling(logits, 0.3, temperature).item()
self.assertEqual(token, 0)
new_logits = apply_top_p(logits, 0.3)
actual_probs = mx.softmax(new_logits.squeeze())
self.assertEqual(actual_probs.tolist(), [1.0, 0.0, 0.0, 0.0])

token = top_p_sampling(logits, 0.95, temperature).item()
self.assertTrue(token in (0, 3))
new_logits = apply_top_p(logits, 0.95)
actual_probs = mx.softmax(new_logits.squeeze())
self.assertEqual(probs.squeeze().tolist(), actual_probs.tolist())

probs = mx.array([0.0, 0.5, 0.4, 0.1])[None]
logits = mx.log(probs)

token = top_p_sampling(logits, 0.4, temperature).item()
self.assertEqual(token, 1)

token = top_p_sampling(logits, 0.6, temperature).item()
self.assertTrue(token in (1, 2))

token = top_p_sampling(logits, 0.95, temperature).item()
self.assertTrue(token in (1, 2, 3))
new_logits = apply_top_p(logits, 0.4)
actual_probs = mx.softmax(new_logits.squeeze())
self.assertEqual(actual_probs.tolist(), [0.0, 1.0, 0.0, 0.0])

new_logits = apply_top_p(logits, 0.6)
actual_probs = mx.softmax(new_logits.squeeze())
self.assertEqual(
[round(p, 4) for p in actual_probs.tolist()], [0.0, 0.5556, 0.4444, 0.0]
)

new_logits = apply_top_p(logits, 0.95)
actual_probs = mx.softmax(new_logits.squeeze())
actual_rounded = [round(p, 4) for p in actual_probs.tolist()]
expected_rounded = [0.0, 0.5, 0.4, 0.1]
self.assertEqual(actual_rounded, expected_rounded)
self.assertAlmostEqual(sum(actual_probs.tolist()), 1.0)

# Batch mode works
probs = mx.array([[0.9, 0.0, 0.0, 0.1], [0.0, 0.8, 0.0, 0.1]])
probs = mx.array([[0.9, 0.0, 0.0, 0.1], [0.0, 0.8, 0.1, 0.1]])
logits = mx.log(probs)
tokens = top_p_sampling(logits, 0.5, temperature)
self.assertEqual(tokens.tolist(), [0, 1])
new_logits = apply_top_p(logits, 0.5)
actual_probs = mx.softmax(new_logits, axis=-1)
self.assertEqual(
actual_probs.tolist(), [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]]
)

def test_min_p_sampling(self):
def test_apply_min_p(self):
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
logits = mx.log(probs)
temperature = 1.0
token = min_p_sampling(logits, 0.8)
self.assertEqual(token, 0)
new_logits = apply_min_p(logits, 0.8)
actual_probs = mx.softmax(new_logits.squeeze())
self.assertEqual(actual_probs.tolist(), [1.0, 0.0, 0.0, 0.0])

probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
logits = mx.log(probs)
temperature = 1.0
for _ in range(5):
token = min_p_sampling(logits, 0.05)
self.assertTrue(token in (0, 3))
new_logits = apply_min_p(logits, 0.05)
actual_probs = mx.softmax(new_logits.squeeze())
self.assertEqual(actual_probs.tolist(), mx.squeeze(probs).tolist())

# Batch mode works
probs = mx.array([[0.9, 0.0, 0.0, 0.1], [0.0, 0.8, 0.0, 0.1]])
logits = mx.log(probs)
tokens = min_p_sampling(logits, 0.7)
self.assertEqual(tokens.tolist(), [0, 1])
new_logits = apply_min_p(logits, 0.7)
actual_probs = mx.softmax(new_logits, axis=-1)
self.assertEqual(
actual_probs.tolist(), [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]]
)

def test_top_k_sampling(self):
def test_apply_top_k(self):
probs = mx.array([0.9, 0.0, 0.0, 0.1])[None]
logits = mx.log(probs)

token = top_k_sampling(logits, 1).item()
self.assertEqual(token, 0)
new_logits = apply_top_k(logits, 1)
actual_probs = mx.softmax(new_logits.squeeze())
self.assertEqual(actual_probs.tolist(), [1.0, 0.0, 0.0, 0.0])

probs = mx.array([0.5, 0.0, 0.0, 0.5])[None]
tokens = set()
for _ in range(100):
token = top_k_sampling(logits, 2)
tokens.add(token.item())
self.assertEqual(tokens, {0, 3})
probs = mx.array([0.6, 0.0, 0.1, 0.3])[None]
logits = mx.log(probs)
new_logits = apply_top_k(logits, 2)
actual_probs = mx.softmax(new_logits.squeeze())
self.assertEqual(
[round(p, 4) for p in actual_probs.tolist()], [0.6667, 0.0, 0.0, 0.3333]
)

# Batch mode works
probs = mx.array([[0.9, 0.0, 0.0, 0.1], [0.0, 0.8, 0.0, 0.1]])
logits = mx.log(probs)

tokens = top_k_sampling(logits, 1)
self.assertEqual(tokens.tolist(), [0, 1])
new_logits = apply_top_k(logits, 1)
actual_probs = mx.softmax(new_logits, axis=-1)
self.assertEqual(
actual_probs.tolist(), [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]]
)


if __name__ == "__main__":
Expand Down