diff --git a/aten/src/ATen/native/mps/kernels/LinearAlgebra.metal b/aten/src/ATen/native/mps/kernels/LinearAlgebra.metal index f5f1eb2fe9a384..85b82e3acd6ef8 100644 --- a/aten/src/ATen/native/mps/kernels/LinearAlgebra.metal +++ b/aten/src/ATen/native/mps/kernels/LinearAlgebra.metal @@ -1,5 +1,4 @@ #include -#include using namespace metal; template @@ -32,271 +31,6 @@ kernel void naive_matmul( outputData[x * strides[2].x + y * strides[2].y] = rc; } -inline float blockReduceSum( - threadgroup float* sharedScratch, - float val, - uint tid, - uint tpg) { - sharedScratch[tid] = val; - threadgroup_barrier(mem_flags::mem_threadgroup); - - for (uint offset = tpg >> 1; offset > 0; offset >>= 1) { - if (tid < offset) { - sharedScratch[tid] += sharedScratch[tid + offset]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - } - - return sharedScratch[0]; -} - -kernel void factorDiagonalBlock( - device float* A [[buffer(0)]], - device int* success [[buffer(1)]], - constant uint& N [[buffer(2)]], - constant uint& NB [[buffer(3)]], - constant uint& k [[buffer(4)]], - uint tid [[thread_position_in_threadgroup]], - uint bid [[threadgroup_position_in_grid]], - uint tpg [[threads_per_threadgroup]]) { - const uint actSize = min(N - k * NB, NB); // uint64 before NB - const uint batch_offset = bid * N * N; - - const uint row0 = k * NB; - const uint col0 = k * NB; - - threadgroup float tile[32][33]; - threadgroup float reduceScratch[256]; - const uint tileSize = actSize * actSize; - - for (uint i = tid; i < tileSize; i += tpg) { - uint r = i / actSize; - uint c = i % actSize; - tile[r][c] = A[batch_offset + (row0 + r) * N + (col0 + c)]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - for (uint kk = 0; kk < actSize; kk++) { - float diagElt = 0.0f; - if (kk > 0) { - float partialSum = 0.0f; - for (uint i = tid; i < kk; i += tpg) { - float val = tile[kk][i]; - partialSum = fma(val, val, partialSum); - } - diagElt = blockReduceSum(reduceScratch, partialSum, tid, tpg); - } - - if (tid == 0) { - float diagVal = tile[kk][kk] - diagElt; - // Check for positive definiteness - if (diagVal <= 0.0f) { - success[bid] = 0; // matrix is not positive definite - return; - } - tile[kk][kk] = sqrt(diagVal); - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - float pivot = tile[kk][kk]; - - for (uint j = kk + 1 + tid; j < actSize; j += tpg) { - float partialSum = 0.0f; - for (uint i = 0; i < kk; i++) { - partialSum = fma(tile[j][i], tile[kk][i], partialSum); - } - - float val = tile[j][kk]; - val -= partialSum; - val /= pivot; - tile[j][kk] = val; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - } - - for (uint i = tid; i < tileSize; i += tpg) { - uint r = i / actSize; - uint c = i % actSize; - A[batch_offset + (row0 + r) * N + (col0 + c)] = tile[r][c]; - } -} - -kernel void applyTRSM( - device float* A [[buffer(0)]], - constant uint& N [[buffer(2)]], - constant uint& NB [[buffer(3)]], - constant uint& k [[buffer(4)]], - uint3 tid [[thread_position_in_threadgroup]], - uint3 tgid [[threadgroup_position_in_grid]], - uint3 tpg [[threads_per_threadgroup]]) { - uint b = tgid.x; - uint idxJ = tgid.y; - - const uint actSize_k = uint(min(int64_t(N - k * NB), int64_t(NB))); - const uint batch_offset = b * N * N; - const uint j = (k + 1) + idxJ; - - uint row0 = j * NB; - uint col0 = k * NB; - - uint actSize_j = (uint)min((int)(N - row0), (int)NB); - if (actSize_k == 0 || actSize_j == 0) { - return; - } - if (j == k) { - return; - } - - threadgroup float diag[32 * 32]; - threadgroup float target[32 * 32]; - - for (uint i = tid.x; i < actSize_k * actSize_k; i += tpg.x) { - uint r = i / actSize_k; - uint c = i % actSize_k; - diag[i] = A[batch_offset + (k * NB + r) * N + (k * NB + c)]; - } - for (uint i = tid.x; i < actSize_j * actSize_k; i += tpg.x) { - uint r = i / actSize_k; - uint c = i % actSize_k; - target[i] = A[batch_offset + (row0 + r) * N + (col0 + c)]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - for (uint col = 0; col < actSize_k; col++) { - float diag_val = diag[col * actSize_k + col]; - if (abs(diag_val) < 1e-6f) { - diag_val = (diag_val < 0.0f) ? -1e-6f : 1e-6f; - } - - for (uint row = tid.x; row < actSize_j; row += tpg.x) { - float sum = target[row * actSize_k + col]; - - // kahan sum - float c = 0.0f; - for (uint p = 0; p < col; p++) { - float y = -target[row * actSize_k + p] * diag[col * actSize_k + p] - c; - float t = sum + y; - c = (t - sum) - y; - sum = t; - } - - target[row * actSize_k + col] = sum / diag_val; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - } - - for (uint i = tid.x; i < actSize_j * actSize_k; i += tpg.x) { - uint r = i / actSize_k; - uint c = i % actSize_k; - A[batch_offset + (row0 + r) * N + (col0 + c)] = target[i]; - } -} - -kernel void applySYRK( - device float* A [[buffer(0)]], - constant uint& N [[buffer(2)]], - constant uint& NB [[buffer(3)]], - constant uint& k [[buffer(4)]], - uint3 tid [[thread_position_in_threadgroup]], - uint3 tgid [[threadgroup_position_in_grid]], - uint3 tpg [[threads_per_threadgroup]]) { - uint b = tgid.x; - uint pairID = tgid.y; - - uint jRel = (-1 + sqrt(1 + 8 * float(pairID))) / 2; - uint hRel = pairID - (jRel * (jRel + 1) >> 1); - - const uint startJ = (k + 1); - uint j = startJ + jRel; - uint h = startJ + hRel; - uint row0 = j * NB; - uint col0 = h * NB; - - const uint actSize_k = uint(min(int64_t(N - k * NB), int64_t(NB))); - const uint actSize_j = min((uint)(N - row0), NB); - const uint actSize_h = min((uint)(N - col0), NB); - const uint batch_offset = b * N * N; - - if (actSize_j == 0 || actSize_h == 0 || actSize_k == 0) - return; - - threadgroup float left[32 * 33]; - threadgroup float right_t[32 * 33]; - threadgroup float tile[32 * 33]; - - const uint threads = min(tpg.x, actSize_j * actSize_k); - - for (uint i = tid.x; i < actSize_j * actSize_k; i += threads) { - uint r = i / actSize_k; - uint c = i % actSize_k; - left[r * actSize_k + c] = A[batch_offset + (j * NB + r) * N + (k * NB + c)]; - } - - for (uint i = tid.x; i < actSize_h * actSize_k; i += threads) { - uint r = i / actSize_k; - uint c = i % actSize_k; - right_t[c * actSize_h + r] = - A[batch_offset + (h * NB + r) * N + (k * NB + c)]; - } - - for (uint i = tid.x; i < actSize_j * actSize_h; i += threads) { - uint r = i / actSize_h; - uint c = i % actSize_h; - tile[r * actSize_h + c] = A[batch_offset + (row0 + r) * N + (col0 + c)]; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - for (uint idx = tid.x; idx < actSize_j * actSize_h; idx += threads) { - uint r = idx / actSize_h; - uint c = idx % actSize_h; - - if ((j == h) && (r < c)) - continue; - - uint tile_idx = r * actSize_h + c; - float sum = tile[tile_idx]; - - uint left_row = r * actSize_k; - uint right_col = c; - - uint k = 0; - float4 sum4 = {0.0f, 0.0f, 0.0f, 0.0f}; - - for (; k + 4 <= actSize_k; k += 4) { - float4 left4 = { - left[left_row + k], - left[left_row + k + 1], - left[left_row + k + 2], - left[left_row + k + 3]}; - - float4 right4 = { - right_t[(k + 0) * actSize_h + right_col], - right_t[(k + 1) * actSize_h + right_col], - right_t[(k + 2) * actSize_h + right_col], - right_t[(k + 3) * actSize_h + right_col]}; - - sum4 = fma(left4, right4, sum4); - } - - sum -= dot(sum4, 1.0); - - for (; k < actSize_k; k++) { - sum = fma(-left[left_row + k], right_t[k * actSize_h + right_col], sum); - } - - tile[tile_idx] = sum; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - for (uint i = tid.x; i < actSize_j * actSize_h; i += threads) { - uint r = i / actSize_h; - uint c = i % actSize_h; - A[batch_offset + (row0 + r) * N + (col0 + c)] = tile[r * actSize_h + c]; - } -} - #define INSTANTIATE_NAIVE_MM(DTYPE) \ template [[host_name("naive_matmul_" #DTYPE)]] kernel void \ naive_matmul( \ diff --git a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm index 32e983238dccde..fe77c1936a21af 100644 --- a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm +++ b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm @@ -18,8 +18,6 @@ #include #include #include -#include -#include #include #include #include @@ -782,83 +780,6 @@ static void linalg_lu_factor_out_mps_impl(const Tensor& A, bool pivot, Tensor& L return out; } -static Tensor& linalg_cholesky_mps_impl(const Tensor& input, bool upper, Tensor& out) { - using namespace mps; - - TORCH_CHECK(out.is_mps()); - TORCH_CHECK(input.scalar_type() == at::ScalarType::Float, "linalg.cholesky: Input tensor must be float32"); - TORCH_CHECK(input.dim() >= 2, "linalg.cholesky: Input tensor must be at least 2D"); - TORCH_CHECK(input.size(-2) == input.size(-1), "linalg.cholesky: Input tensor must be square"); - - if (input.numel() == 0 || out.numel() == 0) { - out.zero_(); - return out; - } - resize_output(out, input.sizes()); - out.copy_(input); - - int64_t ndim = out.dim(); - int64_t N = out.size(-1); - int64_t B = 1; - for (int64_t i = 0; i < ndim - 2; i++) { - B *= out.size(i); - } - - auto stream = getCurrentMPSStream(); - auto device = MPSDevice::getInstance()->device(); - - auto factorDiagonalPSO = lib.getPipelineStateForFunc("factorDiagonalBlock"); - auto applyTRSMPSO = lib.getPipelineStateForFunc("applyTRSM"); - auto applySYRKPSO = lib.getPipelineStateForFunc("applySYRK"); - - int64_t NB = std::min(32, N); - int64_t numBlocks = (N + NB - 1) / NB; - - Tensor success = at::empty({B}, input.options().dtype(kInt)).fill_(1); - id successBuffer = getMTLBufferStorage(success); - - MTLSize threadGroupSize = MTLSizeMake(256, 1, 1); - id outBuffer = getMTLBufferStorage(out); - id computeEncoder = stream->commandEncoder(); - [computeEncoder setBuffer:outBuffer offset:0 atIndex:0]; - [computeEncoder setBytes:&N length:sizeof(int64_t) atIndex:2]; - [computeEncoder setBytes:&NB length:sizeof(int64_t) atIndex:3]; - - @autoreleasepool { - dispatch_sync_with_rethrow(stream->queue(), ^() { - for (int64_t k = 0; k < numBlocks; k++) { - [computeEncoder setComputePipelineState:factorDiagonalPSO]; - [computeEncoder setBuffer:successBuffer offset:0 atIndex:1]; - [computeEncoder setBytes:&k length:sizeof(int64_t) atIndex:4]; - MTLSize gridSize = MTLSizeMake(B, 1, 1); - [computeEncoder dispatchThreadgroups:gridSize threadsPerThreadgroup:threadGroupSize]; - - // process all remaining blocks in this row/column in parallel - if (k < numBlocks - 1) { - int64_t startJ = k + 1; - int64_t nBlocksJ = (numBlocks - startJ); - - if (nBlocksJ > 0) { - // TRSM for all blocks in parallel - MTLSize trsmGridSize = MTLSizeMake(B, nBlocksJ, 1); - [computeEncoder setComputePipelineState:applyTRSMPSO]; - [computeEncoder dispatchThreadgroups:trsmGridSize threadsPerThreadgroup:threadGroupSize]; - - // SYRK for all independent block pairs in parallel - uint32_t nPairs = nBlocksJ * (nBlocksJ + 1) / 2; - MTLSize syrkGridSize = MTLSizeMake(B, nPairs, 1); - [computeEncoder setComputePipelineState:applySYRKPSO]; - [computeEncoder dispatchThreadgroups:syrkGridSize threadsPerThreadgroup:threadGroupSize]; - } - } - } - }); - } - - TORCH_CHECK(success.all().item(), "linalg.cholesky: Input matrix is not positive definite"); - out.tril_(); // - return upper ? out.transpose_(ndim - 2, ndim - 1) : out; -} } // namespace mps Tensor addr_mps(const Tensor& self, const Tensor& vec1, const Tensor& vec2, const Scalar& beta, const Scalar& alpha) { @@ -1019,25 +940,6 @@ Tensor addr_mps(const Tensor& self, const Tensor& vec1, const Tensor& vec2, cons return result; } -Tensor cholesky_mps(const Tensor& self, bool upper) { - auto out = at::empty_like(self, MemoryFormat::Contiguous); - mps::linalg_cholesky_mps_impl(self, upper, out); - return out; -} - -Tensor& cholesky_mps_out(const Tensor& self, bool upper, Tensor& out) { - return mps::linalg_cholesky_mps_impl(self, upper, out); -} - -Tensor& linalg_cholesky_out_mps(const Tensor& self, bool upper, Tensor& out) { - return mps::linalg_cholesky_mps_impl(self, upper, out); -} - -Tensor linalg_cholesky_mps(const Tensor& self, bool upper) { - auto out = at::empty_like(self, MemoryFormat::Contiguous); - return mps::linalg_cholesky_mps_impl(self, upper, out); -} - Tensor addbmm_mps(const Tensor& self, const Tensor& batch1, const Tensor& batch2, diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index b06bc10dceadfe..0280119f1cc08a 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -9470,13 +9470,11 @@ - func: cholesky.out(Tensor self, bool upper=False, *, Tensor(a!) out) -> Tensor(a!) dispatch: CPU, CUDA: cholesky_out - MPS: cholesky_mps_out - func: cholesky(Tensor self, bool upper=False) -> Tensor variants: method, function dispatch: CPU, CUDA: cholesky - MPS: cholesky_mps - func: cholesky_solve.out(Tensor self, Tensor input2, bool upper=False, *, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -13938,15 +13936,9 @@ - func: linalg_cholesky(Tensor self, *, bool upper=False) -> Tensor python_module: linalg - dispatch: - CompositeImplicitAutograd: linalg_cholesky - MPS: linalg_cholesky_mps - func: linalg_cholesky.out(Tensor self, *, bool upper=False, Tensor(a!) out) -> Tensor(a!) python_module: linalg - dispatch: - CompositeImplicitAutograd: linalg_cholesky_out - MPS: linalg_cholesky_out_mps - func: linalg_cross(Tensor self, Tensor other, *, int dim=-1) -> Tensor python_module: linalg diff --git a/test/test_mps.py b/test/test_mps.py index f5dab3d1f84b8e..5b53169edcb48e 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -677,6 +677,7 @@ def mps_ops_modifier(ops): '__rsub__': None, 'cauchy_': None, 'cauchy': None, + 'cholesky': None, 'cholesky_inverse': None, 'cholesky_solve': None, 'cummax': None, @@ -696,6 +697,7 @@ def mps_ops_modifier(ops): 'index_reduceamin': None, 'kthvalue': None, 'lcm': None, + 'linalg.cholesky': None, 'linalg.cholesky_ex': None, 'linalg.cond': None, 'linalg.detsingular': None, @@ -6391,30 +6393,6 @@ def test_sort(self): atol=0, rtol=0 ) - def test_cholesky(self): - from torch.testing._internal.common_utils import random_hermitian_pd_matrix - - def run_cholesky_test(size, *batch_dims, upper): - input_cpu = random_hermitian_pd_matrix(size, *batch_dims, dtype=torch.float32, device="cpu") - input_mps = input_cpu.to('mps') - output_cpu = torch.linalg.cholesky(input_cpu, upper=upper) - output_mps = torch.linalg.cholesky(input_mps, upper=upper) - self.assertEqual(output_cpu, output_mps, atol=2e-5, rtol=1e-6) - - # test with different even/odd matrix sizes - matrix_sizes = [1, 2, 3, 4, 8, 17, 64, 128, 154] - # even/odd batch sizes - batch_sizes = [1, 2, 4, 8, 16, 17] - - for upper in [True, False]: - for size in matrix_sizes: - for batch_size in batch_sizes: - run_cholesky_test(size, batch_size, upper=upper) - - # test >3D matrices - run_cholesky_test(128, 10, 10, upper=False) - run_cholesky_test(128, 2, 2, 2, 2, 10, 10, upper=True) - def test_upsample_nearest2d(self): def helper(N, C, H, W, memory_format): inputCPU = torch.arange(N * C * H * W, device='cpu', dtype=torch.float, diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 595f2757d5cfe2..fa77b906b1b457 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -410,11 +410,6 @@ self: cholesky_backward(grad, upper, L) L: cholesky_jvp(self_t, L, upper) -# temporarily here before linalg_cholesky dispatches to linalg_cholesky_ex on MPS device -- name: linalg_cholesky(Tensor self, *, bool upper=False) -> Tensor - self: cholesky_backward(grad, upper, result) - result: cholesky_jvp(self_t, result, upper) - - name: cholesky_solve(Tensor self, Tensor input2, bool upper=False) -> Tensor self, input2: cholesky_solve_backward(grad, self, input2, result, upper, grad_input_mask) result: cholesky_solve_jvp(result, input2_p, input2_t, self_t, upper) diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index c4bc4848079853..8a9a00528a6c77 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -2646,7 +2646,6 @@ def is_aligned(x): make_fallback(aten.linalg_pinv.atol_rtol_tensor) make_fallback(aten._linalg_eigh) make_fallback(aten.triangular_solve) -make_fallback(aten.linalg_cholesky) make_fallback(aten.linalg_cholesky_ex) make_fallback(aten.cholesky_inverse) make_fallback(aten.cholesky_solve)