From 746a09a9876f4521478ab2b27b177f4592dc407c Mon Sep 17 00:00:00 2001 From: Pencilcaseman Date: Wed, 5 Oct 2022 23:21:16 +0100 Subject: [PATCH] Many bug fixes, particularly with CUDA Former-commit-id: da8fd10f12266523061c9e88e602314f25a79bbd [formerly 77e5ca128988a028b07a18b9ba705311ca33bc58] Former-commit-id: 982b1a60ecc588a5173e37d3a7aeed5b526414bc --- src/librapid/VERSION.hpp | 2 +- src/librapid/array/arrayBase.hpp | 20 ++++--- src/librapid/cuda/memUtils.hpp | 2 + src/librapid/math/vector.hpp | 98 +++++++++++++++++++++++++------- 4 files changed, 94 insertions(+), 28 deletions(-) diff --git a/src/librapid/VERSION.hpp b/src/librapid/VERSION.hpp index 225c4dd92..31c6084c7 100644 --- a/src/librapid/VERSION.hpp +++ b/src/librapid/VERSION.hpp @@ -1,3 +1,3 @@ #ifndef LIBRAPID_VERSION -#define LIBRAPID_VERSION "0.5.6" +#define LIBRAPID_VERSION "0.5.7" #endif diff --git a/src/librapid/array/arrayBase.hpp b/src/librapid/array/arrayBase.hpp index bb37bfd40..5e73237d8 100644 --- a/src/librapid/array/arrayBase.hpp +++ b/src/librapid/array/arrayBase.hpp @@ -147,11 +147,11 @@ namespace librapid { template class ArrayBase { public: - using Scalar = typename internal::traits::Scalar; - using BaseScalar = typename internal::traits::BaseScalar; - using This = ArrayBase; - using Packet = typename internal::traits::Packet; - using StorageType = typename internal::traits::StorageType; + using Scalar = typename internal::traits::Scalar; + using BaseScalar = typename internal::traits::BaseScalar; + using This = ArrayBase; + using Packet = typename internal::traits::Packet; + using StorageType = typename internal::traits::StorageType; static constexpr ui64 Flags = internal::traits::Flags; friend Derived; @@ -265,8 +265,6 @@ void castKernel({1} *dst, {2} *src, i64 size) {{ size /= sizeof(BaseScalar) * 8; } - fmt::print("Information: {}\n", typeid(BaseScalar).name()); - memory::memcpy( res.storage().heap(), eval().storage().heap(), size); return res; @@ -462,7 +460,7 @@ void castKernel({1} *dst, {2} *src, i64 size) {{ } LR_NODISCARD("Do not ignore the result of an evaluated calculation") - auto eval() const { return derived(); } + auto eval() const { return derived().eval(); } template LR_FORCE_INLINE void loadFrom(i64 index, const OtherDerived &other) { @@ -530,6 +528,12 @@ void castKernel({1} *dst, {2} *src, i64 size) {{ m_extent.str(), other.extent().str()); + // If device differs, we need to copy the data + if constexpr (!std::is_same_v::Device>) { + return assignLazy(other.move()); + } + using Selector = functors::AssignOp; Selector::run(derived(), other.derived()); return derived(); diff --git a/src/librapid/cuda/memUtils.hpp b/src/librapid/cuda/memUtils.hpp index 771e12313..43a35c5eb 100644 --- a/src/librapid/cuda/memUtils.hpp +++ b/src/librapid/cuda/memUtils.hpp @@ -98,6 +98,8 @@ To IGNORE this error, just define LIBRAPID_NO_THREAD_CHECK above LibRapid includ cudaMemcpyAsync(dst, src, sizeof(T) * size, cudaMemcpyDeviceToHost, cudaStream)); } else if constexpr (std::is_same_v && std::is_same_v) { + // fmt::print("Info: {} {} {} {}\n", (void *) dst[0], (void *) dst[1], (void *)src[0], (void *)src[1]); + // Host to Device cudaSafeCall( cudaMemcpyAsync(dst, src, sizeof(T) * size, cudaMemcpyHostToDevice, cudaStream)); diff --git a/src/librapid/math/vector.hpp b/src/librapid/math/vector.hpp index 42085547d..8b73281c5 100644 --- a/src/librapid/math/vector.hpp +++ b/src/librapid/math/vector.hpp @@ -82,20 +82,71 @@ namespace librapid { } LR_FORCE_INLINE - VecImpl operator>(const VecImpl &other) { + VecImpl cmp(const VecImpl &other, const char *mode) const { + // Mode: + // 0: == + // 1: != + // 2: < + // 3: <= + // 4: > + // 5: >= + VecImpl res(*this); + i16 modeInt = *(i16 *)mode; + fmt::print("Info: {:Lb}\n", modeInt); + fmt::print("Info: {:Lb}\n", ('g' << 8) | 't'); for (i64 i = 0; i < Dims; ++i) { - if (res[i] > other[i]) { - res[i] = 1; - } else { - res[i] = 0; + switch (modeInt) { + case 'e' | ('q' << 8): + if (res[i] == other[i]) { + res[i] = 1; + } else { + res[i] = 0; + } + break; + case 'n' | ('e' << 8): + if (res[i] != other[i]) { + res[i] = 1; + } else { + res[i] = 0; + } + break; + case 'l' | ('t' << 8): + if (res[i] < other[i]) { + res[i] = 1; + } else { + res[i] = 0; + } + break; + case 'l' | ('e' << 8): + if (res[i] <= other[i]) { + res[i] = 1; + } else { + res[i] = 0; + } + break; + case 'g' | ('t' << 8): + if (res[i] > other[i]) { + res[i] = 1; + } else { + res[i] = 0; + } + break; + case 'g' | ('e' << 8): + if (res[i] >= other[i]) { + res[i] = 1; + } else { + res[i] = 0; + } + break; + default: LR_ASSERT(false, "Invalid mode {}", mode); } } return res; } LR_FORCE_INLINE - VecImpl cmp(const VecImpl &other, char mode[2]) { + VecImpl cmp(const Scalar &value, const char *mode) const { // Mode: // 0: == // 1: != @@ -105,46 +156,48 @@ namespace librapid { // 5: >= VecImpl res(*this); - i16 modeInt = (mode[1] << 8) | mode[0]; + i16 modeInt = *(i16 *)mode; + fmt::print("Info: {:Lb}\n", modeInt); + fmt::print("Info: {:Lb}\n", ('g' << 8) | 't'); for (i64 i = 0; i < Dims; ++i) { switch (modeInt) { - case ('e' << 8) | 'q': - if (res[i] == other[i]) { + case 'e' | ('q' << 8): + if (res[i] == value) { res[i] = 1; } else { res[i] = 0; } break; - case ('n' << 8) | 'e': - if (res[i] != other[i]) { + case 'n' | ('e' << 8): + if (res[i] != value) { res[i] = 1; } else { res[i] = 0; } break; - case ('l' << 8) | 't': - if (res[i] < other[i]) { + case 'l' | ('t' << 8): + if (res[i] < value) { res[i] = 1; } else { res[i] = 0; } break; - case ('l' << 8) | 'e': - if (res[i] <= other[i]) { + case 'l' | ('e' << 8): + if (res[i] <= value) { res[i] = 1; } else { res[i] = 0; } break; - case ('g' << 8) | 't': - if (res[i] > other[i]) { + case 'g' | ('t' << 8): + if (res[i] > value) { res[i] = 1; } else { res[i] = 0; } break; - case ('g' << 8) | 'e': - if (res[i] >= other[i]) { + case 'g' | ('e' << 8): + if (res[i] >= value) { res[i] = 1; } else { res[i] = 0; @@ -163,6 +216,13 @@ namespace librapid { LR_FORCE_INLINE VecImpl operator==(const VecImpl &other) const { return cmp(other, "eq"); } LR_FORCE_INLINE VecImpl operator!=(const VecImpl &other) const { return cmp(other, "ne"); } + LR_FORCE_INLINE VecImpl operator<(const Scalar &other) const { return cmp(other, "lt"); } + LR_FORCE_INLINE VecImpl operator<=(const Scalar &other) const { return cmp(other, "le"); } + LR_FORCE_INLINE VecImpl operator>(const Scalar &other) const { return cmp(other, "gt"); } + LR_FORCE_INLINE VecImpl operator>=(const Scalar &other) const { return cmp(other, "ge"); } + LR_FORCE_INLINE VecImpl operator==(const Scalar &other) const { return cmp(other, "eq"); } + LR_FORCE_INLINE VecImpl operator!=(const Scalar &other) const { return cmp(other, "ne"); } + LR_NODISCARD("") LR_INLINE Scalar mag2() const { return (m_data * m_data).sum(); } LR_NODISCARD("") LR_INLINE Scalar mag() const { return ::librapid::sqrt(mag2()); } LR_NODISCARD("") LR_INLINE Scalar invMag() const { return 1.0 / mag(); }