Skip to content

Commit

Permalink
[SYCLomatic] Support bgradb epilogue for dpct::experimental::matmul (#…
Browse files Browse the repository at this point in the history
…2607)

Signed-off-by: Jiang, Zhiwei <zhiwei.jiang@intel.com>
  • Loading branch information
zhiweij1 authored Jan 15, 2025
1 parent a92bfbe commit 7bc59b6
Showing 1 changed file with 42 additions and 5 deletions.
47 changes: 42 additions & 5 deletions clang/runtime/dpct-rt/include/dpct/blas_gemm_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ class matmul_desc_t {
/// scale_type==float && a_type==float && b_type==float && c_type==float.
/// Currently, this function only supports beta==0 or beta==1.
/// Currently, this function only supports the relu, bias, gelu, gelu_bias,
/// gelu_aux, gelu_aux_bias and dgelu epilogue.
/// gelu_aux, gelu_aux_bias, dgelu and bgradb epilogue.
/// NOTE: Non-col-major matrix will be converted to col-major matrix before.
/// TODO: Impl row-major matmul without layout conversion.
/// multiplication and converted back after multiplication.
Expand Down Expand Up @@ -331,10 +331,12 @@ inline sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr compute_desc,
compute_desc->_epilogue != epilogue_t::gelu_bias &&
compute_desc->_epilogue != epilogue_t::gelu_aux &&
compute_desc->_epilogue != epilogue_t::gelu_aux_bias &&
compute_desc->_epilogue != epilogue_t::dgelu) {
throw std::runtime_error("dpct::blas_gemm::experimental::matmul() only "
"supports relu, bias, gelu, gelu_bias, gelu_aux, "
"gelu_aux_bias and dgelu epilogue currently.");
compute_desc->_epilogue != epilogue_t::dgelu &&
compute_desc->_epilogue != epilogue_t::bgradb) {
throw std::runtime_error(
"dpct::blas_gemm::experimental::matmul() only "
"supports relu, bias, gelu, gelu_bias, gelu_aux, "
"gelu_aux_bias, dgelu and bgradb epilogue currently.");
}

if (!(compute_desc->_scale_type == library_data_t::real_int32 &&
Expand Down Expand Up @@ -559,6 +561,28 @@ inline sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr compute_desc,
#endif
}

::dnnl::memory *po_bias_bgradb_mem = nullptr;
auto po_bias_bgradb_md = ::dnnl::memory::desc(
compute_desc->_trans_b == oneapi::mkl::transpose::nontrans
? ::dnnl::memory::dims{N, 1}
: ::dnnl::memory::dims{1, N},
dpct::dnnl::memory_desc_ext::to_dnnl_data_type(
compute_desc->_bias_data_type),
compute_desc->_trans_b == oneapi::mkl::transpose::nontrans
? ::dnnl::memory::dims{1, N}
: ::dnnl::memory::dims{N, 1});
if (compute_desc->_epilogue == epilogue_t::bgradb) {
po_bias_bgradb_mem = new ::dnnl::memory(
po_bias_bgradb_md, handle->get_engine(), DNNL_MEMORY_NONE);
#ifdef DPCT_USM_LEVEL_NONE
detail::type_dispatch<detail::set_buffer_impl>(
compute_desc->_bias_data_type, po_bias_bgradb_mem,
compute_desc->_bias_pointer);
#else
po_bias_bgradb_mem->set_data_handle(compute_desc->_bias_pointer);
#endif
}

::dnnl::memory *po_aux_mem = nullptr;
auto po_aux_md = ::dnnl::memory::desc(
::dnnl::memory::dims{M, N},
Expand Down Expand Up @@ -660,6 +684,17 @@ inline sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr compute_desc,
post_op_prim_event =
::dnnl::sycl_interop::execute(dgelu_prim, handle->get_engine_stream(),
dgelu_args, {matmul_prim_event});
} else if (compute_desc->_epilogue == epilogue_t::bgradb) {
auto reduction_pd = ::dnnl::reduction::primitive_desc(
handle->get_engine(), ::dnnl::algorithm::reduction_sum, weights_md,
po_bias_bgradb_md, 0.f, 0.f);
auto reduction_prim = ::dnnl::reduction(reduction_pd);
std::unordered_map<int, ::dnnl::memory> reduction_args;
reduction_args.insert({DNNL_ARG_SRC, *weights_mem});
reduction_args.insert({DNNL_ARG_DST, *po_bias_bgradb_mem});
post_op_prim_event = ::dnnl::sycl_interop::execute(
reduction_prim, handle->get_engine_stream(), reduction_args,
{matmul_prim_event});
}

// end of calling oneDNN
Expand Down Expand Up @@ -700,6 +735,8 @@ inline sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr compute_desc,
delete dst_mem;
if (po_bias_mem)
delete po_bias_mem;
if (po_bias_bgradb_mem)
delete po_bias_bgradb_mem;
if (po_aux_mem)
delete po_aux_mem;
::dpct::cs::free((void *)new_a, *q_ptr);
Expand Down

0 comments on commit 7bc59b6

Please sign in to comment.