-
Notifications
You must be signed in to change notification settings - Fork 360
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[JAX] Custom Op Workspace Tensors from XLA Buffers #532
Conversation
@mingxu1067 to review. |
out_types = [ | ||
ir.RankedTensorType.get(out_shape, output_type), | ||
ir.RankedTensorType.get(batch_shape, ir_mu_dtype), | ||
ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype), | ||
ir.RankedTensorType.get(work_meta[0], te_dtype_to_ir_dtype(work_meta[1])), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These changes make the number of outputs from lowering
is different with abstract
. I assume there would be a XLA assertion, could you test this to see if any potential issues?
Besides, these workspace tensors are never used, then it might be removed during XLA compiling. We might have to test this also.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The workspace tensors I added appear to be used in the C++ custom ops and NVTE kernels they call. They were being allocated with cudaMalloc in WorkspaceManager, which is what we want to replace with this PR. Did I misread/misunderstand something about this?
It is entirely possible I unintentionally skipped over additional changes needed in the abstract (or elsewhere) to make this work correctly. Do we need dummy outputs in the abstract to make sure we don't lose the workspaces to XLA optimization?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes the purpose of this PR is to remove cudaMalloc in WorkspaceManager
.
I am just thinking there might be two issues
- The mismatching number of outputs between
lowering
andabstract
might trigger some assertion. If so, then we have to seperateabstract
forinner_p
andouter_p
, then add dummy outouts toabstract
of inner_p. - Ununsed workspace tensors might be removed via some XLA's memory optimization pass. But after carefully thinking, we might not to worry about this, since XLA has less knowledge to custom calls, therefore it usually skips those optimization to custom calls
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-
If JAX doesn't raise an error due to different number of inputs/outputs, please raise this and give a repro (don't need to be a minimal one). I can look at adding the missing check in JAX.
-
XLA and all compilers can remove one full operation if all its outputs aren't used and is a side-effect free operation. But it can't remove only part of one operations outputs unless it is know that this instruction support this. So yes, XLA can't do that for custom_call, but in practice, it doesn't do that for native operations too . I don't know any such exception in XLA.
Would it be better to not expose the workspace data type to the Another thing is that we can calculate the workspace size in modules.cpp before returning back to the cpp_extensions.py, so that we can return only a single |
9129b8e
to
830b67f
Compare
/te-ci jax |
830b67f
to
04b14b2
Compare
cfe0dda
to
a8c3fdc
Compare
/te-ci jax |
6bc5b3f
to
1e6d94a
Compare
@nouiz @mingxu1067 @zlsh80826 This is passing LayerNorm and FusedAttn tests on my end. I'm running the CI but this should be ready for final review now. Thanks! |
/te-ci jax |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM at high level.
out_types = [ | ||
ir.RankedTensorType.get(out_shape, output_type), | ||
ir.RankedTensorType.get(batch_shape, ir_mu_dtype), | ||
ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype), | ||
ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)), | ||
ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)) | |
ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)) |
wkspace_size = reduce(operator.mul, wkspace_aval.shape) | ||
barrier_size = reduce(operator.mul, barrier_aval.shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wkspace_size = reduce(operator.mul, wkspace_aval.shape) | |
barrier_size = reduce(operator.mul, barrier_aval.shape) | |
wkspace_size = wkspace_aval.size | |
barrier_size = barrier_aval.size |
wkspace_size = reduce(operator.mul, wkspace_aval.shape) | ||
barrier_size = reduce(operator.mul, barrier_aval.shape) | ||
dgamma_part_size = reduce(operator.mul, dgamma_part_aval.shape) | ||
dbeta_part_size = reduce(operator.mul, dbeta_part_aval.shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wkspace_size = reduce(operator.mul, wkspace_aval.shape) | |
barrier_size = reduce(operator.mul, barrier_aval.shape) | |
dgamma_part_size = reduce(operator.mul, dgamma_part_aval.shape) | |
dbeta_part_size = reduce(operator.mul, dbeta_part_aval.shape) | |
wkspace_size = wkspace_aval.size | |
barrier_size = barrier_aval.size | |
dgamma_part_size = dgamma_part_aval.size | |
dbeta_part_size = dbeta_part_aval.size |
wkspace_size = reduce(operator.mul, wkspace_aval.shape) | ||
barrier_size = reduce(operator.mul, barrier_aval.shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wkspace_size = reduce(operator.mul, wkspace_aval.shape) | |
barrier_size = reduce(operator.mul, barrier_aval.shape) | |
wkspace_size = wkspace_aval.size | |
barrier_size = barrier_aval.size |
wkspace_size = reduce(operator.mul, wkspace_aval.shape) | ||
barrier_size = reduce(operator.mul, barrier_aval.shape) | ||
dgamma_part_size = reduce(operator.mul, dgamma_part_aval.shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wkspace_size = reduce(operator.mul, wkspace_aval.shape) | |
barrier_size = reduce(operator.mul, barrier_aval.shape) | |
dgamma_part_size = reduce(operator.mul, dgamma_part_aval.shape) | |
wkspace_size = wkspace_aval.size | |
barrier_size = barrier_aval.size | |
dgamma_part_size = dgamma_part_aval.size |
out_types = [ | ||
ir.RankedTensorType.get(x_shape, x_type.element_type), | ||
ir.RankedTensorType.get(g_shape, g_type.element_type), | ||
ir.RankedTensorType.get(b_shape, b_type.element_type), | ||
ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)), | ||
ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)), | ||
ir.RankedTensorType.get(dgamma_part_aval.shape, | ||
jax_dtype_to_ir_dtype(dgamma_part_aval.dtype)), | ||
ir.RankedTensorType.get(dbeta_part_aval.shape, | ||
jax_dtype_to_ir_dtype(dbeta_part_aval.dtype)) | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would work, but the current enumerated list is also fine to me.
out_types = [ | |
ir.RankedTensorType.get(x_shape, x_type.element_type), | |
ir.RankedTensorType.get(g_shape, g_type.element_type), | |
ir.RankedTensorType.get(b_shape, b_type.element_type), | |
ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)), | |
ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)), | |
ir.RankedTensorType.get(dgamma_part_aval.shape, | |
jax_dtype_to_ir_dtype(dgamma_part_aval.dtype)), | |
ir.RankedTensorType.get(dbeta_part_aval.shape, | |
jax_dtype_to_ir_dtype(dbeta_part_aval.dtype)) | |
] | |
out_types = [ | |
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype)) | |
for output in ctx.avals_out | |
] |
batch_size = reduce(operator.mul, batch_shape) | ||
|
||
wkspace_aval = ctx.avals_out[-1] | ||
wkspace_size = reduce(operator.mul, wkspace_aval.shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wkspace_size = reduce(operator.mul, wkspace_aval.shape) | |
wkspace_size = wkspace_aval.size |
wkspace_size = reduce(operator.mul, wkspace_aval.shape) | ||
barrier_size = reduce(operator.mul, barrier_aval.shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wkspace_size = reduce(operator.mul, wkspace_aval.shape) | |
barrier_size = reduce(operator.mul, barrier_aval.shape) | |
wkspace_size = wkspace_aval.size | |
barrier_size = barrier_aval.size |
wkspace_size = reduce(operator.mul, wkspace_aval.shape) | ||
barrier_size = reduce(operator.mul, barrier_aval.shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wkspace_size = reduce(operator.mul, wkspace_aval.shape) | |
barrier_size = reduce(operator.mul, barrier_aval.shape) | |
wkspace_size = wkspace_aval.size | |
barrier_size = barrier_aval.size |
12bd429
to
088a79b
Compare
/te-ci jax |
088a79b
to
d5918b1
Compare
/te-ci jax |
@cyanguwa There are some (minor) changes to the common fused attn kernels here and it would be great to get your feedback on them to make sure there won't be unintended consequences. Thanks! |
…equest buffers from XLA for their workspace tensors. Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
…eductions in primitives Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
…ecuting dummy kernel call with nullptr stream Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
2c1703a
to
199d27f
Compare
* Removed cudaMalloc/WorkspaceManager in JAX csrc. JAX custom ops now request buffers from XLA for their workspace tensors. Signed-off-by: Alp Dener <adener@nvidia.com> * removed unused GEMM C++ API in TE-JAX Signed-off-by: Alp Dener <adener@nvidia.com> * fixed typo in layernorm_geglu_fp8_mlp and removed unnecessary shape reductions in primitives Signed-off-by: Alp Dener <adener@nvidia.com> * fixed import order for linting Signed-off-by: Alp Dener <adener@nvidia.com> * fixed custom op errors due to incorrect static arg nums in JAX jit Signed-off-by: Alp Dener <adener@nvidia.com> * shifted cudnnSetStream further down the kernel to avoid error when executing dummy kernel call with nullptr stream Signed-off-by: Alp Dener <adener@nvidia.com> * fixed linting errors for blank lines Signed-off-by: Alp Dener <adener@nvidia.com> --------- Signed-off-by: Alp Dener <adener@nvidia.com> Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
* Removed cudaMalloc/WorkspaceManager in JAX csrc. JAX custom ops now request buffers from XLA for their workspace tensors. Signed-off-by: Alp Dener <adener@nvidia.com> * removed unused GEMM C++ API in TE-JAX Signed-off-by: Alp Dener <adener@nvidia.com> * fixed typo in layernorm_geglu_fp8_mlp and removed unnecessary shape reductions in primitives Signed-off-by: Alp Dener <adener@nvidia.com> * fixed import order for linting Signed-off-by: Alp Dener <adener@nvidia.com> * fixed custom op errors due to incorrect static arg nums in JAX jit Signed-off-by: Alp Dener <adener@nvidia.com> * shifted cudnnSetStream further down the kernel to avoid error when executing dummy kernel call with nullptr stream Signed-off-by: Alp Dener <adener@nvidia.com> * fixed linting errors for blank lines Signed-off-by: Alp Dener <adener@nvidia.com> --------- Signed-off-by: Alp Dener <adener@nvidia.com> Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Previously, custom ops allocated their workspace tensors via
WorkspaceManager
injax/csrc/utils.h
, which relied on directcudaMalloc()
.This PR removes the
WorkspaceManager
and exposes new Python interfaces to determine workspaces sizes in the custom-op primitives in order to request the same workspace memory allocation from XLA.