From 12f213a7f7839b3593f18428cc71bb8f8e24a67d Mon Sep 17 00:00:00 2001 From: Ivy Zhang Date: Mon, 7 Mar 2022 13:06:35 +0800 Subject: [PATCH] [BYOC-DNNL] Support DNNL optimal layout (#10421) * enable dnnl optimal layout for supported ops * verfied cv models with onednnv1.7 * rebase to the latest main branch * fix format related comments * remove unnecessary layout transformation * change deconv into conv_transpose * rename some variables and functions * simplify query_layout * add checkes for query_layout * fix lint * move partition_for_dnnl from dnnl.py to test_dnnl.py * remove unnecessary model test * add more dnnl layout * rename flag in convolution.cc * enhance dnnl layout --- python/tvm/relay/op/contrib/dnnl.py | 246 ++++++++++-- src/relay/backend/contrib/dnnl/codegen.cc | 42 +- .../backend/contrib/dnnl/query_layout.cc | 378 ++++++++++++++++++ src/relay/op/nn/convolution.cc | 48 ++- src/runtime/contrib/dnnl/dnnl_json_runtime.cc | 291 ++++++++++---- src/tir/ir/data_layout.cc | 1 - tests/python/contrib/test_dnnl.py | 158 ++++++-- 7 files changed, 990 insertions(+), 174 deletions(-) create mode 100755 src/relay/backend/contrib/dnnl/query_layout.cc diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py index 6d4fe0d81260..2bcb2b0ef7f8 100644 --- a/python/tvm/relay/op/contrib/dnnl.py +++ b/python/tvm/relay/op/contrib/dnnl.py @@ -35,9 +35,9 @@ import logging import tvm.ir -from tvm.relay import transform -from tvm.relay.build_module import bind_params_by_name +from tvm import relay +from ... import _ffi_api from ...dataflow_pattern import wildcard, is_op from .register import register_pattern_table @@ -94,12 +94,12 @@ def _func_wrapper(expr): def make_conv_pattern(conv_name, with_bias=True, with_eltwise=None): - """Create patterns related to conv and deconv. + """Create patterns related to conv and conv_transpose. Parameters ---------- with_bias : bool - Whether attach `bias_add` to `conv / deconv`. + Whether attach `bias_add` to `conv / conv_transpose`. with_eltwise : str The attached elementwise post-op name. Returns @@ -147,12 +147,12 @@ def make_dense_pattern(with_bias=True, with_eltwise=None): return dense_out -def make_dnnl_pattern(op, with_bias, with_eltwise): +def make_dnnl_pattern(op_name, with_bias, with_eltwise): """Create dnnl patterns. Parameters ---------- - op : str + op_name : str The first call node's op name. with_bias : bool Whether attach `bias_add` to `nn.dense`. @@ -163,18 +163,20 @@ def make_dnnl_pattern(op, with_bias, with_eltwise): pattern : Tuple(pattern_name, CallPattern) Created pattern name, along with its CallPattern. """ - pat_name = op.replace("nn", "dnnl") + pat_name = op_name.replace("nn", "dnnl") + if "_transpose" in op_name: + pat_name = "dnnl.deconv" + op_name.split("_")[0][-2::] pat_name += "_bias" if with_bias else "" pat_name += ("_" + with_eltwise.split(".")[-1]) if with_eltwise else "" - if "conv" in op: - dnnl_pattern = (pat_name, make_conv_pattern(op, with_bias, with_eltwise)) - elif op == "nn.dense": + if "conv" in op_name: + dnnl_pattern = (pat_name, make_conv_pattern(op_name, with_bias, with_eltwise)) + elif op_name == "nn.dense": dnnl_pattern = (pat_name, make_dense_pattern(with_bias, with_eltwise)) else: logger.warning( "Currently, only conv1d, conv2d, conv2d_transpose, conv3d_transpose and " "dense op are supported, but got %s.", - op, + op_name, ) dnnl_pattern = () return dnnl_pattern @@ -207,39 +209,205 @@ def pattern_table(): return dnnl_patterns -def partition_for_dnnl(mod, params=None): - """Partition the graph greedily offloading supported operators to DNNL. +def get_optimal_layout_for_conv( + data_layout, kernel_layout, weight_shape, out_shape, paddings, strides, dilates, groups +): + """Get the optimal layout of dnnl, given shape of conv2d. Parameters ---------- - mod : Module - The module to run passes on. - params : Optional[Dict[str, NDArray]] - Constant input parameters. + data_layout, kernel_layout,weight_shape, out_shape, paddings, strides, dilates, groups + : String + Input argument. + Returns ------- - mod : Module - Annotated and partitioned module. + layouts : string + The result. """ + return _ffi_api.get_optimal_layout_for_conv( + data_layout, + kernel_layout, + weight_shape, + out_shape, + paddings, + strides, + dilates, + groups, + ) + + +def get_optimal_layout_for_conv_transpose( + data_layout, + kernel_layout, + weight_shape, + out_shape, + paddings, + output_paddings, + strides, + dilates, + groups, +): + """Get the optimal layout of dnnl, given shape of tranposed conv2d. + + Parameters + ---------- + data_layout, kernel_layout, weight_shape, out_shape, paddings, output_paddings, strides, + dilates, groups + : Int, String + Input argument. + + Returns + ------- + layouts : string + The result. + """ + return _ffi_api.get_optimal_layout_for_conv_transpose( + data_layout, + kernel_layout, + weight_shape, + out_shape, + paddings, + output_paddings, + strides, + dilates, + groups, + ) + + +def get_shape(tensor): + """Get tensor's shape.""" + if isinstance(tensor, relay.expr.Var): + return tensor.type_annotation.concrete_shape + if isinstance(tensor, relay.expr.Constant): + return tensor.data.shape + if isinstance(tensor, tvm.ir.tensor_type.TensorType): + return tensor.concrete_shape + if isinstance(tensor, tvm.ir.container.Array): + return tensor[-1].shape + if isinstance(tensor, relay.expr.Call): + return tensor.checked_type.shape + raise TypeError("Unsupport data type: %s" % type(tensor)) + - if params: - mod["main"] = bind_params_by_name(mod["main"], params) - seq = tvm.transform.Sequential( - [ - transform.CanonicalizeOps(), - transform.InferType(), - transform.SimplifyInference(), - transform.FoldConstant(), - transform.FoldScaleAxis(), - # fold consecutive add ops to simplify pattern `conv2d-bias_add-bn-relu` - transform.SimplifyExpr(), - transform.FoldConstant(), - transform.MergeComposite(pattern_table()), - transform.AnnotateTarget("dnnl"), - transform.MergeCompilerRegions(), - transform.PartitionGraph(), - ] +def tag2layout(input_data, is_weight=False, conv_type="Conv1D"): + """Transfer layout, denoted with `a, b, c, d, e`, + into valid layout (NCHW / OIHW) of TVM.""" + if "Conv1D" in conv_type: + data_dic = {"a": "N", "b": "C", "c": "W"} + weight_dic = {"a": "O", "b": "I", "c": "W", "d": "G"} + elif "Conv2D" in conv_type: + data_dic = {"a": "N", "b": "C", "c": "H", "d": "W"} + weight_dic = {"a": "O", "b": "I", "c": "H", "d": "W"} + if "e" in input_data: + weight_dic = {"a": "G", "b": "O", "c": "I", "d": "H", "e": "W"} + elif "Conv3D" in conv_type: + data_dic = {"a": "N", "b": "C", "c": "D", "d": "H", "e": "W"} + weight_dic = {"a": "O", "b": "I", "c": "D", "d": "H", "e": "W", "f": "G"} + + dic = weight_dic if is_weight else data_dic + res = "" + + for i in input_data: + if i.isupper(): + i = i.lower() + res += dic[i] + dic[i] = dic[i].lower() + elif i.islower(): + res += dic[i] + elif i.isdigit(): + res += i + else: + raise ValueError("Unsupport layout format: %s" % input_data) + return res + + +def legalize_group_conv(attrs, inputs, types): + """Legalize group conv / conv_transpose calculation. + Alter weight layout from OIHW to GOIHW / IOHW to GIOHW""" + groups = attrs.groups + data, weight = inputs + if groups == 1: + if "Transpose" not in type(attrs).__name__: + return relay.nn.conv2d(data, weight, **attrs) + return relay.nn.conv2d_transpose(data, weight, **attrs) + OC, IC, H, W = get_shape(weight) + new_attrs = dict(attrs) + weight = relay.reshape(weight, (groups, OC // groups, IC, H, W)) + if "Transpose" not in type(attrs).__name__: + new_attrs["kernel_layout"] = "GOIHW" + return relay.nn.conv2d(data, weight, **new_attrs) + new_attrs["kernel_layout"] = "GIOHW" + return relay.nn.conv2d_transpose(data, weight, **new_attrs) + + +def alter_conv(attrs, inputs, tinfos, out_type): + """The convolution's layout auto-query func for dnnl.""" + + data, weight = inputs + groups = str(attrs.groups) + weight_shape = ",".join([str(x) for x in get_shape(weight)]) + out_shape = ",".join([str(x) for x in get_shape(out_type)]) + paddings = ",".join([str(x) for x in attrs.get_int_tuple("padding")]) + strides = ",".join([str(x) for x in attrs.get_int_tuple("strides")]) + dilates = ",".join([str(x) for x in attrs.get_int_tuple("dilation")]) + new_attrs = dict(attrs) + conv_type = type(attrs).__name__.split("Attrs")[0] + + res = get_optimal_layout_for_conv( + attrs["data_layout"], + attrs["kernel_layout"], + weight_shape, + out_shape, + paddings, + strides, + dilates, + groups, ) - with tvm.transform.PassContext(opt_level=3): - mod = seq(mod) - return mod + src_df, weight_df, dst_df = res.split(",") + new_attrs["data_layout"] = tag2layout(src_df, is_weight=False, conv_type=conv_type) + new_attrs["kernel_layout"] = tag2layout(weight_df, is_weight=True, conv_type=conv_type) + new_attrs["out_layout"] = tag2layout(dst_df, is_weight=False, conv_type=conv_type) + + if conv_type == "Conv1D": + return relay.nn.conv1d(data, weight, **new_attrs) + if conv_type == "Conv2D": + return relay.nn.conv2d(data, weight, **new_attrs) + return relay.nn.conv3d(data, weight, **new_attrs) + + +def alter_conv_transpose(attrs, inputs, tinfos, out_type): + """The transposed convolution's layout auto-query func for dnnl.""" + + data, weight = inputs + weight_shape = ",".join([str(x) for x in get_shape(weight)]) + out_shape = ",".join([str(x) for x in get_shape(out_type)]) + paddings = ",".join([str(x) for x in attrs.get_int_tuple("padding")]) + output_paddings = ",".join([str(x) for x in attrs.get_int_tuple("output_padding")]) + strides = ",".join([str(x) for x in attrs.get_int_tuple("strides")]) + dilates = ",".join([str(x) for x in attrs.get_int_tuple("dilation")]) + groups = str(attrs.groups) + new_attrs = dict(attrs) + conv_type = type(attrs).__name__.split("Attrs")[0] + + res = get_optimal_layout_for_conv_transpose( + attrs["data_layout"], + attrs["kernel_layout"], + weight_shape, + out_shape, + paddings, + output_paddings, + strides, + dilates, + groups, + ) + src_df, weight_df, dst_df = res.split(",") + new_attrs["data_layout"] = tag2layout(src_df, is_weight=False, conv_type=conv_type) + new_attrs["kernel_layout"] = tag2layout(weight_df, is_weight=True, conv_type=conv_type) + new_attrs["out_layout"] = tag2layout(dst_df, is_weight=False, conv_type=conv_type) + + if conv_type == "Conv1DTranspose": + return relay.nn.conv1d_transpose(data, weight, **new_attrs) + if conv_type == "Conv2DTranspose": + return relay.nn.conv2d_transpose(data, weight, **new_attrs) + return relay.nn.conv3d_transpose(data, weight, **new_attrs) diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc index 7971b9cf67d2..41480ed33b0a 100644 --- a/src/relay/backend/contrib/dnnl/codegen.cc +++ b/src/relay/backend/contrib/dnnl/codegen.cc @@ -445,14 +445,30 @@ class DNNLJSONSerializer : public backend::contrib::JSONSerializer { {"relu", "nn.relu"}, {"tanh", "tanh"}, {"sigmoid", "sigmoid"}, + {"nn.deconv2d", "nn.conv2d_transpose"}, + {"nn.deconv3d", "nn.conv3d_transpose"}, }; - std::vector ParsingOpList(std::string op, std::string pattern_name) { - std::vector op_list = {"nn." + op}; - for (auto& t : op_map) { - if (pattern_name.find(t.first) != std::string::npos) { - op_list.push_back(t.second); + std::vector ParsingOpList(const std::string& pattern_name, + std::string interval = "_") { + ICHECK_NE(pattern_name, ""); + std::vector op_list; + size_t pos = 0, start = 0; + while ((pos = pattern_name.find(interval, start)) != std::string::npos) { + std::string op_name = pattern_name.substr(start, pos - start); + if (op_name.find("dnnl") != std::string::npos) { + op_name.replace(op_name.find("dnnl"), 4, "nn"); + if (op_name.find("deconv") != std::string::npos) { + op_name = op_map[op_name]; + } + } else { + op_name = op_map[op_name]; } + if (pos > start) op_list.push_back(op_name); + start = pos + interval.size(); + } + if (pattern_name.size() > start) { + op_list.push_back(op_map[pattern_name.substr(start)]); } return op_list; } @@ -471,28 +487,28 @@ class DNNLJSONSerializer : public backend::contrib::JSONSerializer { ICHECK(comp.defined()) << "DNNL JSON runtime only supports composite functions."; name = comp.value(); - if (name.find("dnnl.conv2d_transpose") != std::string::npos) { - std::vector op_list = ParsingOpList("conv2d_transpose", name); + if (name.find("dnnl.deconv2d") != std::string::npos) { + std::vector op_list = ParsingOpList(name); call = GetRootCall(fn->body.as(), op_list.size() - 1, op_list); ICHECK(call->op.as()) << "Not op node"; - } else if (name.find("dnnl.conv3d_transpose") != std::string::npos) { - std::vector op_list = ParsingOpList("conv3d_transpose", name); + } else if (name.find("dnnl.deconv3d") != std::string::npos) { + std::vector op_list = ParsingOpList(name); call = GetRootCall(fn->body.as(), op_list.size() - 1, op_list); ICHECK(call->op.as()) << "Not op node"; } else if (name.find("dnnl.conv1d") != std::string::npos) { - std::vector op_list = ParsingOpList("conv1d", name); + std::vector op_list = ParsingOpList(name); call = GetRootCall(fn->body.as(), op_list.size() - 1, op_list); ICHECK(call->op.as()) << "Not op node"; } else if (name.find("dnnl.conv2d") != std::string::npos) { - std::vector op_list = ParsingOpList("conv2d", name); + std::vector op_list = ParsingOpList(name); call = GetRootCall(fn->body.as(), op_list.size() - 1, op_list); ICHECK(call->op.as()) << "Not op node"; } else if (name.find("dnnl.conv3d") != std::string::npos) { - std::vector op_list = ParsingOpList("conv3d", name); + std::vector op_list = ParsingOpList(name); call = GetRootCall(fn->body.as(), op_list.size() - 1, op_list); ICHECK(call->op.as()) << "Not op node"; } else if (name.find("dnnl.dense") != std::string::npos) { - std::vector op_list = ParsingOpList("dense", name); + std::vector op_list = ParsingOpList(name); call = GetRootCall(fn->body.as(), op_list.size() - 1, op_list); ICHECK(call->op.as()) << "Not op node"; } else { diff --git a/src/relay/backend/contrib/dnnl/query_layout.cc b/src/relay/backend/contrib/dnnl/query_layout.cc new file mode 100755 index 000000000000..7fb1d824c702 --- /dev/null +++ b/src/relay/backend/contrib/dnnl/query_layout.cc @@ -0,0 +1,378 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/backend/contrib/dnnl/query_layout.cc + * \brief layout auto-query func. + */ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "../../utils.h" +#include "dnnl.hpp" + +using dim_t = dnnl_dim_t; +using dims_t = dnnl_dims_t; + +namespace tvm { +namespace relay { +namespace contrib { + +template +inline void array_set(T* arr, const U& val, size_t size) { + for (size_t i = 0; i < size; ++i) arr[i] = static_cast(val); +} + +template +inline void array_copy(T* dst, const T* src, size_t size) { + for (size_t i = 0; i < size; ++i) dst[i] = src[i]; +} + +template +inline void swap(T& t1, T& t2) { + T tmp(t1); + t1 = t2; + t2 = tmp; +} + +template +inline void simultaneous_sort(T* vals, T* vals_2nd_level, U* keys, size_t size, F comparator) { + if (size == 0) return; + + for (size_t i = 0; i < size - 1; ++i) { + bool swapped = false; + + for (size_t j = 0; j < size - i - 1; j++) { + auto res = comparator(vals[j], vals[j + 1]); + if (res == 0) res = comparator(vals_2nd_level[j], vals_2nd_level[j + 1]); + + if (res > 0) { + swap(vals[j], vals[j + 1]); + swap(vals_2nd_level[j], vals_2nd_level[j + 1]); + swap(keys[j], keys[j + 1]); + swapped = true; + } + } + + if (swapped == false) break; + } +} + +void compute_blocks(dims_t blocks, const dnnl::memory::desc* md) { + using format_kind_t = dnnl_format_kind_t; + const format_kind_t blocked = dnnl_blocked; + if (!(md->data.format_kind == blocked)) { + array_set(blocks, 0, md->data.ndims); + return; + } + array_set(blocks, 1, md->data.ndims); + const auto& bd = md->data.format_desc.blocking; + for (int iblk = 0; iblk < bd.inner_nblks; ++iblk) + blocks[bd.inner_idxs[iblk]] *= bd.inner_blks[iblk]; +} + +inline bool has_runtime_strides(const dnnl::memory::desc* md) { + using format_kind_t = dnnl_format_kind_t; + const format_kind_t blocked = dnnl_blocked; + if (!(md->data.format_kind == blocked)) return false; + for (int d = 0; d < md->data.ndims; ++d) + if (md->data.format_desc.blocking.strides[d] == DNNL_RUNTIME_DIM_VAL) return true; + return false; +} + +std::string md2fmt_tag_str(const dnnl::memory::desc* md) { + const auto& blk = md->data.format_desc.blocking; + + dims_t blocks = {0}; + compute_blocks(blocks, md); + + char dim_chars[DNNL_MAX_NDIMS + 1]; + + dims_t ou_blocks = {0}; + array_copy(ou_blocks, md->data.padded_dims, md->data.ndims); + + bool plain = true; + for (int d = 0; d < md->data.ndims; ++d) { + dim_chars[d] = (blocks[d] == 1 ? 'a' : 'A') + static_cast(d); + if (blocks[d] != 1) plain = false; + ou_blocks[d] /= blocks[d]; + } + + // Can't report meaningful tag for runtime dimensions. + if (has_runtime_strides(md)) return "*"; + + dims_t strides; + array_copy(strides, blk.strides, md->data.ndims); + + simultaneous_sort(strides, ou_blocks, dim_chars, md->data.ndims, + [](dim_t a, dim_t b) { return b - a; }); + + dim_chars[md->data.ndims] = '\0'; + + std::string s(dim_chars); + + if (!plain) { + for (int iblk = 0; iblk < blk.inner_nblks; ++iblk) { + char c = ('a' + static_cast(blk.inner_idxs[iblk])); + s += (std::to_string(blk.inner_blks[iblk]) + c); + } + } + return s; +} + +dnnl::memory::dims str2dims(const std::string& str_shape, bool dilates = false, + std::string interval = ",") { + // Split strings + std::vector str_dims; + size_t pos = 0, start = 0; + while ((pos = str_shape.find(interval, start)) != std::string::npos) { + std::string str_dim = str_shape.substr(start, pos - start); + if (pos > start) str_dims.push_back(str_dim); + start = pos + interval.size(); + } + if (str_shape.size() > start) { + str_dims.push_back(str_shape.substr(start)); + } + // transfer string to dims + dnnl::memory::dims out_dims; + if (dilates) { + std::transform(str_dims.begin(), str_dims.end(), std::back_inserter(out_dims), + [](const std::string& str) { return std::stoi(str) - 1; }); + } else { + std::transform(str_dims.begin(), str_dims.end(), std::back_inserter(out_dims), + [](const std::string& str) { return std::stoi(str); }); + } + return out_dims; +} + +void check_shapes(const std::vector shapes) { + std::regex valid_pat("(\\d*)(,(\\d*))*"); + bool checked = std::regex_match(shapes[0], valid_pat); + for (size_t i = 1; i < shapes.size() - 1; i++) { + checked &= std::regex_match(shapes[i], valid_pat); + } + checked &= std::regex_match(shapes[shapes.size() - 1], std::regex("\\d*")); + if (!checked) { + LOG(FATAL) << "Invalid input args for query dnnl optimal layout."; + } +} + +void check_layout(bool var, bool ref) { + if (var != ref) { + LOG(FATAL) << "Invalid input layout for query dnnl optimal layout."; + } +} + +std::string get_optimal_layout_for_conv(std::string data_layout, std::string kernel_layout, + std::string weight_shape, std::string out_shape, + std::string paddings, std::string strides, + std::string dilates, std::string G) { + check_layout(std::regex_match(data_layout, std::regex("NC(D?)(H?)W")), true); + check_layout(std::regex_match(kernel_layout, std::regex("(G?)OI(D?)(H?)W")), true); + check_shapes({weight_shape, out_shape, paddings, strides, dilates, G}); + + dnnl::engine eng(dnnl::engine::kind::cpu, 0); + dnnl::stream s(eng); + using tag = dnnl::memory::format_tag; + using dt = dnnl::memory::data_type; + + dnnl::memory::dim groups = std::stoi(G); + dnnl::memory::dims weight_dims_ = str2dims(weight_shape); + dnnl::memory::dims weight_dims = weight_dims_; + + if (groups > 1) { + if (weight_dims_.size() == 5) { + weight_dims = {groups * weight_dims_[1], groups * weight_dims_[2], weight_dims_[3], + weight_dims_[4]}; + } else { + weight_dims[1] = weight_dims[1] * groups; + } + } + + dnnl::memory::dims out_dims = str2dims(out_shape); + dnnl::memory::dims padding_dims = str2dims(paddings); + dnnl::memory::dims padding_dims_l(padding_dims.begin(), + padding_dims.begin() + padding_dims.size() / 2); + dnnl::memory::dims padding_dims_r(padding_dims.end() - padding_dims.size() / 2, + padding_dims.end()); + dnnl::memory::dims strides_dims = str2dims(strides); + dnnl::memory::dims dilates_dims = str2dims(dilates, true); + + dnnl::memory::dims input_dims = out_dims; + input_dims[1] = weight_dims[1]; + for (size_t i = 2; i < out_dims.size(); i++) { + dnnl::memory::dim K = weight_dims[i]; + dnnl::memory::dim S = strides_dims[i - 2]; + dnnl::memory::dim D = dilates_dims[i - 2]; + dnnl::memory::dim PL = padding_dims_l[i - 2]; + dnnl::memory::dim PR = padding_dims_r[i - 2]; + dnnl::memory::dim DK = 1 + (K - 1) * (D + 1); + input_dims[i] = out_dims[i] * S - PL - PR + DK - 1; + } + + dnnl::memory::dims conv_src_dims = input_dims; + dnnl::memory::dims conv_weights_dims = weight_dims; + if (groups > 1) { + conv_weights_dims = {groups, out_dims[1] / groups, input_dims[1] / groups}; + conv_weights_dims.insert(conv_weights_dims.end(), weight_dims.begin() + 2, weight_dims.end()); + } + + dnnl::memory::dims conv_dst_dims = out_dims; + dnnl::memory::dims conv_strides = strides_dims; + dnnl::memory::dims conv_dilates = dilates_dims; + dnnl::memory::dims conv_padding_l = padding_dims_l; + dnnl::memory::dims conv_padding_r = padding_dims_r; + + auto conv_src_md = dnnl::memory::desc({conv_src_dims}, dt::f32, tag::any); + auto conv_weights_md = dnnl::memory::desc({conv_weights_dims}, dt::f32, tag::any); + auto conv_dst_md = dnnl::memory::desc({conv_dst_dims}, dt::f32, tag::any); + + auto conv_desc = dnnl::convolution_forward::desc( + dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_direct, conv_src_md, + conv_weights_md, conv_dst_md, conv_strides, conv_dilates, conv_padding_l, conv_padding_r); + + auto conv_prim_desc = dnnl::convolution_forward::primitive_desc(conv_desc, eng); + + auto src_format = conv_prim_desc.src_desc(); + auto weights_format = conv_prim_desc.weights_desc(); + auto dst_format = conv_prim_desc.dst_desc(); + std::string src_df, weight_df, dst_df; + + src_df = md2fmt_tag_str(&src_format); + weight_df = md2fmt_tag_str(&weights_format); + dst_df = md2fmt_tag_str(&dst_format); + std::string res = src_df + "," + weight_df + "," + dst_df; + return res; +} + +std::string get_optimal_layout_for_conv_transpose(std::string data_layout, + std::string kernel_layout, + std::string weight_shape, std::string out_shape, + std::string paddings, std::string output_paddings, + std::string strides, std::string dilates, + std::string G) { + check_layout(std::regex_match(data_layout, std::regex("NC(D?)(H?)W")), true); + check_layout(std::regex_match(kernel_layout, std::regex("(G?)((IO)|(OI))(D?)(H?)W")), true); + check_shapes({weight_shape, out_shape, paddings, output_paddings, strides, dilates, G}); + + dnnl::engine eng(dnnl::engine::kind::cpu, 0); + dnnl::stream s(eng); + using tag = dnnl::memory::format_tag; + using dt = dnnl::memory::data_type; + + dnnl::memory::dim groups = std::stoi(G); + dnnl::memory::dims weight_dims_ = str2dims(weight_shape); + dnnl::memory::dims weight_dims = weight_dims_; + if (groups > 1) { + if (weight_dims_.size() == 5) { + weight_dims = {groups * weight_dims_[1], groups * weight_dims_[2], weight_dims_[3], + weight_dims_[4]}; + } else { + weight_dims[1] = weight_dims[1] * groups; + } + } + dnnl::memory::dims out_dims = str2dims(out_shape); + dnnl::memory::dims padding_dims = str2dims(paddings); + dnnl::memory::dims padding_dims_l(padding_dims.begin(), + padding_dims.begin() + padding_dims.size() / 2); + dnnl::memory::dims padding_dims_r(padding_dims.end() - padding_dims.size() / 2, + padding_dims.end()); + dnnl::memory::dims output_padding_dims = str2dims(output_paddings); + dnnl::memory::dims strides_dims = str2dims(strides); + dnnl::memory::dims dilates_dims = str2dims(dilates, true); + + dnnl::memory::dims input_dims = out_dims; + if (out_dims[1] == weight_dims[0]) { + input_dims[1] = weight_dims[1]; + } else { + input_dims[1] = weight_dims[0]; + std::swap(weight_dims[0], weight_dims[1]); + } + for (size_t i = 2; i < out_dims.size(); i++) { + dnnl::memory::dim K = weight_dims[i]; + dnnl::memory::dim S = strides_dims[i - 2]; + dnnl::memory::dim D = dilates_dims[i - 2]; + dnnl::memory::dim PL = padding_dims_l[i - 2]; + dnnl::memory::dim PR = padding_dims_r[i - 2]; + dnnl::memory::dim OP = output_padding_dims[i - 2]; + dnnl::memory::dim DK = 1 + (K - 1) * (D + 1); + input_dims[i] = (out_dims[i] - DK + PL + PR - OP) / S + 1; + } + + dnnl::memory::dims deconv_src_dims = input_dims; + dnnl::memory::dims deconv_weights_dims = weight_dims; + if (groups > 1) { + deconv_weights_dims = {groups, out_dims[1] / groups, input_dims[1] / groups}; + deconv_weights_dims.insert(deconv_weights_dims.end(), weight_dims.begin() + 2, + weight_dims.end()); + } + dnnl::memory::dims deconv_dst_dims = out_dims; + dnnl::memory::dims deconv_strides = strides_dims; + dnnl::memory::dims deconv_dilates = dilates_dims; + dnnl::memory::dims deconv_padding_l = padding_dims_l; + dnnl::memory::dims deconv_padding_r = padding_dims_r; + + auto deconv_src_md = dnnl::memory::desc({deconv_src_dims}, dt::f32, tag::any); + auto deconv_weights_md = dnnl::memory::desc({deconv_weights_dims}, dt::f32, tag::any); + auto deconv_dst_md = dnnl::memory::desc({deconv_dst_dims}, dt::f32, tag::any); + + auto deconv_desc = dnnl::deconvolution_forward::desc( + dnnl::prop_kind::forward_inference, dnnl::algorithm::deconvolution_direct, deconv_src_md, + deconv_weights_md, deconv_dst_md, deconv_strides, deconv_dilates, deconv_padding_l, + deconv_padding_r); + + auto deconv_prim_desc = dnnl::deconvolution_forward::primitive_desc(deconv_desc, eng); + + auto src_format = deconv_prim_desc.src_desc(); + auto weights_format = deconv_prim_desc.weights_desc(); + auto dst_format = deconv_prim_desc.dst_desc(); + std::string src_df, weight_df, dst_df; + + src_df = md2fmt_tag_str(&src_format); + weight_df = md2fmt_tag_str(&weights_format); + dst_df = md2fmt_tag_str(&dst_format); + std::string res = src_df + "," + weight_df + "," + dst_df; + return res; +} + +TVM_REGISTER_GLOBAL("relay.ir.get_optimal_layout_for_conv") + .set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = get_optimal_layout_for_conv(args[0], args[1], args[2], args[3], args[4], args[5], + args[6], args[7]); + }); + +TVM_REGISTER_GLOBAL("relay.ir.get_optimal_layout_for_conv_transpose") + .set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = get_optimal_layout_for_conv_transpose(args[0], args[1], args[2], args[3], args[4], + args[5], args[6], args[7], args[8]); + }); + +} // namespace contrib +} // namespace relay +} // namespace tvm diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc index 081e077bc180..0c882589e9cb 100644 --- a/src/relay/op/nn/convolution.cc +++ b/src/relay/op/nn/convolution.cc @@ -185,13 +185,19 @@ bool Conv2DRel(const Array& types, int num_inputs, const Attrs& attrs, const auto* weight = types[1].as(); if (data == nullptr) return false; static const Layout kNCHW("NCHW"); - static const Layout kOIHW("OIHW"); + Layout kOIHW("OIHW"); const auto* param = attrs.as(); ICHECK(param != nullptr); const Layout in_layout(param->data_layout); const Layout kernel_layout(param->kernel_layout); + bool is_dnnl_group_conv = false; + if (param->groups > 1 && kernel_layout.name().find("G") != std::string::npos) { + kOIHW = Layout("GOIHW"); + is_dnnl_group_conv = true; + } + const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW); if (!trans_in_layout.defined()) { reporter->GetDiagCtx().Emit( @@ -203,10 +209,10 @@ bool Conv2DRel(const Array& types, int num_inputs, const Attrs& attrs, const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIHW); if (!trans_kernel_layout.defined()) { - reporter->GetDiagCtx().Emit( - Diagnostic::Error(reporter->GetSpan()) - << "conv2d only support kernel layouts that are convertible from OIHW." - << " The provided layout is: " << kernel_layout); + reporter->GetDiagCtx().Emit(Diagnostic::Error(reporter->GetSpan()) + << "conv2d only support kernel layouts that are convertible from " + << kOIHW << "." + << " The provided layout is: " << kernel_layout); return false; } @@ -244,7 +250,12 @@ bool Conv2DRel(const Array& types, int num_inputs, const Attrs& attrs, ICHECK_EQ(param->dilation.size(), 2); Array wshape; - if (is_depthwise) { + if (is_dnnl_group_conv) { + // infer weight's shape for group convolution + wshape = {{param->groups, indexdiv(param->channels, param->groups), + indexdiv(dshape_nchw[1], param->groups), param->kernel_size[0], + param->kernel_size[1]}}; + } else if (is_depthwise) { // infer weight's shape for depthwise convolution wshape = {{dshape_nchw[1], indexdiv(param->channels, dshape_nchw[1]), param->kernel_size[0], param->kernel_size[1]}}; @@ -734,13 +745,19 @@ bool Conv2DTransposeRel(const Array& types, int num_inputs, const Attrs& a if (data == nullptr) return false; static const Layout kNCHW("NCHW"); - static const Layout kIOHW("IOHW"); + Layout kIOHW("IOHW"); const Conv2DTransposeAttrs* param = attrs.as(); ICHECK(param != nullptr); const Layout in_layout(param->data_layout); const Layout kernel_layout(param->kernel_layout); + bool is_dnnl_group_conv = false; + if (param->groups > 1 && kernel_layout.name().find("G") != std::string::npos) { + kIOHW = Layout("GIOHW"); + is_dnnl_group_conv = true; + } + const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW); ICHECK(trans_in_layout.defined()) << "Conv2DTransposed only support input layouts that are convertible from NCHW." @@ -748,8 +765,8 @@ bool Conv2DTransposeRel(const Array& types, int num_inputs, const Attrs& a const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kIOHW); ICHECK(trans_kernel_layout.defined()) - << "Conv2DTransposed only support kernel layouts that are convertible from IOHW." - << " But got " << kernel_layout; + << "Conv2DTransposed only support kernel layouts that are convertible from " << kIOHW << "." + << " But got " << kernel_layout << " " << kIOHW; Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCHW); @@ -766,8 +783,17 @@ bool Conv2DTransposeRel(const Array& types, int num_inputs, const Attrs& a ICHECK_EQ(param->kernel_size.size(), 2); ICHECK_EQ(param->dilation.size(), 2); - Array wshape({dshape_nchw[1], indexdiv(param->channels, param->groups), - param->kernel_size[0], param->kernel_size[1]}); + Array wshape; + if (is_dnnl_group_conv) { + // infer weight's shape for group convolution + wshape = {{param->groups, indexdiv(dshape_nchw[1], param->groups), + indexdiv(param->channels, param->groups), param->kernel_size[0], + param->kernel_size[1]}}; + } else { + // infer weight's shape for depthwise convolution + wshape = {{dshape_nchw[1], indexdiv(param->channels, param->groups), param->kernel_size[0], + param->kernel_size[1]}}; + } wshape = trans_kernel_layout.BackwardShape(wshape); dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc index 6d5e5543cd40..706780614279 100644 --- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc +++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc @@ -91,12 +91,6 @@ class DNNLJSONRuntime : public JSONRuntimeBase { private: // Build up the engine based on the input graph. - std::map layout_dict{ - {"NCW", tag::ncw}, {"OIW", tag::oiw}, {"GOIW", tag::goiw}, {"NCHW", tag::nchw}, - {"OIHW", tag::oihw}, {"GOIHW", tag::goihw}, {"NCDHW", tag::ncdhw}, {"OIDHW", tag::oidhw}, - {"GOIDHW", tag::goidhw}, {"IOHW", tag::iohw}, {"GIOHW", tag::giohw}, {"IODHW", tag::iodhw}, - {"GIODHW", tag::giodhw}, - }; std::map elt_name2algo{ {"abs", dnnl::algorithm::eltwise_abs}, @@ -112,6 +106,59 @@ class DNNLJSONRuntime : public JSONRuntimeBase { {"clip", dnnl::algorithm::eltwise_clip}, }; + std::map layout_dict{ + {"", tag::any}, + {"NCW", tag::ncw}, + {"NWC", tag::nwc}, + {"OIW", tag::oiw}, + {"GOIW", tag::goiw}, + {"NCHW", tag::nchw}, + {"NHWC", tag::nhwc}, + {"OIHW", tag::oihw}, + {"GOIHW", tag::goihw}, + {"NCDHW", tag::ncdhw}, + {"NDHWC", tag::ndhwc}, + {"OIDHW", tag::oidhw}, + {"GOIDHW", tag::goidhw}, + {"IOHW", tag::iohw}, + {"GIOHW", tag::giohw}, + {"IODHW", tag::iodhw}, + {"GIODHW", tag::giodhw}, + + // Blocking layout. + {"NCW8c", tag::nCw8c}, + {"NCW16c", tag::nCw16c}, + {"OIW16i16o", tag::OIw8i8o}, + {"OIW16i16o", tag::OIw16i16o}, + {"OWI8o", tag::Owi8o}, + {"OWI16o", tag::Owi16o}, + {"NCHW4c", tag::nChw4c}, + {"NCHW8c", tag::nChw8c}, + {"NCHW16c", tag::nChw16c}, + {"OIHW8i8o", tag::OIhw8i8o}, + {"IOHW8i8o", tag::any}, + {"OIHW16i16o", tag::OIhw16i16o}, + {"IOHW16i16o", tag::IOhw16i16o}, + {"GOIHW4i4o", tag::gOIhw4i4o}, + {"GOIHW8i8o", tag::gOIhw8i8o}, + {"GOIHW16i16o", tag::gOIhw16i16o}, + {"OHWI8o", tag::Ohwi8o}, + {"OHWI16o", tag::Ohwi16o}, + {"OHWI32o", tag::Ohwi32o}, + {"OHWI48o", tag::Ohwi48o}, + {"OHWI64o", tag::Ohwi64o}, + {"GOIHW8g", tag::Goihw8g}, + {"GOIHW16g", tag::Goihw16g}, + {"NCDHW8c", tag::nCdhw8c}, + {"NCDHW16c", tag::nCdhw16c}, + {"OIDHW16i16o", tag::OIdhw16i16o}, + {"IODHW16i16o", tag::IOdhw16i16o}, + {"OIDHW8i8o", tag::OIdhw8i8o}, + {"IODHW8i8o", tag::any}, + {"ODHWI8o", tag::Odhwi8o}, + {"ODHWI16o", tag::Odhwi16o}, + }; + bool ParsingOpName(const std::string op_name, dnnl::primitive_attr attr) { // Define RegExp. std::regex bias_add_pat(".*_bias.*"); @@ -136,9 +183,49 @@ class DNNLJSONRuntime : public JSONRuntimeBase { return std::regex_match(op_name, bias_add_pat) ? true : false; } - dnnl::memory::dims TransformStr2Dims(std::vector strs, std::string str_name) { + dnnl::memory::dims TransDims2Plain(dnnl::memory::dims input_dims, std::string layout) { + std::vector axis = { + 'N', 'C', 'O', 'I', 'D', 'H', 'W', + }; dnnl::memory::dims out_dims; - if (str_name == "dilates") { + std::string::iterator t = layout.begin(); + // Remove numbers in layout string to match the size of input_dims + while (t != layout.end()) { + if (*t >= '0' && *t <= '9') { + layout.erase(t); + } else { + t++; + } + } + // Push the correct shapes of each axis into the output_dims + for (auto a : axis) { + dnnl::memory::dim shape = 1; + if (layout.find(a) != std::string::npos) { + shape *= input_dims[layout.find(a)]; + char lower_a = std::tolower(a); + if (layout.find(lower_a) != std::string::npos) { + shape *= input_dims[layout.find(lower_a)]; + } + out_dims.push_back(shape); + } + } + // Multiply O and I with G, respectively + if (layout.find("G") != std::string::npos) { + dnnl::memory::dim G = 1; + if (layout.find("g") != std::string::npos) { + G = input_dims[layout.find("g")] * input_dims[layout.find("G")]; + } else { + G = input_dims[layout.find("G")]; + } + out_dims[0] *= G; + out_dims[1] *= G; + } + return out_dims; + } + + dnnl::memory::dims TransformStr2Dims(std::vector strs, bool dilates = false) { + dnnl::memory::dims out_dims; + if (dilates) { std::transform(strs.begin(), strs.end(), std::back_inserter(out_dims), [](const std::string& str) { return std::stoi(str) - 1; }); } else { @@ -153,7 +240,8 @@ class DNNLJSONRuntime : public JSONRuntimeBase { stream_ = dnnl::stream(engine_); std::regex conv_pat(".*conv[1-3]d.*"); - std::regex conv_tranpose_pat(".*conv[1-3]d_transpose.*"); + std::regex deconv_pat(".*deconv[1-3]d.*"); + std::regex conv_transpose_pat(".*conv[1-3]d_transpose.*"); std::regex dense_pat(".*dense.*"); std::regex max_pool_pat(".*max_pool[1-3]d"); std::regex avg_pool_pat(".*avg_pool[1-3]d"); @@ -164,7 +252,8 @@ class DNNLJSONRuntime : public JSONRuntimeBase { if (node.GetOpType() == "kernel") { ICHECK_EQ(node.GetOpType(), "kernel"); auto op_name = node.GetOpName(); - if (std::regex_match(op_name, conv_tranpose_pat)) { + if (std::regex_match(op_name, deconv_pat) || + std::regex_match(op_name, conv_transpose_pat)) { Deconvolution(nid); } else if (std::regex_match(op_name, conv_pat)) { Convolution(nid); @@ -247,31 +336,51 @@ class DNNLJSONRuntime : public JSONRuntimeBase { std::string kernel_layout = node.GetAttr>("kernel_layout")[0]; // Check layout. - if (layout_dict.find(data_layout) == layout_dict.end() || - layout_dict.find(kernel_layout) == layout_dict.end()) { - LOG(FATAL) << "Unsupported layout for conv: " << data_layout << " " << kernel_layout; + if (layout_dict.find(data_layout) == layout_dict.end()) { + LOG(FATAL) << "Unsupported data layout for conv: " << data_layout; + } + + if (layout_dict.find(kernel_layout) == layout_dict.end()) { + layout_dict.insert({kernel_layout, tag::any}); + LOG(WARNING) << "Unregistered kernel layout for conv: " << data_layout + << ", transfer to tag::any"; } // Memory shapes. - dnnl::memory::dims src_dims = input_shape; // {N, IC, ID, IH, IW} - dnnl::memory::dims weights_dims = weight_shape; // {OC, IC, KD, KH, KW} + dnnl::memory::dims src_dims = TransDims2Plain(input_shape, data_layout); + dnnl::memory::dims weights_dims_ = TransDims2Plain(weight_shape, kernel_layout); + dnnl::memory::dims bias_dims = {channels}; + dnnl::memory::dims strides_dims = TransformStr2Dims(str_strides); + dnnl::memory::dims dilates_dims = TransformStr2Dims(str_dilates, true); + dnnl::memory::dims padding_dims_l = TransformStr2Dims(str_padding_l); + dnnl::memory::dims padding_dims_r = TransformStr2Dims(str_padding_r); + dnnl::memory::dims dst_dims = src_dims; + dst_dims[1] = channels; + weights_dims_[0] = channels; + for (size_t i = 2; i < src_dims.size(); i++) { + dnnl::memory::dim K = weights_dims_[i]; + dnnl::memory::dim S = strides_dims[i - 2]; + dnnl::memory::dim D = dilates_dims[i - 2]; + dnnl::memory::dim PL = padding_dims_l[i - 2]; + dnnl::memory::dim PR = padding_dims_r[i - 2]; + dnnl::memory::dim DK = 1 + (K - 1) * (D + 1); + dst_dims[i] = (src_dims[i] - DK + PL + PR) / S + 1; + } + + dnnl::memory::dims weights_dims = weights_dims_; if (groups > 1) { - weights_dims = {groups, channels / groups, input_shape[1] / groups}; - weights_dims.insert(weights_dims.end(), weight_shape.begin() + 2, weight_shape.end()); - kernel_layout.insert(0, "G"); + weights_dims = {groups, channels / groups, src_dims[1] / groups}; + weights_dims.insert(weights_dims.end(), weights_dims_.begin() + 2, weights_dims_.end()); + if (kernel_layout == "OIHW") { + kernel_layout.insert(0, "G"); + } } - dnnl::memory::dims bias_dims = {channels}; - dnnl::memory::dims dst_dims = out_shape; // {N, OC, OD, OH, OW} - dnnl::memory::dims strides_dims = TransformStr2Dims(str_strides, "strides"); - dnnl::memory::dims dilates_dims = TransformStr2Dims(str_dilates, "dilates"); - dnnl::memory::dims padding_dims_l = TransformStr2Dims(str_padding_l, "padding"); - dnnl::memory::dims padding_dims_r = TransformStr2Dims(str_padding_r, "padding"); // Memory descriptions. auto conv_src_md = dnnl::memory::desc(src_dims, dt::f32, layout_dict[data_layout]); auto conv_weights_md = dnnl::memory::desc(weights_dims, dt::f32, layout_dict[kernel_layout]); auto conv_bias_md = dnnl::memory::desc(bias_dims, dt::f32, tag::any); - auto conv_dst_md = dnnl::memory::desc(dst_dims, dt::f32, layout_dict[data_layout]); + auto conv_dst_md = dnnl::memory::desc(dst_dims, dt::f32, tag::any); // Covn2d description. auto conv_desc = @@ -285,37 +394,37 @@ class DNNLJSONRuntime : public JSONRuntimeBase { dilates_dims, padding_dims_l, padding_dims_r); // Enable elementwise post-ops. - auto conv2d_prim_desc = dnnl::convolution_forward::primitive_desc(conv_desc, attr, engine_); + auto conv_prim_desc = dnnl::convolution_forward::primitive_desc(conv_desc, attr, engine_); // Push to the network. - auto conv = dnnl::convolution_forward(conv2d_prim_desc); + auto conv = dnnl::convolution_forward(conv_prim_desc); net_.push_back(conv); // Data memory. - auto conv2d_src_memory = BindDNNLMemory(data_entry, conv_src_md); + auto conv_src_memory = BindDNNLMemory(data_entry, conv_src_md); // Weight memory. - auto conv2d_weights_memory = BindDNNLMemory(weight_entry, conv_weights_md); + auto conv_weights_memory = BindDNNLMemory(weight_entry, conv_prim_desc.weights_desc()); // Output memory. - auto conv2d_dst_memory = BindDNNLMemory(out_entry, conv2d_prim_desc.dst_desc()); + auto conv_dst_memory = BindDNNLMemory(out_entry, conv_prim_desc.dst_desc()); // Bias memory. - auto conv2d_bias_memory = dnnl::memory({bias_dims, dt::f32, tag::x}, engine_); + auto conv_bias_memory = dnnl::memory({bias_dims, dt::f32, tag::x}, engine_); if (has_bias) { auto bias_entry = node.GetInputs()[2]; - BindDNNLMemory(bias_entry, conv2d_bias_memory); + BindDNNLMemory(bias_entry, conv_bias_memory); // Bind memory buffers. - net_args_.push_back({{DNNL_ARG_SRC, conv2d_src_memory}, - {DNNL_ARG_WEIGHTS, conv2d_weights_memory}, - {DNNL_ARG_BIAS, conv2d_bias_memory}, - {DNNL_ARG_DST, conv2d_dst_memory}}); + net_args_.push_back({{DNNL_ARG_SRC, conv_src_memory}, + {DNNL_ARG_WEIGHTS, conv_weights_memory}, + {DNNL_ARG_BIAS, conv_bias_memory}, + {DNNL_ARG_DST, conv_dst_memory}}); } else { // Bind memory buffers. - net_args_.push_back({{DNNL_ARG_SRC, conv2d_src_memory}, - {DNNL_ARG_WEIGHTS, conv2d_weights_memory}, - {DNNL_ARG_DST, conv2d_dst_memory}}); + net_args_.push_back({{DNNL_ARG_SRC, conv_src_memory}, + {DNNL_ARG_WEIGHTS, conv_weights_memory}, + {DNNL_ARG_DST, conv_dst_memory}}); } } @@ -343,44 +452,63 @@ class DNNLJSONRuntime : public JSONRuntimeBase { str_padding.begin() + str_padding.size() / 2); std::vector str_padding_r(str_padding.end() - str_padding.size() / 2, str_padding.end()); + std::vector str_out_padding = + node.GetAttr>("output_padding"); dnnl::memory::dim groups = std::stoi(node.GetAttr>("groups")[0]); std::string data_layout = node.GetAttr>("data_layout")[0]; std::string kernel_layout = node.GetAttr>("kernel_layout")[0]; // Check layout. - if (layout_dict.find(data_layout) == layout_dict.end() || - layout_dict.find(kernel_layout) == layout_dict.end()) { - LOG(FATAL) << "Unsupported layout: " << data_layout << " " << kernel_layout; + if (layout_dict.find(data_layout) == layout_dict.end()) { + LOG(FATAL) << "Unsupported data layout for deconv: " << data_layout; } - // Memory shapes. - dnnl::memory::dims src_dims = input_shape; // {N, IC, ID, IH, IW} - dnnl::memory::dims weights_dims = weight_shape; // {OC, IC, KD, KH, KW} + if (layout_dict.find(kernel_layout) == layout_dict.end()) { + layout_dict.insert({kernel_layout, tag::any}); + LOG(WARNING) << "Unregistered kernel layout for deconv: " << data_layout + << ", transfer to tag::any"; + } - // Check weight shape, transform to `OIHW` - if (weights_dims[0] == src_dims[1] && weights_dims[1] == channels) { - std::swap(weights_dims[0], weights_dims[1]); + // Memory shapes. + dnnl::memory::dims src_dims = TransDims2Plain(input_shape, data_layout); + dnnl::memory::dims weights_dims_ = TransDims2Plain(weight_shape, kernel_layout); + // legalize shape IOHW with layout OIHW + if (weights_dims_[0] == src_dims[1] && weights_dims_[1] == channels) { + std::swap(weights_dims_[0], weights_dims_[1]); + if (kernel_layout.find("OI") == 0) { + kernel_layout.replace(kernel_layout.find("OI"), 2, "IO"); + } } - if (kernel_layout == "OIDHW") { - kernel_layout = "IODHW"; + dnnl::memory::dims bias_dims = {channels}; + dnnl::memory::dims strides_dims = TransformStr2Dims(str_strides); + dnnl::memory::dims dilates_dims = TransformStr2Dims(str_dilates, true); + dnnl::memory::dims padding_dims_l = TransformStr2Dims(str_padding_l); + dnnl::memory::dims padding_dims_r = TransformStr2Dims(str_padding_r); + dnnl::memory::dims out_padding = TransformStr2Dims(str_out_padding); + dnnl::memory::dims dst_dims = src_dims; + dst_dims[1] = channels; + for (size_t i = 2; i < src_dims.size(); i++) { + dnnl::memory::dim K = weights_dims_[i]; + dnnl::memory::dim S = strides_dims[i - 2]; + dnnl::memory::dim D = dilates_dims[i - 2]; + dnnl::memory::dim PL = padding_dims_l[i - 2]; + dnnl::memory::dim PR = padding_dims_r[i - 2]; + dnnl::memory::dim OP = out_padding[i - 2]; + dnnl::memory::dim DK = 1 + (K - 1) * (D + 1); + dst_dims[i] = S * (src_dims[i] - 1) + DK - PL - PR + OP; } + + dnnl::memory::dims weights_dims = weights_dims_; if (groups > 1) { - weights_dims = {groups, channels / groups, input_shape[1] / groups}; - weights_dims.insert(weights_dims.end(), weight_shape.begin() + 2, weight_shape.end()); - kernel_layout.insert(0, "G"); + weights_dims = {groups, channels / groups, src_dims[1] / groups}; + weights_dims.insert(weights_dims.end(), weights_dims_.begin() + 2, weights_dims_.end()); } - dnnl::memory::dims bias_dims = {channels}; - dnnl::memory::dims dst_dims = out_shape; // {N, OC, OD, OH, OW} - dnnl::memory::dims strides_dims = TransformStr2Dims(str_strides, "strides"); - dnnl::memory::dims dilates_dims = TransformStr2Dims(str_dilates, "dilates"); - dnnl::memory::dims padding_dims_l = TransformStr2Dims(str_padding_l, "padding"); - dnnl::memory::dims padding_dims_r = TransformStr2Dims(str_padding_r, "padding"); // Memory descriptions. auto deconv_src_md = dnnl::memory::desc(src_dims, dt::f32, layout_dict[data_layout]); auto deconv_weights_md = dnnl::memory::desc(weights_dims, dt::f32, layout_dict[kernel_layout]); auto deconv_bias_md = dnnl::memory::desc(bias_dims, dt::f32, tag::any); - auto deconv_dst_md = dnnl::memory::desc(dst_dims, dt::f32, layout_dict[data_layout]); + auto deconv_dst_md = dnnl::memory::desc(dst_dims, dt::f32, tag::any); // Transposed covn2d description. auto deconv_desc = @@ -394,38 +522,37 @@ class DNNLJSONRuntime : public JSONRuntimeBase { padding_dims_l, padding_dims_r); // Enable elementwise post-ops. - auto deconv2d_prim_desc = - dnnl::deconvolution_forward::primitive_desc(deconv_desc, attr, engine_); + auto deconv_prim_desc = dnnl::deconvolution_forward::primitive_desc(deconv_desc, attr, engine_); // Push to the network. - auto deconv = dnnl::deconvolution_forward(deconv2d_prim_desc); + auto deconv = dnnl::deconvolution_forward(deconv_prim_desc); net_.push_back(deconv); // Data memory. - auto deconv2d_src_memory = BindDNNLMemory(data_entry, deconv_src_md); + auto deconv_src_memory = BindDNNLMemory(data_entry, deconv_src_md); // Weight memory. - auto deconv2d_weights_memory = BindDNNLMemory(weight_entry, deconv_weights_md); + auto deconv_weights_memory = BindDNNLMemory(weight_entry, deconv_prim_desc.weights_desc()); // Output memory. - auto deconv2d_dst_memory = BindDNNLMemory(out_entry, deconv2d_prim_desc.dst_desc()); + auto deconv_dst_memory = BindDNNLMemory(out_entry, deconv_prim_desc.dst_desc()); // Bias memory. - auto deconv2d_bias_memory = dnnl::memory({bias_dims, dt::f32, tag::x}, engine_); + auto deconv_bias_memory = dnnl::memory({bias_dims, dt::f32, tag::x}, engine_); if (has_bias) { auto bias_entry = node.GetInputs()[2]; - BindDNNLMemory(bias_entry, deconv2d_bias_memory); + BindDNNLMemory(bias_entry, deconv_bias_memory); // Bind memory buffers. - net_args_.push_back({{DNNL_ARG_SRC, deconv2d_src_memory}, - {DNNL_ARG_WEIGHTS, deconv2d_weights_memory}, - {DNNL_ARG_BIAS, deconv2d_bias_memory}, - {DNNL_ARG_DST, deconv2d_dst_memory}}); + net_args_.push_back({{DNNL_ARG_SRC, deconv_src_memory}, + {DNNL_ARG_WEIGHTS, deconv_weights_memory}, + {DNNL_ARG_BIAS, deconv_bias_memory}, + {DNNL_ARG_DST, deconv_dst_memory}}); } else { // Bind memory buffers. - net_args_.push_back({{DNNL_ARG_SRC, deconv2d_src_memory}, - {DNNL_ARG_WEIGHTS, deconv2d_weights_memory}, - {DNNL_ARG_DST, deconv2d_dst_memory}}); + net_args_.push_back({{DNNL_ARG_SRC, deconv_src_memory}, + {DNNL_ARG_WEIGHTS, deconv_weights_memory}, + {DNNL_ARG_DST, deconv_dst_memory}}); } } @@ -562,13 +689,13 @@ class DNNLJSONRuntime : public JSONRuntimeBase { : dnnl::algorithm::pooling_avg_exclude_padding; } - dnnl::memory::dims src_dims = input_shape; - dnnl::memory::dims dst_dims = out_shape; - dnnl::memory::dims kernel_dims = TransformStr2Dims(str_kernel, "kernel"); - dnnl::memory::dims strides_dims = TransformStr2Dims(str_strides, "strides"); - dnnl::memory::dims dilates_dims = TransformStr2Dims(str_dilates, "dilates"); - dnnl::memory::dims padding_dims_l = TransformStr2Dims(str_padding_l, "padding"); - dnnl::memory::dims padding_dims_r = TransformStr2Dims(str_padding_r, "padding"); + dnnl::memory::dims src_dims = TransDims2Plain(input_shape, layout); + dnnl::memory::dims dst_dims = TransDims2Plain(out_shape, layout); + dnnl::memory::dims kernel_dims = TransformStr2Dims(str_kernel); + dnnl::memory::dims strides_dims = TransformStr2Dims(str_strides); + dnnl::memory::dims dilates_dims = TransformStr2Dims(str_dilates, true); + dnnl::memory::dims padding_dims_l = TransformStr2Dims(str_padding_l); + dnnl::memory::dims padding_dims_r = TransformStr2Dims(str_padding_r); // Memory descriptions. auto pool_src_md = dnnl::memory::desc(src_dims, dt::f32, layout_dict[layout]); diff --git a/src/tir/ir/data_layout.cc b/src/tir/ir/data_layout.cc index 070cd7077d18..5e3ba83ce000 100644 --- a/src/tir/ir/data_layout.cc +++ b/src/tir/ir/data_layout.cc @@ -412,7 +412,6 @@ BijectiveLayout::BijectiveLayout(Layout src_layout, Layout dst_layout) { n->src_layout = std::move(src_layout); n->dst_layout = std::move(dst_layout); - // To be consistent with previous behavior, a nullptr layout is created // when argument is invalid. if (GetStoreRule(&n->index_forward_rule, &n->shape_forward_rule, n->src_layout, n->dst_layout)) { diff --git a/tests/python/contrib/test_dnnl.py b/tests/python/contrib/test_dnnl.py index 4d1972d6a3b0..fb48e05c4d80 100755 --- a/tests/python/contrib/test_dnnl.py +++ b/tests/python/contrib/test_dnnl.py @@ -16,12 +16,16 @@ # under the License. import pytest import itertools +import numpy as np + import tvm -import tvm.relay.testing from tvm import relay +from tvm.relay import transform +from tvm.relay.build_module import bind_params_by_name +from tvm.relay.testing.temp_op_attr import TempOpAttr from tvm.relay.op.contrib import dnnl import tvm.testing -import numpy as np + has_dnnl_codegen = pytest.mark.skipif( not tvm.get_global_func("relay.ext.dnnl", True), reason="DNNL codegen not available" @@ -34,6 +38,74 @@ ) +def partition_for_dnnl(mod, params=None, alter_layout=True): + """Partition the graph greedily offloading supported operators to DNNL. + + Parameters + ---------- + mod : Module + The module to run passes on. + params : Optional[Dict[str, NDArray]] + Constant input parameters. + Returns + ------- + mod : Module + Annotated and partitioned module. + """ + if params: + mod["main"] = bind_params_by_name(mod["main"], params) + + with TempOpAttr("nn.conv2d", "FTVMLegalize", dnnl.legalize_group_conv): + with TempOpAttr("nn.conv2d_transpose", "FTVMLegalize", dnnl.legalize_group_conv): + seq = tvm.transform.Sequential( + [ + transform.CanonicalizeOps(), + transform.InferType(), + transform.SimplifyInference(), + transform.FoldConstant(), + transform.FoldScaleAxis(), + # fold consecutive add ops to simplify pattern `conv2d-bias_add-bn-relu` + transform.SimplifyExpr(), + transform.FoldConstant(), + # alter group conv /conv_transpose layout to `GOIHW` / `GIOHW` + transform.Legalize(), + transform.FoldConstant(), + ] + ) + with tvm.transform.PassContext(opt_level=3): + mod = seq(mod) + if alter_layout: + with TempOpAttr("nn.conv1d", "FTVMAlterOpLayout", dnnl.alter_conv): + with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", dnnl.alter_conv): + with TempOpAttr("nn.conv3d", "FTVMAlterOpLayout", dnnl.alter_conv): + with TempOpAttr( + "nn.conv2d_transpose", "FTVMAlterOpLayout", dnnl.alter_conv_transpose + ): + with TempOpAttr( + "nn.conv3d_transpose", "FTVMAlterOpLayout", dnnl.alter_conv_transpose + ): + alter_layout_seq = tvm.transform.Sequential( + [ + transform.AlterOpLayout(), + transform.FoldConstant(), + ] + ) + with tvm.transform.PassContext(opt_level=3): + mod = alter_layout_seq(mod) + + byoc_seq = tvm.transform.Sequential( + [ + transform.MergeComposite(dnnl.pattern_table()), + transform.AnnotateTarget("dnnl"), + transform.MergeCompilerRegions(), + transform.PartitionGraph(), + ] + ) + with tvm.transform.PassContext(opt_level=3): + mod = byoc_seq(mod) + return mod + + def vmobj_to_list(o): if isinstance(o, tvm.nd.NDArray): return [o.numpy()] @@ -61,13 +133,17 @@ def check_dnnl_used(mod): dev = tvm.cpu() result_dict = dict() for mode in ["graph", "vm"]: - for use_dnnl in [False, True]: - result_key = mode + ("_dnnl" if use_dnnl else "") + for use_dnnl, alter_layout in [(False, False), (True, False), (True, True)]: + result_key = mode + ("_dnnl" if use_dnnl else "") + ("_layout" if alter_layout else "") if use_dnnl: - mod = dnnl.partition_for_dnnl(mod, params) - check_dnnl_used(mod) + processed_mod = partition_for_dnnl(mod, params, alter_layout) + check_dnnl_used(processed_mod) + else: + processed_mod = mod with tvm.transform.PassContext(opt_level=3): - func = relay.create_executor(mode, mod=mod, device=dev, target=target).evaluate() + func = relay.create_executor( + mode, mod=processed_mod, device=dev, target=target + ).evaluate() if run_module: if isinstance(input, dict): result_dict[result_key] = func(**input, **params) @@ -80,13 +156,11 @@ def check_dnnl_used(mod): def run_and_verify_func(config, run_module, target="llvm", dtype="float32"): """Test a Relay func by compiling, running, and comparing TVM and DNNL outputs. - Parameters ---------- config : Tuple[relay.Function, Dict[str, NDArray], List[str]] A tuple containing 1) The function to test, 2) A dictionary of var names to input shapes and 3) A list of which vars should be considered params. - run_module: bool If True, the built module will be run after being compiled. """ @@ -97,12 +171,12 @@ def run_and_verify_func(config, run_module, target="llvm", dtype="float32"): for k, v in input_shapes.items() if k not in is_param } - run_and_verify(f, input_dict, params, target, run_module) + run_and_verify(f, input_dict, params, target=target, run_module=run_module) def get_conv1d( x_shape=((1, 3, 224)), - k_shape=(10, 3, 3), + k_shape=(16, 3, 3), groups=1, padding=(1, 1), strides=(1), @@ -222,7 +296,7 @@ def get_conv2d_transpose( out = relay.nn.conv2d_transpose( x, kernel, - channels=k_shape[1], + channels=k_shape[1] * groups, kernel_size=k_shape[2:4], groups=groups, padding=padding, @@ -251,7 +325,7 @@ def get_conv2d_weights_const( dtype="float32", ): x = relay.var("x", shape=(x_shape), dtype=dtype) - kernel = relay.const(np.ones(k_shape).astype(dtype)) + kernel = relay.const(np.random.randint(0, 1, k_shape).astype(dtype)) out = relay.nn.conv2d( x, kernel, @@ -270,7 +344,7 @@ def get_conv2d_weights_const( def get_conv2d_bias( x_shape=(1, 32, 8, 8), k_shape=(16, 32, 3, 3), activation=None, dtype="float32" ): - conv, dic, param_lst = get_conv2d(x_shape=x_shape, k_shape=k_shape, dtype=dtype) + conv, dic, param_lst = get_conv2d_weights_const(x_shape=x_shape, k_shape=k_shape, dtype=dtype) bias = relay.var("bias", shape=(k_shape[0],), dtype=dtype) out = relay.nn.bias_add(conv, bias) dic["bias"] = (k_shape[0],) @@ -336,7 +410,7 @@ def get_conv3d( dtype="float32", ): x = relay.var("x", shape=(x_shape), dtype=dtype) - kernel = relay.var("kernel", shape=(k_shape), dtype=dtype) + kernel = relay.const(np.random.randint(0, 1, k_shape).astype(dtype)) out = relay.nn.conv3d( x, kernel, @@ -373,7 +447,7 @@ def get_conv3d_transpose( kernel_layout="OIDHW", ): x = relay.var("x", shape=(x_shape), dtype=dtype) - kernel = relay.var("kernel", shape=(k_shape), dtype=dtype) + kernel = relay.const(np.random.randint(0, 1, k_shape).astype(dtype)) out = relay.nn.conv3d_transpose( x, kernel, @@ -466,7 +540,7 @@ def test_dnnl_not_compatible(run_module, target="llvm", dtype="float32"): f = relay.Function([x], out) mod = tvm.IRModule() mod["main"] = f - mod = dnnl.partition_for_dnnl(mod) + mod = partition_for_dnnl(mod) for mode in ["graph", "vm"]: with tvm.transform.PassContext(opt_level=3): func = relay.create_executor(mode, mod=mod, device=tvm.cpu(0), target=target).evaluate() @@ -542,15 +616,22 @@ def get_graph(x_shape, axis): def test_conv1d(run_module, dtype="float32"): - conv1d, dic, param_lst = get_conv1d(channels=10, dtype=dtype) + conv1d, dic, param_lst = get_conv1d(channels=16, dtype=dtype) conv1d = tvm.IRModule.from_expr(conv1d) config = conv1d, dic, param_lst run_and_verify_func(config, run_module=run_module, dtype=dtype) + x_shape = (1, 32, 224) + k_shape = (16, 32, 3) + conv1d_bias, dic, param_lst = get_conv1d(x_shape, k_shape, dtype=dtype) + conv1d_bias = tvm.IRModule.from_expr(conv1d_bias) + config = conv1d_bias, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + def test_conv1d_pattern(run_module, dtype="float32"): x_shape = (1, 3, 224) - k_shape = (10, 3, 3) + k_shape = (16, 3, 3) activation_lst = [None, "relu", "tanh", "sigmoid"] for a in activation_lst: conv1d, dic, param_lst = get_conv1d(x_shape, k_shape, activation=a, dtype=dtype) @@ -566,7 +647,7 @@ def test_conv1d_pattern(run_module, dtype="float32"): def test_conv2d(run_module, dtype="float32"): x_shape = (1, 32, 8, 8) - for k_shape, groups in [((16, 32, 3, 3), 1), ((32, 1, 3, 3), 32)]: + for k_shape, groups in [((16, 32, 3, 3), 1), ((32, 1, 3, 3), 32), ((32, 2, 3, 3), 16)]: for padding in [(0, 0), (1, 1)]: for strides in [(1, 1), (2, 2)]: for dilation in [(1, 1), (2, 2)]: @@ -592,6 +673,13 @@ def test_conv2d_weights_const(run_module, dtype="float32"): config = conv2d, dic, param_lst run_and_verify_func(config, run_module=run_module, dtype=dtype) + x_shape = (1, 3, 8, 8) + k_shape = (16, 3, 3, 3) + conv2d, dic, param_lst = get_conv2d_weights_const(x_shape, k_shape, dtype=dtype) + conv2d = tvm.IRModule.from_expr(conv2d) + config = conv2d, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + def test_conv2d_pattern(run_module, dtype="float32"): x_shape = (1, 32, 8, 8) @@ -615,14 +703,21 @@ def test_conv2d_pattern(run_module, dtype="float32"): def test_conv2d_transpose(run_module, dtype="float32"): - for padding in [(0, 0), (1, 1)]: - for strides in [(1, 1), (2, 2)]: - conv2d_transpose, dic, param_lst = get_conv2d_transpose( - padding=padding, strides=strides, dtype=dtype - ) - conv2d_transpose = tvm.IRModule.from_expr(conv2d_transpose) - config = conv2d_transpose, dic, param_lst - run_and_verify_func(config, run_module=run_module, dtype=dtype) + x_shape = (1, 32, 8, 8) + for k_shape, groups in [((32, 16, 3, 3), 1), ((32, 1, 3, 3), 32), ((32, 4, 3, 3), 16)]: + for padding in [(0, 0), (1, 1)]: + for strides in [(1, 1), (2, 2)]: + conv2d_transpose, dic, param_lst = get_conv2d_transpose( + x_shape=x_shape, + k_shape=k_shape, + groups=groups, + padding=padding, + strides=strides, + dtype=dtype, + ) + conv2d_transpose = tvm.IRModule.from_expr(conv2d_transpose) + config = conv2d_transpose, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) def test_conv2d_transpose_pattern(run_module, dtype="float32"): @@ -650,6 +745,13 @@ def test_conv3d(run_module, dtype="float32"): config = conv3d, dic, param_lst run_and_verify_func(config, run_module=run_module, dtype=dtype) + conv3d, dic, param_lst = get_conv3d( + x_shape=(1, 3, 8, 8, 8), k_shape=(16, 3, 3, 3, 3), dtype=dtype + ) + conv3d = tvm.IRModule.from_expr(conv3d) + config = conv3d, dic, param_lst + run_and_verify_func(config, run_module=run_module, dtype=dtype) + def test_conv3d_pattern(run_module, dtype="float32"): activation_lst = [None, "relu", "tanh", "sigmoid"]