-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Xla parallel proxy #12
Conversation
Originally imported from gemma repository https://github.com/google/gemma_pytorch.git At commit version: cf8658c.
This will allow to execute models in a parallel way and interact with them from the caller thread. To see how it works in output mode, you can launch the test with debug enabled this way: DEBUG=1 pytest -s tests/test_parallel_proxy.py
7317e4b
to
c8d7bbe
Compare
c8d7bbe
to
747ad57
Compare
optimum/tpu/xla_parallel_proxy.py
Outdated
xmp.spawn(_mp_fn, args=(args), join=True, daemon=False) | ||
|
||
|
||
class ModelProxy: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would rename this, it's not clear imo it handles the MP logic, wdyt?
What about something like DistributedModel
or DistributedTpuModel
?
optimum/tpu/xla_parallel_proxy.py
Outdated
|
||
def send(self, command: ModelCommand, data: Dict = None): | ||
# First wait until model is ready to receive commands | ||
debug(f" MM Command {command} waiting for model to be ready") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's remove the DEBUG/debug reference and leverage python logging
What does this PR do?
This PR does several things:
adds an example with
gemma-2b
that show the inference as it worksAdds and implementation that will allow to launch and communicate with models running in parallel using
ModelProxy
Adds a test and that to a CI workflow.
Did you write any new necessary tests?