diff --git a/dask_distance/__init__.py b/dask_distance/__init__.py index f055f5c..fd7c4bf 100644 --- a/dask_distance/__init__.py +++ b/dask_distance/__init__.py @@ -216,27 +216,23 @@ def squareform(X, force="no"): X_tri = [] j1 = 0 - for j2 in _pycompat.irange(d - 1, -1, -1): + for j2 in _pycompat.irange(d - 1, 0, -1): X_tri.append(X[j1:j1 + j2]) j1 += j2 - z = dask.array.zeros((1,), dtype=X.dtype, chunks=(1,)) - - result = [] - for i in range(d): - col_i = [] - - for j in range(i): - i_j = i - j - col_i.append(X_tri[j][i_j - 1:i_j]) - col_i.append(z) - col_i.append(X_tri[i]) - - result.append(dask.array.concatenate([ - a for a in col_i if a.size - ])) - - result = dask.array.stack(result) + z = dask.array.zeros((1, 1), dtype=X.dtype, chunks=(1, 1)) + + result = z + for i in _pycompat.irange(d - 2, -1, -1): + X_tri_i = X_tri[i] + result = result.rechunk(2 * X_tri_i.chunks) + result = dask.array.concatenate( + [ + dask.array.concatenate([z, X_tri_i[None]], axis=1), + dask.array.concatenate([X_tri_i[:, None], result], axis=1) + ], + axis=0 + ) elif conv == "tovec": result = [ X[i, i + 1:] for i in range(0, len(X) - 1)