Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JAX] Custom Op Workspace Tensors from XLA Buffers #532

Merged
merged 7 commits into from
Jan 29, 2024

Conversation

denera
Copy link
Collaborator

@denera denera commented Nov 22, 2023

Previously, custom ops allocated their workspace tensors via WorkspaceManager in jax/csrc/utils.h, which relied on direct cudaMalloc().

This PR removes the WorkspaceManager and exposes new Python interfaces to determine workspaces sizes in the custom-op primitives in order to request the same workspace memory allocation from XLA.

@denera denera self-assigned this Nov 22, 2023
@nouiz
Copy link
Collaborator

nouiz commented Nov 22, 2023

@mingxu1067 to review.

out_types = [
ir.RankedTensorType.get(out_shape, output_type),
ir.RankedTensorType.get(batch_shape, ir_mu_dtype),
ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype),
ir.RankedTensorType.get(work_meta[0], te_dtype_to_ir_dtype(work_meta[1])),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changes make the number of outputs from lowering is different with abstract. I assume there would be a XLA assertion, could you test this to see if any potential issues?

Besides, these workspace tensors are never used, then it might be removed during XLA compiling. We might have to test this also.

Copy link
Collaborator Author

@denera denera Nov 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The workspace tensors I added appear to be used in the C++ custom ops and NVTE kernels they call. They were being allocated with cudaMalloc in WorkspaceManager, which is what we want to replace with this PR. Did I misread/misunderstand something about this?

It is entirely possible I unintentionally skipped over additional changes needed in the abstract (or elsewhere) to make this work correctly. Do we need dummy outputs in the abstract to make sure we don't lose the workspaces to XLA optimization?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes the purpose of this PR is to remove cudaMalloc in WorkspaceManager.

I am just thinking there might be two issues

  1. The mismatching number of outputs between lowering and abstract might trigger some assertion. If so, then we have to seperate abstract for inner_p and outer_p, then add dummy outouts to abstract of inner_p.
  2. Ununsed workspace tensors might be removed via some XLA's memory optimization pass. But after carefully thinking, we might not to worry about this, since XLA has less knowledge to custom calls, therefore it usually skips those optimization to custom calls

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. If JAX doesn't raise an error due to different number of inputs/outputs, please raise this and give a repro (don't need to be a minimal one). I can look at adding the missing check in JAX.

  2. XLA and all compilers can remove one full operation if all its outputs aren't used and is a side-effect free operation. But it can't remove only part of one operations outputs unless it is know that this instruction support this. So yes, XLA can't do that for custom_call, but in practice, it doesn't do that for native operations too . I don't know any such exception in XLA.

@zlsh80826
Copy link
Collaborator

zlsh80826 commented Nov 23, 2023

Would it be better to not expose the workspace data type to the cpp_extensions.py? We can do something like in JAX that only expose the required workspace size to python side and use somthing like ir.IntegerType.get_signless(8) as the data type. (jaxlib workspace reference)

Another thing is that we can calculate the workspace size in modules.cpp before returning back to the cpp_extensions.py, so that we can return only a single size_t instead of a vector.

transformer_engine/jax/csrc/modules.h Outdated Show resolved Hide resolved
transformer_engine/jax/csrc/modules.cpp Outdated Show resolved Hide resolved
transformer_engine/jax/csrc/modules.h Outdated Show resolved Hide resolved
@denera
Copy link
Collaborator Author

denera commented Dec 15, 2023

/te-ci jax

@denera denera force-pushed the jax-gemm-workspace-xla-buffer branch from 830b67f to 04b14b2 Compare December 15, 2023 23:57
@denera denera marked this pull request as ready for review December 15, 2023 23:57
@denera denera force-pushed the jax-gemm-workspace-xla-buffer branch 2 times, most recently from cfe0dda to a8c3fdc Compare December 19, 2023 22:02
@denera
Copy link
Collaborator Author

denera commented Dec 19, 2023

/te-ci jax

@denera denera force-pushed the jax-gemm-workspace-xla-buffer branch 2 times, most recently from 6bc5b3f to 1e6d94a Compare December 20, 2023 19:53
@denera
Copy link
Collaborator Author

denera commented Dec 20, 2023

@nouiz @mingxu1067 @zlsh80826 This is passing LayerNorm and FusedAttn tests on my end. I'm running the CI but this should be ready for final review now. Thanks!

@denera
Copy link
Collaborator Author

denera commented Dec 20, 2023

/te-ci jax

Copy link
Collaborator

@nouiz nouiz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM at high level.

out_types = [
ir.RankedTensorType.get(out_shape, output_type),
ir.RankedTensorType.get(batch_shape, ir_mu_dtype),
ir.RankedTensorType.get(batch_shape, ir_rsigma_dtype),
ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)),
ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype))
ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype))

Comment on lines 370 to 371
wkspace_size = reduce(operator.mul, wkspace_aval.shape)
barrier_size = reduce(operator.mul, barrier_aval.shape)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
wkspace_size = reduce(operator.mul, wkspace_aval.shape)
barrier_size = reduce(operator.mul, barrier_aval.shape)
wkspace_size = wkspace_aval.size
barrier_size = barrier_aval.size

Comment on lines 567 to 570
wkspace_size = reduce(operator.mul, wkspace_aval.shape)
barrier_size = reduce(operator.mul, barrier_aval.shape)
dgamma_part_size = reduce(operator.mul, dgamma_part_aval.shape)
dbeta_part_size = reduce(operator.mul, dbeta_part_aval.shape)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
wkspace_size = reduce(operator.mul, wkspace_aval.shape)
barrier_size = reduce(operator.mul, barrier_aval.shape)
dgamma_part_size = reduce(operator.mul, dgamma_part_aval.shape)
dbeta_part_size = reduce(operator.mul, dbeta_part_aval.shape)
wkspace_size = wkspace_aval.size
barrier_size = barrier_aval.size
dgamma_part_size = dgamma_part_aval.size
dbeta_part_size = dbeta_part_aval.size

Comment on lines 764 to 765
wkspace_size = reduce(operator.mul, wkspace_aval.shape)
barrier_size = reduce(operator.mul, barrier_aval.shape)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
wkspace_size = reduce(operator.mul, wkspace_aval.shape)
barrier_size = reduce(operator.mul, barrier_aval.shape)
wkspace_size = wkspace_aval.size
barrier_size = barrier_aval.size

Comment on lines 935 to 937
wkspace_size = reduce(operator.mul, wkspace_aval.shape)
barrier_size = reduce(operator.mul, barrier_aval.shape)
dgamma_part_size = reduce(operator.mul, dgamma_part_aval.shape)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
wkspace_size = reduce(operator.mul, wkspace_aval.shape)
barrier_size = reduce(operator.mul, barrier_aval.shape)
dgamma_part_size = reduce(operator.mul, dgamma_part_aval.shape)
wkspace_size = wkspace_aval.size
barrier_size = barrier_aval.size
dgamma_part_size = dgamma_part_aval.size

Comment on lines 572 to 566
out_types = [
ir.RankedTensorType.get(x_shape, x_type.element_type),
ir.RankedTensorType.get(g_shape, g_type.element_type),
ir.RankedTensorType.get(b_shape, b_type.element_type),
ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)),
ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)),
ir.RankedTensorType.get(dgamma_part_aval.shape,
jax_dtype_to_ir_dtype(dgamma_part_aval.dtype)),
ir.RankedTensorType.get(dbeta_part_aval.shape,
jax_dtype_to_ir_dtype(dbeta_part_aval.dtype))
]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would work, but the current enumerated list is also fine to me.

Suggested change
out_types = [
ir.RankedTensorType.get(x_shape, x_type.element_type),
ir.RankedTensorType.get(g_shape, g_type.element_type),
ir.RankedTensorType.get(b_shape, b_type.element_type),
ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)),
ir.RankedTensorType.get(barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype)),
ir.RankedTensorType.get(dgamma_part_aval.shape,
jax_dtype_to_ir_dtype(dgamma_part_aval.dtype)),
ir.RankedTensorType.get(dbeta_part_aval.shape,
jax_dtype_to_ir_dtype(dbeta_part_aval.dtype))
]
out_types = [
ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype))
for output in ctx.avals_out
]

batch_size = reduce(operator.mul, batch_shape)

wkspace_aval = ctx.avals_out[-1]
wkspace_size = reduce(operator.mul, wkspace_aval.shape)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
wkspace_size = reduce(operator.mul, wkspace_aval.shape)
wkspace_size = wkspace_aval.size

Comment on lines 3364 to 3365
wkspace_size = reduce(operator.mul, wkspace_aval.shape)
barrier_size = reduce(operator.mul, barrier_aval.shape)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
wkspace_size = reduce(operator.mul, wkspace_aval.shape)
barrier_size = reduce(operator.mul, barrier_aval.shape)
wkspace_size = wkspace_aval.size
barrier_size = barrier_aval.size

Comment on lines 3608 to 3609
wkspace_size = reduce(operator.mul, wkspace_aval.shape)
barrier_size = reduce(operator.mul, barrier_aval.shape)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
wkspace_size = reduce(operator.mul, wkspace_aval.shape)
barrier_size = reduce(operator.mul, barrier_aval.shape)
wkspace_size = wkspace_aval.size
barrier_size = barrier_aval.size

@denera denera force-pushed the jax-gemm-workspace-xla-buffer branch 6 times, most recently from 12bd429 to 088a79b Compare January 17, 2024 17:41
@denera
Copy link
Collaborator Author

denera commented Jan 17, 2024

/te-ci jax

@denera denera force-pushed the jax-gemm-workspace-xla-buffer branch from 088a79b to d5918b1 Compare January 22, 2024 15:44
@denera
Copy link
Collaborator Author

denera commented Jan 22, 2024

/te-ci jax

@denera
Copy link
Collaborator Author

denera commented Jan 22, 2024

@cyanguwa There are some (minor) changes to the common fused attn kernels here and it would be great to get your feedback on them to make sure there won't be unintended consequences. Thanks!

@denera denera requested a review from cyanguwa January 22, 2024 15:46
…equest buffers from XLA for their workspace tensors.

Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
…eductions in primitives

Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
…ecuting dummy kernel call with nullptr stream

Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
@denera denera force-pushed the jax-gemm-workspace-xla-buffer branch from 2c1703a to 199d27f Compare January 23, 2024 21:47
@denera denera added the enhancement New feature or request label Jan 23, 2024
@denera denera merged commit 4077ccc into NVIDIA:main Jan 29, 2024
9 checks passed
Oleg-Goncharov pushed a commit to Oleg-Goncharov/TransformerEngine that referenced this pull request Jan 30, 2024
* Removed cudaMalloc/WorkspaceManager in JAX csrc. JAX custom ops now request buffers from XLA for their workspace tensors.

Signed-off-by: Alp Dener <adener@nvidia.com>

* removed unused GEMM C++ API in TE-JAX

Signed-off-by: Alp Dener <adener@nvidia.com>

* fixed typo in layernorm_geglu_fp8_mlp and removed unnecessary shape reductions in primitives

Signed-off-by: Alp Dener <adener@nvidia.com>

* fixed import order for linting

Signed-off-by: Alp Dener <adener@nvidia.com>

* fixed custom op errors due to incorrect static arg nums in JAX jit

Signed-off-by: Alp Dener <adener@nvidia.com>

* shifted cudnnSetStream further down the kernel to avoid error when executing dummy kernel call with nullptr stream

Signed-off-by: Alp Dener <adener@nvidia.com>

* fixed linting errors for blank lines

Signed-off-by: Alp Dener <adener@nvidia.com>

---------

Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Oleg-Goncharov pushed a commit to Oleg-Goncharov/TransformerEngine that referenced this pull request Jan 30, 2024
* Removed cudaMalloc/WorkspaceManager in JAX csrc. JAX custom ops now request buffers from XLA for their workspace tensors.

Signed-off-by: Alp Dener <adener@nvidia.com>

* removed unused GEMM C++ API in TE-JAX

Signed-off-by: Alp Dener <adener@nvidia.com>

* fixed typo in layernorm_geglu_fp8_mlp and removed unnecessary shape reductions in primitives

Signed-off-by: Alp Dener <adener@nvidia.com>

* fixed import order for linting

Signed-off-by: Alp Dener <adener@nvidia.com>

* fixed custom op errors due to incorrect static arg nums in JAX jit

Signed-off-by: Alp Dener <adener@nvidia.com>

* shifted cudnnSetStream further down the kernel to avoid error when executing dummy kernel call with nullptr stream

Signed-off-by: Alp Dener <adener@nvidia.com>

* fixed linting errors for blank lines

Signed-off-by: Alp Dener <adener@nvidia.com>

---------

Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants