Skip to content

Commit

Permalink
feat: implement DROPLETGeneratorImpl (#687)
Browse files Browse the repository at this point in the history
  • Loading branch information
wiryls authored Feb 2, 2024
1 parent 9c2b622 commit 0eaa761
Showing 1 changed file with 34 additions and 9 deletions.
43 changes: 34 additions & 9 deletions dipu/torch_dipu/csrc_dipu/vendor/droplet/DropletGeneratorImpl.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#include <cstring>

#include <ATen/Functions.h>
#include <ATen/Utils.h>

Expand All @@ -7,21 +9,44 @@

namespace dipu {

// Discriminate floating device type.
// static bool is_floating_device = true;

// just an example
// not implemented now
class DROPLETGeneratorImpl : public dipu::DIPUGeneratorImpl {
private:
static constexpr std::size_t seed_size = sizeof(uint64_t);
static constexpr std::size_t offset_size = sizeof(int64_t);
static constexpr std::size_t total_size = seed_size + offset_size;

public:
DROPLETGeneratorImpl(at::DeviceIndex device_index)
explicit DROPLETGeneratorImpl(at::DeviceIndex device_index)
: dipu::DIPUGeneratorImpl(device_index) {}

void set_state(const c10::TensorImpl& state) override {}

void update_state() const override {}
void set_state(const c10::TensorImpl& state) override {
at::detail::check_rng_state(state);
auto state_size = state.numel();
TORCH_CHECK(
state_size == total_size || state_size == total_size - offset_size,
"RNG state size is invalid");

state_ = at::Tensor(
state.shallow_copy_and_detach(state.version_counter(), true));
state_need_reset_ = false;
}

void update_state() const override {
if (state_need_reset_) {
state_ = at::detail::empty_cpu({static_cast<int64_t>(total_size)},
c10::ScalarType::Byte, c10::nullopt,
c10::nullopt, c10::nullopt, c10::nullopt);
auto rng_state = state_.data_ptr<uint8_t>();
uint64_t seed = this->current_seed();

std::memcpy(rng_state, &seed, seed_size);
std::memset(rng_state + seed_size, 0, offset_size);
state_need_reset_ = false;
}
}
};

// NOLINTNEXTLINE(readability-const-return-type)
const at::Generator vendorMakeGenerator(at::DeviceIndex device_index) {
return at::make_generator<DROPLETGeneratorImpl>(device_index);
}
Expand Down

0 comments on commit 0eaa761

Please sign in to comment.