-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: GraphtransformerProcessor chunking #66
base: main
Are you sure you want to change the base?
Conversation
Hi Jan, thank you for adding this, very nice. Looking at the code it seems that some parts do the same but look slightly different. I was wondering if this would also be a good opportunity to reduce code duplication between GraphTransformerProcessorBlock and GraphTransformerMapperBlock. Maybe encapsulated in a common routine? I think differences are dim of x_skip |
I moved the common "attention" part to the GraphTransformerBaseBlock. |
What's the connection between this PR and the shard everything one? |
There is none, I can make it ready for merge |
Describe your changes
This PR adds chunking for the GraphTransformerProcessorBlock to reduce memory usage in inference. The functionality is equivalent to the GraphTransformerMapperBlock chunking and uses the same env variable
ANEMOI_INFERENCE_NUM_CHUNKS
to control chunking behaviour.Type of change
Please delete options that are not relevant.
Checklist before requesting a review
Tag possible reviewers
@ssmmnn11 @gabrieloks