Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Xla parallel proxy #12

Merged
merged 10 commits into from
Apr 8, 2024
Merged

Xla parallel proxy #12

merged 10 commits into from
Apr 8, 2024

Conversation

tengomucho
Copy link
Collaborator

@tengomucho tengomucho commented Apr 5, 2024

What does this PR do?

This PR does several things:

  • adds an example with gemma-2b that show the inference as it works

  • Adds and implementation that will allow to launch and communicate with models running in parallel using ModelProxy

  • Adds a test and that to a CI workflow.

  • Did you write any new necessary tests?

Originally imported from gemma repository

https://github.com/google/gemma_pytorch.git

At commit version: cf8658c.
This will allow to execute models in a parallel way and interact with
them from the caller thread.
To see how it works in output mode, you can launch the test with debug
enabled this way:

DEBUG=1 pytest -s tests/test_parallel_proxy.py
@tengomucho tengomucho force-pushed the xla-parallel-proxy branch from 7317e4b to c8d7bbe Compare April 5, 2024 09:52
@tengomucho tengomucho force-pushed the xla-parallel-proxy branch from c8d7bbe to 747ad57 Compare April 5, 2024 10:19
@tengomucho tengomucho marked this pull request as ready for review April 5, 2024 10:29
@tengomucho tengomucho requested a review from mfuntowicz April 5, 2024 10:29
xmp.spawn(_mp_fn, args=(args), join=True, daemon=False)


class ModelProxy:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would rename this, it's not clear imo it handles the MP logic, wdyt?

What about something like DistributedModel or DistributedTpuModel ?


def send(self, command: ModelCommand, data: Dict = None):
# First wait until model is ready to receive commands
debug(f" MM Command {command} waiting for model to be ready")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove the DEBUG/debug reference and leverage python logging

.github/workflows/test-pytorch-xla-tpu.yml Outdated Show resolved Hide resolved
optimum/tpu/xla_parallel_proxy.py Outdated Show resolved Hide resolved
pyproject.toml Show resolved Hide resolved
tests/test_parallel_proxy.py Show resolved Hide resolved
@tengomucho tengomucho requested a review from mfuntowicz April 8, 2024 08:08
@tengomucho tengomucho merged commit 7b48145 into main Apr 8, 2024
2 checks passed
@mfuntowicz mfuntowicz deleted the xla-parallel-proxy branch April 8, 2024 08:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants