Skip to content

Commit

Permalink
matmul bug fixes. 1) if beta!=0 copy C in. 2) detect additional
Browse files Browse the repository at this point in the history
case that cublas doesn't support.
  • Loading branch information
luitjens committed Jan 5, 2024
1 parent 141df30 commit 389ee69
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion include/matx/transforms/matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -1107,7 +1107,8 @@ __MATX_INLINE__ auto getCublasSupportedTensor( const Op &in, cudaStream_t stream
(in.Stride(RANK-1) != (index_t)1 && in.Stride(RANK-2) != (index_t)1) ||
// cublas allows 0 strides, but verify that the corresponding size is 1
(in.Stride(RANK-1) == (index_t)0 && in.Size(RANK-1) != (index_t)1) ||
(in.Stride(RANK-2) == (index_t)0 && in.Size(RANK-2) != (index_t)1)
(in.Stride(RANK-2) == (index_t)0 && in.Size(RANK-2) != (index_t)1) ||
in.Stride(RANK-2) == 0 // WAR for CUBLAS bug
) {
supported = false;
}
Expand Down Expand Up @@ -1192,6 +1193,10 @@ void matmul_impl(TensorTypeC C, const TensorTypeA A,
if(!b.isSameView(B_)) {
(b = B_).run(stream);
}

if(beta != 0 && !c.isSameView(C)) {
(c = C).run(stream);
}

#if MATX_ENABLE_CUTLASS != 1
// cublasLt does not allow transpose modes on C. Thus we need to make sure that the right most dimension has a stride of 1.
Expand Down

0 comments on commit 389ee69

Please sign in to comment.