Skip to content

Commit

Permalink
Fix a bug caused low precision intermediates not working.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Sep 14, 2024
1 parent 6737683 commit c67441f
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 13 deletions.
4 changes: 2 additions & 2 deletions lib/nnc/mfa/ccv_nnc_mfa_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ void ccv_nnc_mfa_encode_attention(mfa::context* context, ccv_nnc_mfa_attention_p
}
AttentionDescriptor attentionDesc;
attentionDesc.lowPrecisionInputs = (params.data_type == MTL::DataTypeHalf) ? true : false;
attentionDesc.lowPrecisionIntermediates = false;
attentionDesc.lowPrecisionIntermediates = (params.data_type == MTL::DataTypeHalf && !hash.upcast) ? true : false;
attentionDesc.matrixDimensions[0] = hash.R;
attentionDesc.matrixDimensions[1] = hash.C;
attentionDesc.matrixDimensions[2] = hash.D;
Expand Down Expand Up @@ -146,7 +146,7 @@ void ccv_nnc_mfa_encode_attention(mfa::context* context, ccv_nnc_mfa_attention_p
encoder->setBuffer(tensors[2], tensor_offsets[2], AttentionOperand(AttentionOperand::V).bufferIndex());
if (attentionDesc.lowPrecisionInputs) {
encoder->setBuffer(scratch, 0, AttentionOperand(AttentionOperand::O).bufferIndex());
encoder->setBuffer(scratch, hash.R * hash.D * hash.Hq * attentionDesc.batchDimension, AttentionOperand(AttentionOperand::L).bufferIndex());
encoder->setBuffer(scratch, sizeof(float) * hash.R * hash.D * hash.Hq * attentionDesc.batchDimension, AttentionOperand(AttentionOperand::L).bufferIndex());
} else {
encoder->setBuffer(tensors[3], tensor_offsets[3], AttentionOperand(AttentionOperand::O).bufferIndex());
encoder->setBuffer(scratch, 0, AttentionOperand(AttentionOperand::L).bufferIndex());
Expand Down
12 changes: 6 additions & 6 deletions lib/nnc/mfa/v2/AttentionDescriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ bool AttentionDescriptor::operator==(const AttentionDescriptor& rhs) const {
return
batchDimension == rhs.batchDimension &&
Hq == rhs.Hq &&
Hk == rhs.Hk &&
(lowPrecisionInputs == rhs.lowPrecisionInputs) &&
(lowPrecisionIntermediates == rhs.lowPrecisionIntermediates) &&
simd_all(leadingDimensions.value_or(simd::uint4(UINT32_MAX)) == rhs.leadingDimensions.value_or(simd::uint4(UINT32_MAX))) &&
Expand All @@ -21,6 +22,7 @@ std::size_t std::hash<AttentionDescriptor>::operator()(const AttentionDescriptor
using namespace ccv::nnc::mfa::hash;
combine_32(seed, hash.batchDimension);
combine_32(seed, hash.Hq);
combine_32(seed, hash.Hk);
combine_32(seed, hash.matrixDimensions[0]);
combine_32(seed, hash.matrixDimensions[1]);
combine_32(seed, hash.matrixDimensions[2]);
Expand Down Expand Up @@ -116,9 +118,9 @@ AttentionKernelDescriptor AttentionDescriptor::kernelDescriptor(MTL::Device *con
};

if (device->supportsFamily(MTL::GPUFamily(1009))) {
return AttentionKernelDescriptor(createBlockDimensions(), createCacheState(), createHeadDimension(), createMemoryPrecisions(), true, false, createRegisterPrecisions(device), createTransposeState(), createLeadingDimensions(), type, scale);
return AttentionKernelDescriptor(createBlockDimensions(), createCacheState(), createHeadDimension(), Hq, Hk, createMemoryPrecisions(), true, false, createRegisterPrecisions(device), createTransposeState(), createLeadingDimensions(), type, scale);
} else {
return AttentionKernelDescriptor(createBlockDimensions(), createCacheState(), createHeadDimension(), createMemoryPrecisions(), false, true, createRegisterPrecisions(device), createTransposeState(), createLeadingDimensions(), type, scale);
return AttentionKernelDescriptor(createBlockDimensions(), createCacheState(), createHeadDimension(), Hq, Hk, createMemoryPrecisions(), false, true, createRegisterPrecisions(device), createTransposeState(), createLeadingDimensions(), type, scale);
}
}

Expand All @@ -130,10 +132,8 @@ std::pair<AttentionKernelDescriptor, PipelineValue<AttentionKernel> *> Attention
(MTL::FunctionConstantValues::alloc()->init());
uint32_t rowDimension = matrixDimensions[0];
uint32_t columnDimension = matrixDimensions[1];
uint32_t Hq = this->Hq;
constants->setConstantValue(&rowDimension, MTL::DataTypeUInt, NS::Integer(0));
constants->setConstantValue(&columnDimension, MTL::DataTypeUInt, 1);
constants->setConstantValue(&Hq, MTL::DataTypeUInt, 2);
std::vector<AttentionOperand> operands;
switch (type.value) {
case AttentionKernelType::forward:
Expand All @@ -148,7 +148,7 @@ std::pair<AttentionKernelDescriptor, PipelineValue<AttentionKernel> *> Attention
}
for (const auto& operand : operands) {
uint32_t batchStride = batchStrides[operand].value_or(0);
constants->setConstantValue(&batchStride, MTL::DataTypeUInt, 3 + operand.bufferIndex());
constants->setConstantValue(&batchStride, MTL::DataTypeUInt, 2 + operand.bufferIndex());
}

NS::String* swiftName = NS::String::string("attention", NS::UTF8StringEncoding);
Expand Down Expand Up @@ -399,7 +399,7 @@ AttentionOperands<GEMMOperandPrecision> AttentionDescriptor::createRegisterPreci
// MARK: - AttentionDescriptor+Parameters

std::vector<AttentionParameterRow> AttentionDescriptor::parameterFile(AttentionKernelType type, MTL::Device *const device) const noexcept {
if (lowPrecisionInputs && lowPrecisionIntermediates) {
if (lowPrecisionInputs || lowPrecisionIntermediates) {
switch (type.value) {
case AttentionKernelType::forward:
return forwardMixed(device);
Expand Down
3 changes: 3 additions & 0 deletions lib/nnc/mfa/v2/AttentionDescriptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ struct AttentionDescriptor {
/// The number of query heads per sequence that run in parallel.
unsigned short Hq = 1;

/// The number of key / value heads per sequence that run in parallel.
unsigned short Hk = 1;

/// Q, K, V, dO
bool lowPrecisionInputs;

Expand Down
7 changes: 4 additions & 3 deletions lib/nnc/mfa/v2/AttentionKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ AttentionKernel::AttentionKernel(AttentionKernelDescriptor descriptor, MTL::Devi

blockDimensions = descriptor.blockDimensions;
headDimension = descriptor.headDimension;
Hq = descriptor.Hq;
Hk = descriptor.Hk;
leadingDimensions = descriptor.leadingDimensions;

scale = descriptor.scale;
Expand Down Expand Up @@ -458,10 +460,10 @@ std::string AttentionKernel::createConstants() const noexcept {
operands = {AttentionOperand::Q, AttentionOperand::K, AttentionOperand::V, AttentionOperand::O, AttentionOperand::dO, AttentionOperand::dV, AttentionOperand::dK};
break;
}
std::string output = "";
std::string output = "#define Hq (" + std::to_string(Hq) + ")\n";
for (const auto& operand : operands) {
output += " constant uint " + operand.name() + "_batch_stride [[function_constant(";
output += std::to_string(operand.bufferIndex() + 3) + ")]];\n";
output += std::to_string(operand.bufferIndex() + 2) + ")]];\n";
}
return R"(
Expand All @@ -470,7 +472,6 @@ std::string AttentionKernel::createConstants() const noexcept {
// Hq = number of query heads.
constant uint R [[function_constant(0)]];
constant uint C [[function_constant(1)]];
constant uint Hq [[function_constant(2)]];
)" + output;
}
Expand Down
4 changes: 4 additions & 0 deletions lib/nnc/mfa/v2/AttentionKernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ struct AttentionKernel {

unsigned short headDimension;

unsigned short Hq;

unsigned short Hk;

unsigned short threadgroupMemoryAllocation;

/// The number of threads per group.
Expand Down
6 changes: 5 additions & 1 deletion lib/nnc/mfa/v2/AttentionKernelDescriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ bool AttentionKernelDescriptor::operator==(const AttentionKernelDescriptor& rhs)
simd_all(blockDimensions == rhs.blockDimensions) &&
cacheState == rhs.cacheState &&
headDimension == rhs.headDimension &&
Hq == rhs.Hq && Hk == rhs.Hk &&
memoryPrecisions == rhs.memoryPrecisions &&
(preferAsyncCache == rhs.preferAsyncCache) &&
(preferAsyncLoad == rhs.preferAsyncLoad) &&
Expand All @@ -23,16 +24,19 @@ std::size_t std::hash<AttentionKernelDescriptor>::operator()(const AttentionKern
using namespace ccv::nnc::mfa::hash;
combine_64(seed, pack_64(simd_make_ushort4(hash.blockDimensions, 0)));
combine_32(seed, pack_32(simd::ushort2 { hash.headDimension, hash.type.value }));
combine_32(seed, pack_32(simd::ushort2 { hash.Hq, hash.Hk }));
combine_32(seed, pack_32(simd::uchar4 { hash.preferAsyncCache, hash.preferAsyncLoad, 0, 0 }));
return 0;
}

// MARK: - Initializer

AttentionKernelDescriptor::AttentionKernelDescriptor(simd::ushort3 blockDimensions, AttentionOperands<bool> cacheState, unsigned short headDimension, AttentionOperands<GEMMOperandPrecision> memoryPrecisions, bool preferAsyncCache, bool preferAsyncLoad, AttentionOperands<GEMMOperandPrecision> registerPrecisions, AttentionOperands<bool> transposeState, AttentionOperands<unsigned short> leadingDimensions, AttentionKernelType type, float scale) noexcept {
AttentionKernelDescriptor::AttentionKernelDescriptor(simd::ushort3 blockDimensions, AttentionOperands<bool> cacheState, unsigned short headDimension, unsigned short Hq, unsigned short Hk, AttentionOperands<GEMMOperandPrecision> memoryPrecisions, bool preferAsyncCache, bool preferAsyncLoad, AttentionOperands<GEMMOperandPrecision> registerPrecisions, AttentionOperands<bool> transposeState, AttentionOperands<unsigned short> leadingDimensions, AttentionKernelType type, float scale) noexcept {
this->blockDimensions = blockDimensions;
this->cacheState = cacheState;
this->headDimension = headDimension;
this->Hq = Hq;
this->Hk = Hk;
this->memoryPrecisions = memoryPrecisions;
this->preferAsyncCache = preferAsyncCache;
this->preferAsyncLoad = preferAsyncLoad;
Expand Down
6 changes: 5 additions & 1 deletion lib/nnc/mfa/v2/AttentionKernelDescriptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ struct AttentionKernelDescriptor {
/// Required. The problem size along the head dimension.
unsigned short headDimension;

unsigned short Hq;

unsigned short Hk;

AttentionOperands<GEMMOperandPrecision> memoryPrecisions;

/// Reads with a one-to-one mapping to threads (like GEMM store) and writes.
Expand Down Expand Up @@ -58,7 +62,7 @@ struct AttentionKernelDescriptor {
AttentionKernelDescriptor() = delete;

/// Initialize the kernel descriptor.
AttentionKernelDescriptor(simd::ushort3 blockDimensions, AttentionOperands<bool> cacheState, unsigned short headDimension, AttentionOperands<GEMMOperandPrecision> memoryPrecisions, bool preferAsyncCache, bool preferAsyncLoad, AttentionOperands<GEMMOperandPrecision> registerPrecisions, AttentionOperands<bool> transposeState, AttentionOperands<unsigned short> leadingDimensions, AttentionKernelType type, float scale) noexcept;
AttentionKernelDescriptor(simd::ushort3 blockDimensions, AttentionOperands<bool> cacheState, unsigned short headDimension, unsigned short Hq, unsigned short Hk, AttentionOperands<GEMMOperandPrecision> memoryPrecisions, bool preferAsyncCache, bool preferAsyncLoad, AttentionOperands<GEMMOperandPrecision> registerPrecisions, AttentionOperands<bool> transposeState, AttentionOperands<unsigned short> leadingDimensions, AttentionKernelType type, float scale) noexcept;

bool operator==(const AttentionKernelDescriptor& rhs) const;
};
Expand Down

0 comments on commit c67441f

Please sign in to comment.