Skip to content

Commit

Permalink
Init
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao committed Aug 25, 2023
1 parent 5d0ef94 commit 7fd9b53
Show file tree
Hide file tree
Showing 33 changed files with 2,665 additions and 11 deletions.
15 changes: 15 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ include(cmake/utils/Utils.cmake)
include(cmake/utils/Summary.cmake)
include(cmake/utils/Linker.cmake)
include(cmake/utils/FindCUDA.cmake)
include(cmake/utils/FindNCCL.cmake)
include(cmake/utils/FindOpenCL.cmake)
include(cmake/utils/FindVulkan.cmake)
include(cmake/utils/FindLLVM.cmake)
Expand All @@ -25,6 +26,7 @@ endif()
# and add set(OPTION VALUE) to override these build options.
# Alernatively, use cmake -DOPTION=VALUE through command-line.
tvm_option(USE_CUDA "Build with CUDA" OFF)
tvm_option(USE_NCCL "Build with NCCL" OFF)
tvm_option(USE_OPENCL "Build with OpenCL" OFF)
tvm_option(USE_OPENCL_ENABLE_HOST_PTR "Enable OpenCL memory object access to host" OFF)
tvm_option(USE_OPENCL_GTEST "Path to OpenCL specific gtest version for runtime cpp tests." /path/to/opencl/gtest)
Expand Down Expand Up @@ -350,6 +352,7 @@ list(APPEND COMPILER_SRCS "src/target/datatype/myfloat/myfloat.cc")
tvm_file_glob(GLOB RUNTIME_SRCS
src/runtime/*.cc
src/runtime/vm/*.cc
src/runtime/disco/*.cc
src/runtime/minrpc/*.cc
src/runtime/relax_vm/*.cc
)
Expand Down Expand Up @@ -434,6 +437,13 @@ if(USE_PROFILER)
list(APPEND RUNTIME_SRCS ${RUNTIME_VM_PROFILER_SRCS})
endif(USE_PROFILER)

if(USE_CUDA AND USE_NCCL)
message(STATUS "Build with NCCL...")
find_nccl(${USE_NCCL})
tvm_file_glob(GLOB RUNTIME_NCCL_SRC src/runtime/disco/nccl/*.cc)
list(APPEND RUNTIME_SRCS ${RUNTIME_NCCL_SRC})
endif()

if(USE_AOT_EXECUTOR)
message(STATUS "Build with AOT Executor support...")
file(GLOB RUNTIME_AOT_EXECUTOR_SRCS src/runtime/aot_executor/*.cc)
Expand Down Expand Up @@ -850,3 +860,8 @@ if(USE_CUDA AND USE_CUTLASS)
target_link_libraries(tvm PRIVATE -Wl,--no-as-needed flash_attn)
target_link_libraries(tvm_runtime PRIVATE -Wl,--no-as-needed flash_attn)
endif()

if(USE_CUDA AND USE_NCCL)
target_link_libraries(tvm_runtime PRIVATE nccl)
target_link_libraries(tvm PRIVATE nccl)
endif()
6 changes: 6 additions & 0 deletions cmake/config.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@
# - /path/to/cuda: use specific path to cuda toolkit
set(USE_CUDA OFF)

# Whether to enable NCCL support:
# - ON: enable NCCL with cmake's auto search
# - OFF: disable NCCL
# - /path/to/nccl: use specific path to nccl
set(USE_NCCL OFF)

# Whether enable ROCM runtime
#
# Possible values:
Expand Down
1 change: 1 addition & 0 deletions cmake/modules/LibInfo.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ function(add_lib_info src_file)
TVM_INFO_USE_CPP_RTVM="${USE_CPP_RTVM}"
TVM_INFO_USE_CUBLAS="${USE_CUBLAS}"
TVM_INFO_USE_CUDA="${USE_CUDA}"
TVM_INFO_USE_NCCL="${USE_NCCL}"
TVM_INFO_USE_CUDNN="${USE_CUDNN}"
TVM_INFO_USE_CUSTOM_LOGGING="${USE_CUSTOM_LOGGING}"
TVM_INFO_USE_CUTLASS="${USE_CUTLASS}"
Expand Down
56 changes: 56 additions & 0 deletions cmake/utils/FindNCCL.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
# Variables used by this module, they can change the default behaviour and need
# to be set before calling find_package:
#
# NCCL_ROOT - When set, this path is inspected instead of standard library
# locations as the root of the NCCL installation.
# The environment variable NCCL_ROOT overrides this variable.
#
# This module defines
# Nccl_FOUND, whether nccl has been found
# NCCL_INCLUDE_DIR, directory containing header
# NCCL_LIBRARY, directory containing nccl library
# This module assumes that the user has already called find_package(CUDA)

macro(find_nccl use_nccl)
if(${use_nccl} MATCHES ${IS_FALSE_PATTERN})
return()
endif()
if(${use_nccl} MATCHES ${IS_TRUE_PATTERN})
find_path(NCCL_INCLUDE_DIR NAMES nccl.h)
find_library(NCCL_LIBRARY NAMES nccl)
else()
find_path(NCCL_INCLUDE_DIR NAMES nccl.h HINTS ${use_nccl} ${use_nccl}/include)
find_library(NCCL_LIBRARY NAMES nccl HINTS ${use_nccl} ${use_nccl}/lib)
endif()
include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(Nccl DEFAULT_MSG NCCL_INCLUDE_DIR NCCL_LIBRARY)
if (Nccl_FOUND)
message(STATUS "Found NCCL_LIBRARY: ${NCCL_LIBRARY}")
message(STATUS "Found NCCL_INCLUDE_DIR: ${NCCL_INCLUDE_DIR}")
add_library(nccl SHARED IMPORTED)
set_target_properties(nccl
PROPERTIES
INTERFACE_INCLUDE_DIRECTORIES "${NCCL_INCLUDE_DIR}"
IMPORTED_LOCATION "${NCCL_LIBRARY}")
else()
message(STATUS "NCCL not found")
endif()
mark_as_advanced(NCCL_INCLUDE_DIR NCCL_LIBRARY)
endmacro(find_nccl)
46 changes: 46 additions & 0 deletions include/tvm/relax/attrs/ccl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/relax/attrs/ccl.h
* \brief Attributes for ccl operators.
*/
#ifndef TVM_RELAX_ATTRS_CCL_H_
#define TVM_RELAX_ATTRS_CCL_H_

#include <tvm/relax/expr.h>

namespace tvm {
namespace relax {

/*! \brief Attributes used in allreduce operators */
struct AllReduceAttrs : public tvm::AttrsNode<AllReduceAttrs> {
String op_type;

TVM_DECLARE_ATTRS(AllReduceAttrs, "relax.attrs.AllReduceAttrs") {
TVM_ATTR_FIELD(op_type).describe(
"The type of reduction operation to be applied to the input data. Now only sum is "
"supported.");
}
}; // struct AllReduceAttrs

} // namespace relax
} // namespace tvm

#endif // TVM_RELAX_ATTRS_CCL_H_
55 changes: 44 additions & 11 deletions include/tvm/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -1284,13 +1284,27 @@ namespace parameter_pack {

template <typename... EnumArgs>
struct EnumeratedParamPack {
struct Invoke {
template <template <size_t i, typename TArgument> class Functor, typename... ExtraParams>
static void F(ExtraParams&&... extra_params) {
struct InvokeWithoutArg {
template <template <size_t i, typename TArgument> class Functor, typename ExtraParams>
static void F(ExtraParams&& extra_params) {
using TExpander = int[];
(void)TExpander{
0,
(Functor<EnumArgs::i, typename EnumArgs::T>::F(extra_params...), 0)...,
(Functor<EnumArgs::i, typename EnumArgs::T>::F(std::forward<ExtraParams>(extra_params)),
0)...,
};
}
};
struct InvokeWithArg {
template <template <size_t i, typename TArgument> class Functor, typename ExtraParams,
typename... Params>
static void F(ExtraParams&& extra_params, Params&&... params) {
using TExpander = int[];
(void)TExpander{
0,
(Functor<EnumArgs::i, typename EnumArgs::T>::F(std::forward<ExtraParams>(extra_params),
std::forward<Params>(params)),
0)...,
};
}
};
Expand All @@ -1310,22 +1324,27 @@ struct EnumerateImpl {

template <std::size_t... id>
struct Zipper<std::integer_sequence<std::size_t, id...>> {
using T = EnumeratedParamPack<Item<id, Args>...>;
using WithoutArg = typename EnumeratedParamPack<Item<id, Args>...>::InvokeWithoutArg;
using WithArg = typename EnumeratedParamPack<Item<id, Args>...>::InvokeWithArg;
};

public:
using T = typename Zipper<std::index_sequence_for<Args...>>::T;
using WithoutArg = typename Zipper<std::index_sequence_for<Args...>>::WithoutArg;
using WithArg = typename Zipper<std::index_sequence_for<Args...>>::WithArg;
};

template <typename... Args>
using Enumerate = typename EnumerateImpl<Args...>::T;
using EnumerateWithoutArg = typename EnumerateImpl<Args...>::WithoutArg;

template <typename... Args>
using EnumerateWithArg = typename EnumerateImpl<Args...>::WithArg;

template <typename... Args>
struct ParamPack {
template <template <size_t i, typename TArgument> class Functor, typename... ExtraParams>
static void InvokeWithoutArg(ExtraParams&&... extra_params) {
Enumerate<Args...>::Invoke::template F<Functor, ExtraParams...>(
std::forward<ExtraParams>(extra_params)...);
template <template <size_t i, typename TArgument> class Functor, typename ExtraParams>
static void InvokeWithoutArg(ExtraParams&& extra_params) {
EnumerateWithoutArg<Args...>::template F<Functor, ExtraParams>(
std::forward<ExtraParams>(extra_params));
}
};

Expand Down Expand Up @@ -1622,6 +1641,20 @@ inline TVMRetValue PackedFunc::operator()(Args&&... args) const {
return rv;
}

template <int i, typename T>
struct TVMArgsSetterApply {
static TVM_ALWAYS_INLINE void F(TVMArgsSetter* setter, T&& value) {
(*setter)(i, std::forward<T>(value));
}
};

template <typename... Args>
void TVM_ALWAYS_INLINE PackArgs(TVMValue* values, int* type_codes, Args&&... args) {
TVMArgsSetter setter(values, type_codes);
detail::parameter_pack::EnumerateWithArg<Args...>::template F<TVMArgsSetterApply>(
&setter, std::forward<Args>(args)...);
}

namespace detail {
template <typename R, int nleft, int index, typename F>
struct unpack_call_dispatcher {
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relax/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from . import image
from . import memory
from . import nn
from . import ccl

# Register operator gradient functions
from . import _op_gradient
Expand Down
19 changes: 19 additions & 0 deletions python/tvm/relax/op/ccl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=wildcard-import
"""CCL related operators."""
from .ccl import *
20 changes: 20 additions & 0 deletions python/tvm/relax/op/ccl/_ffi_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Operators serving for Collective Communications Library (CCL) operators"""
import tvm._ffi

tvm._ffi._init_api("relax.op.ccl", __name__)
42 changes: 42 additions & 0 deletions python/tvm/relax/op/ccl/ccl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Relax Collective Communications Library (CCL) operators"""
from . import _ffi_api
from ...expr import Expr


def allreduce(x, op_type: str = "sum"):
"""Allreduce operator
Parameters
----------
x : relax.Expr
The input tensor.
op_type: str
The type of reduction operation to be applied to the input data.
Now "sum", "prod", "min", "max" and "avg" are supported.
Returns
-------
result : relax.Expr
The result of allreduce.
"""
supported_op_types = ["sum", "prod", "min", "max", "avg"]
assert (
op_type in supported_op_types
), f"Allreduce only supports limited reduction operations, including {supported_op_types}, but got {op_type}."
return _ffi_api.allreduce(x, op_type)
1 change: 1 addition & 0 deletions python/tvm/relax/transform/legalize_ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
"""Legalize high-level operator calls in Relax functions to call_tir."""
from . import binary
from . import ccl
from . import create
from . import datatype
from . import grad
Expand Down
Loading

0 comments on commit 7fd9b53

Please sign in to comment.