Skip to content

Commit

Permalink
Merge pull request #43 from FL33TW00D/feature/inplace
Browse files Browse the repository at this point in the history
feature: inplace support
  • Loading branch information
FL33TW00D authored Jan 26, 2024
2 parents 116c91a + 55c3ba0 commit cc32e02
Show file tree
Hide file tree
Showing 17 changed files with 550 additions and 118 deletions.
1 change: 0 additions & 1 deletion .github/workflows/ratbot.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ jobs:
${{ steps.scc.outputs.scc }}
```
\`\`\`
</details>
`;
github.rest.issues.createComment({
issue_number: context.issue.number,
Expand Down
4 changes: 1 addition & 3 deletions .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ jobs:
include:
- name: Windows x86_64
os: windows-2022
- name: Linux x86_64
os: ubuntu-latest
#TODO: android?
steps:
- uses: actions/checkout@v4
Expand Down Expand Up @@ -84,7 +82,7 @@ jobs:
shell: bash
run: |
set -e
cargo nextest run --features pyo3 --no-fail-fast
cargo nextest run -j 1 --features pyo3 --no-fail-fast
# - name: Install wasm-pack
# run: |
Expand Down
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ strip = true
#debug = 2

[workspace.dependencies]
wgpu = { version = "0.18.0", features = ["fragile-send-sync-non-atomic-wasm", "expose-ids"] }
wgpu = { version = "0.19.1", features = ["fragile-send-sync-non-atomic-wasm"] }
anyhow = "1.0.40"
bytemuck = { version = "1.14.0", features=["wasm_simd", "aarch64_simd", "extern_crate_alloc"] }
num-traits = "0.2.17"
Expand Down
1 change: 1 addition & 0 deletions crates/ratchet-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,5 @@ rand = "0.8.4"
pyo3 = { version = "0.20.2", features=["auto-initialize"] }
numpy = { version = "0.20.0" }
ndarray = { version = "0.15.6" }
regex = "1.10.3"

86 changes: 86 additions & 0 deletions crates/ratchet-core/kernels/softmax_scalar.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
//https://developer.download.nvidia.com/assets/cuda/files/reduction.pdf
@group(0) @binding(0)
var<storage, read_write> X: array<f32>;

struct Meta {
M: u32,
N: u32,
ND4: u32,
}

@group(1) @binding(0)
var<uniform> metadata: Meta;

var<workgroup> smem: array<f32, 128>; // max size is 16kb
var<workgroup> maximum: f32;
var<workgroup> sum: f32;

const BLOCK_SIZE = 128u;
const minFloat: f32 = -3.402823e+38f;

fn block_sum(index: u32, stride: u32) {
if index < stride {
smem[index] += smem[index + stride];
}
workgroupBarrier();
}

fn block_max(index: u32, stride: u32) {
if index < stride {
smem[index] = max(smem[index], smem[index + stride]);
}
workgroupBarrier();
}

@compute @workgroup_size(128, 1, 1)
fn main(
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(workgroup_id) group_id: vec3<u32>,
) {
let batch_stride = group_id.y * metadata.M * metadata.N;
let row_start = batch_stride + group_id.x * metadata.N;
let index = local_id.x;

smem[index] = minFloat;
for (var i: u32 = index; i < metadata.N; i += BLOCK_SIZE) {
smem[index] = max(smem[index], X[row_start + i]);
}
workgroupBarrier();

block_max(index, 64u);
block_max(index, 32u);
block_max(index, 16u);
block_max(index, 8u);
block_max(index, 4u);
block_max(index, 2u);
block_max(index, 1u);

if index == 0u{
maximum = smem[0];
}
workgroupBarrier();

smem[index] = 0.0;
for (var i: u32 = index; i < metadata.N; i += BLOCK_SIZE) {
smem[index] += exp(X[row_start + i] - maximum);
}

workgroupBarrier();
block_sum(index, 64u);
block_sum(index, 32u);
block_sum(index, 16u);
block_sum(index, 8u);
block_sum(index, 4u);
block_sum(index, 2u);
block_sum(index, 1u);

if index == 0u {
sum = smem[0];
}
workgroupBarrier();

for(var i: u32 = index; i < metadata.N; i += BLOCK_SIZE) {
var val = X[row_start + i];
X[row_start + i] = exp(val - maximum) / sum;
}
}
86 changes: 86 additions & 0 deletions crates/ratchet-core/kernels/softmax_vec4.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
//https://developer.download.nvidia.com/assets/cuda/files/reduction.pdf
@group(0) @binding(0)
var<storage, read_write> X: array<vec4<f32>>;

struct Meta {
M: u32,
N: u32,
ND4: u32,
}

@group(1) @binding(0)
var<uniform> metadata: Meta;

var<workgroup> smem: array<vec4<f32>, 128>; // max size is 16kb
var<workgroup> maximum: f32;
var<workgroup> sum: f32;

const BLOCK_SIZE = 128u;
const minFloat: f32 = -3.402823e+38f;

fn block_sum(index: u32, stride: u32) {
if index < stride {
smem[index] += smem[index + stride];
}
workgroupBarrier();
}

fn block_max(index: u32, stride: u32) {
if index < stride {
smem[index] = max(smem[index], smem[index + stride]);
}
workgroupBarrier();
}

@compute @workgroup_size(128, 1, 1)
fn main(
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(workgroup_id) group_id: vec3<u32>,
) {
let batch_stride = group_id.y * metadata.M * metadata.ND4;
let row_start = batch_stride + group_id.x * metadata.ND4;
let index = local_id.x;

smem[index] = vec4<f32>(minFloat);
for (var i: u32 = index; i < metadata.ND4; i += BLOCK_SIZE) {
smem[index] = max(smem[index], X[row_start + i]);
}
workgroupBarrier();

block_max(index, 64u);
block_max(index, 32u);
block_max(index, 16u);
block_max(index, 8u);
block_max(index, 4u);
block_max(index, 2u);
block_max(index, 1u);

if index == 0u{
maximum = max(smem[0].x, max(smem[0].y, max(smem[0].z, smem[0].w)));
}
workgroupBarrier();

smem[index] = vec4<f32>(0.0);
for (var i: u32 = index; i < metadata.ND4; i += BLOCK_SIZE) {
smem[index] += exp(X[row_start + i] - maximum);
}

workgroupBarrier();
block_sum(index, 64u);
block_sum(index, 32u);
block_sum(index, 16u);
block_sum(index, 8u);
block_sum(index, 4u);
block_sum(index, 2u);
block_sum(index, 1u);

if index == 0u {
sum = dot(smem[0], vec4<f32>(1.0));
}
workgroupBarrier();

for(var i: u32 = index; i < metadata.ND4; i += BLOCK_SIZE) {
var val = X[row_start + i];
X[row_start + i] = exp(val - maximum) / sum;
}
}
8 changes: 7 additions & 1 deletion crates/ratchet-core/src/compiled_op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,18 @@ impl CompiledOp {
dst: &Tensor,
bind_group_layouts: RVec<BindGroupLayoutHandle>,
device: &WgpuDevice,
inplace: bool,
) -> RVec<GpuBindGroup> {
let mut bind_group_entries = drvec![];
for tensor in srcs.iter().chain(std::iter::once(&dst)) {

for tensor in srcs.iter() {
bind_group_entries.append(&mut tensor.bindings());
}

if !inplace {
bind_group_entries.append(&mut dst.bindings());
}

let mut storage_groups = rvec![];
for (group_index, bind_group_layout) in bind_group_layouts.iter().enumerate() {
let group_range = Self::group_range(group_index, bind_group_entries.len());
Expand Down
82 changes: 66 additions & 16 deletions crates/ratchet-core/src/gpu/buffer_allocator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,40 @@ impl BufferAllocator {
}
}

/// # Inplace operations
///
/// If an operation supports inplace, we need to "lease" the buffer
/// from the actual source (i.e the first non-inplace operation)
///
/// On what conditions do we terminate the upward traversal?
/// 1. We reach a constant
/// 2. We reach an operation that does not support inplace
/// 3. We reach an operation that has more than one consumer
/// 4. We reach an operation that has more than one source
fn traverse_upwards_for_inplace(source: &Tensor) -> &Tensor {
let mut true_source = source;
loop {
let is_const = true_source.op().is_const();
let cant_inplace = !true_source.op().supports_inplace();
let multiple_sources = true_source.op().srcs().len() > 1;
let multiple_consumers = false; //TODO: implement
if cant_inplace || multiple_sources || multiple_consumers || is_const {
break;
}

true_source = true_source.op().srcs()[0];
}
true_source
}

/// # Graph memory allocation
///
/// Simple greedy algorithm for allocating all required buffers to store
/// activations during an inference pass.
/// Simple greedy algorithm
/// 1. Iterate over all tensors in reverse order
/// 2. For each tensor, loop through it's sources
/// 3. Each source is inserted into the assignments.
/// 4. Release my output buffer (because we traverse in reverse order, when I arrive at myself,
/// my output buffer is no longer needed)
pub fn allocate_cfg(
&self,
execution_order: &[Tensor],
Expand All @@ -125,7 +155,7 @@ impl BufferAllocator {
let mut free = Vec::new(); //TODO: switch to BTreeMap
let mut assignments = FxHashMap::default();

for t in execution_order {
for t in execution_order.iter().rev() {
if t.resolved() {
assignments.insert(
t.id(),
Expand All @@ -134,12 +164,16 @@ impl BufferAllocator {
continue;
}

// I need all of my sources to be allocated in order to compute my output value.
// We "lease" the buffer, and it is released when we reach it in the execution order.
// If the current tensor is an inplace operation,
// we traverse upwards until we find a non-inplace operation.
for source in t.op().srcs() {
//Add support for inplace here when added
assignments.entry(source.id()).or_insert_with(|| {
let true_source = Self::traverse_upwards_for_inplace(source);
assignments.entry(true_source.id()).or_insert_with(|| {
self.graph_allocate(
BufferDescriptor::new(
source.num_bytes() as _,
true_source.num_bytes() as _,
BufferUsages::standard(),
false,
),
Expand All @@ -149,21 +183,37 @@ impl BufferAllocator {
});
}

//release my buffer
//My buffer is no longer needed, since we traverse in reverse order
//Earlier tensors can use my buffer
if let Some(buf) = assignments.get(&t.id()) {
free.push(buf.clone());
}
}
//Allocate for CFG output

//The output never gets allocated in the above loop, because it is not a source.
//We know we need an allocation for the output.
//We traverse upwards until we find the first non-inplace operation, and use it's buffer.
let output = execution_order.last().unwrap();
assignments.insert(
output.id(),
device.get_or_create_buffer(&BufferDescriptor {
size: output.num_bytes() as _,
usage: BufferUsages::standard(),
mapped_at_creation: false,
})?,
);
let output_source = Self::traverse_upwards_for_inplace(output);

//If output source is allocated, we can use it's buffer
//Otherwise, we need to allocate a new buffer
let output_buffer = assignments
.get(&output_source.id())
.cloned()
.unwrap_or_else(|| {
self.graph_allocate(
BufferDescriptor::new(
output_source.num_bytes() as _,
BufferUsages::standard(),
false,
),
&mut free,
device,
)
});
assignments.insert(output.id(), output_buffer);

Ok(assignments)
}
}
9 changes: 5 additions & 4 deletions crates/ratchet-core/src/gpu/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,16 @@ impl WgpuDevice {
let adapter = Self::select_adapter()?;

#[allow(unused_mut)]
let mut features = wgpu::Features::default();
let mut required_features = wgpu::Features::default();
#[cfg(feature = "gpu-profiling")]
{
features |= wgpu::Features::TIMESTAMP_QUERY;
}

let mut device_descriptor = wgpu::DeviceDescriptor {
label: Some("ratchet"),
features,
limits: Limits {
required_features,
required_limits: Limits {
max_buffer_size: MAX_BUFFER_SIZE,
max_storage_buffer_binding_size: MAX_BUFFER_SIZE as u32,
..Default::default()
Expand All @@ -77,7 +77,7 @@ impl WgpuDevice {
"Failed to acq. device, trying again with reduced limits: {:?}",
e
);
device_descriptor.limits = adapter.limits();
device_descriptor.required_limits = adapter.limits();
adapter.request_device(&device_descriptor, None).await
} else {
device_request
Expand Down Expand Up @@ -131,6 +131,7 @@ impl WgpuDevice {
let backends = wgpu::util::backend_bits_from_env().unwrap_or(wgpu::Backends::PRIMARY);
let adapter = instance
.enumerate_adapters(backends)
.into_iter()
.max_by_key(|adapter| match adapter.get_info().device_type {
DeviceType::DiscreteGpu => 5,
DeviceType::Other => 4,
Expand Down
Loading

0 comments on commit cc32e02

Please sign in to comment.