Skip to content

Commit

Permalink
Many bug fixes, particularly with CUDA
Browse files Browse the repository at this point in the history
Former-commit-id: da8fd10 [formerly 77e5ca1]
Former-commit-id: 982b1a6
  • Loading branch information
Pencilcaseman committed Oct 5, 2022
1 parent c4bed0b commit 746a09a
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 28 deletions.
2 changes: 1 addition & 1 deletion src/librapid/VERSION.hpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
#ifndef LIBRAPID_VERSION
#define LIBRAPID_VERSION "0.5.6"
#define LIBRAPID_VERSION "0.5.7"
#endif
20 changes: 12 additions & 8 deletions src/librapid/array/arrayBase.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,11 @@ namespace librapid {
template<typename Derived, typename Device>
class ArrayBase {
public:
using Scalar = typename internal::traits<Derived>::Scalar;
using BaseScalar = typename internal::traits<Scalar>::BaseScalar;
using This = ArrayBase<Derived, Device>;
using Packet = typename internal::traits<Derived>::Packet;
using StorageType = typename internal::traits<Derived>::StorageType;
using Scalar = typename internal::traits<Derived>::Scalar;
using BaseScalar = typename internal::traits<Scalar>::BaseScalar;
using This = ArrayBase<Derived, Device>;
using Packet = typename internal::traits<Derived>::Packet;
using StorageType = typename internal::traits<Derived>::StorageType;
static constexpr ui64 Flags = internal::traits<This>::Flags;

friend Derived;
Expand Down Expand Up @@ -265,8 +265,6 @@ void castKernel({1} *dst, {2} *src, i64 size) {{
size /= sizeof(BaseScalar) * 8;
}

fmt::print("Information: {}\n", typeid(BaseScalar).name());

memory::memcpy<BaseScalar, D, BaseScalar, Device>(
res.storage().heap(), eval().storage().heap(), size);
return res;
Expand Down Expand Up @@ -462,7 +460,7 @@ void castKernel({1} *dst, {2} *src, i64 size) {{
}

LR_NODISCARD("Do not ignore the result of an evaluated calculation")
auto eval() const { return derived(); }
auto eval() const { return derived().eval(); }

template<typename OtherDerived>
LR_FORCE_INLINE void loadFrom(i64 index, const OtherDerived &other) {
Expand Down Expand Up @@ -530,6 +528,12 @@ void castKernel({1} *dst, {2} *src, i64 size) {{
m_extent.str(),
other.extent().str());

// If device differs, we need to copy the data
if constexpr (!std::is_same_v<Device,
typename internal::traits<OtherDerived>::Device>) {
return assignLazy(other.move<Device>());
}

using Selector = functors::AssignOp<Derived, OtherDerived>;
Selector::run(derived(), other.derived());
return derived();
Expand Down
2 changes: 2 additions & 0 deletions src/librapid/cuda/memUtils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ To IGNORE this error, just define LIBRAPID_NO_THREAD_CHECK above LibRapid includ
cudaMemcpyAsync(dst, src, sizeof(T) * size, cudaMemcpyDeviceToHost, cudaStream));
} else if constexpr (std::is_same_v<d, device::GPU> &&
std::is_same_v<d_, device::CPU>) {
// fmt::print("Info: {} {} {} {}\n", (void *) dst[0], (void *) dst[1], (void *)src[0], (void *)src[1]);

// Host to Device
cudaSafeCall(
cudaMemcpyAsync(dst, src, sizeof(T) * size, cudaMemcpyHostToDevice, cudaStream));
Expand Down
98 changes: 79 additions & 19 deletions src/librapid/math/vector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,20 +82,71 @@ namespace librapid {
}

LR_FORCE_INLINE
VecImpl operator>(const VecImpl &other) {
VecImpl cmp(const VecImpl &other, const char *mode) const {
// Mode:
// 0: ==
// 1: !=
// 2: <
// 3: <=
// 4: >
// 5: >=

VecImpl res(*this);
i16 modeInt = *(i16 *)mode;
fmt::print("Info: {:Lb}\n", modeInt);
fmt::print("Info: {:Lb}\n", ('g' << 8) | 't');
for (i64 i = 0; i < Dims; ++i) {
if (res[i] > other[i]) {
res[i] = 1;
} else {
res[i] = 0;
switch (modeInt) {
case 'e' | ('q' << 8):
if (res[i] == other[i]) {
res[i] = 1;
} else {
res[i] = 0;
}
break;
case 'n' | ('e' << 8):
if (res[i] != other[i]) {
res[i] = 1;
} else {
res[i] = 0;
}
break;
case 'l' | ('t' << 8):
if (res[i] < other[i]) {
res[i] = 1;
} else {
res[i] = 0;
}
break;
case 'l' | ('e' << 8):
if (res[i] <= other[i]) {
res[i] = 1;
} else {
res[i] = 0;
}
break;
case 'g' | ('t' << 8):
if (res[i] > other[i]) {
res[i] = 1;
} else {
res[i] = 0;
}
break;
case 'g' | ('e' << 8):
if (res[i] >= other[i]) {
res[i] = 1;
} else {
res[i] = 0;
}
break;
default: LR_ASSERT(false, "Invalid mode {}", mode);
}
}
return res;
}

LR_FORCE_INLINE
VecImpl cmp(const VecImpl &other, char mode[2]) {
VecImpl cmp(const Scalar &value, const char *mode) const {
// Mode:
// 0: ==
// 1: !=
Expand All @@ -105,46 +156,48 @@ namespace librapid {
// 5: >=

VecImpl res(*this);
i16 modeInt = (mode[1] << 8) | mode[0];
i16 modeInt = *(i16 *)mode;
fmt::print("Info: {:Lb}\n", modeInt);
fmt::print("Info: {:Lb}\n", ('g' << 8) | 't');
for (i64 i = 0; i < Dims; ++i) {
switch (modeInt) {
case ('e' << 8) | 'q':
if (res[i] == other[i]) {
case 'e' | ('q' << 8):
if (res[i] == value) {
res[i] = 1;
} else {
res[i] = 0;
}
break;
case ('n' << 8) | 'e':
if (res[i] != other[i]) {
case 'n' | ('e' << 8):
if (res[i] != value) {
res[i] = 1;
} else {
res[i] = 0;
}
break;
case ('l' << 8) | 't':
if (res[i] < other[i]) {
case 'l' | ('t' << 8):
if (res[i] < value) {
res[i] = 1;
} else {
res[i] = 0;
}
break;
case ('l' << 8) | 'e':
if (res[i] <= other[i]) {
case 'l' | ('e' << 8):
if (res[i] <= value) {
res[i] = 1;
} else {
res[i] = 0;
}
break;
case ('g' << 8) | 't':
if (res[i] > other[i]) {
case 'g' | ('t' << 8):
if (res[i] > value) {
res[i] = 1;
} else {
res[i] = 0;
}
break;
case ('g' << 8) | 'e':
if (res[i] >= other[i]) {
case 'g' | ('e' << 8):
if (res[i] >= value) {
res[i] = 1;
} else {
res[i] = 0;
Expand All @@ -163,6 +216,13 @@ namespace librapid {
LR_FORCE_INLINE VecImpl operator==(const VecImpl &other) const { return cmp(other, "eq"); }
LR_FORCE_INLINE VecImpl operator!=(const VecImpl &other) const { return cmp(other, "ne"); }

LR_FORCE_INLINE VecImpl operator<(const Scalar &other) const { return cmp(other, "lt"); }
LR_FORCE_INLINE VecImpl operator<=(const Scalar &other) const { return cmp(other, "le"); }
LR_FORCE_INLINE VecImpl operator>(const Scalar &other) const { return cmp(other, "gt"); }
LR_FORCE_INLINE VecImpl operator>=(const Scalar &other) const { return cmp(other, "ge"); }
LR_FORCE_INLINE VecImpl operator==(const Scalar &other) const { return cmp(other, "eq"); }
LR_FORCE_INLINE VecImpl operator!=(const Scalar &other) const { return cmp(other, "ne"); }

LR_NODISCARD("") LR_INLINE Scalar mag2() const { return (m_data * m_data).sum(); }
LR_NODISCARD("") LR_INLINE Scalar mag() const { return ::librapid::sqrt(mag2()); }
LR_NODISCARD("") LR_INLINE Scalar invMag() const { return 1.0 / mag(); }
Expand Down

0 comments on commit 746a09a

Please sign in to comment.