Skip to content

Commit

Permalink
Merge pull request #82 from MPI-Dortmund/write_positions_into_umap
Browse files Browse the repository at this point in the history
Write umap embeddings metadata + positions  into umap attrs and remove label mask calculation
  • Loading branch information
thorstenwagner authored Dec 13, 2023
2 parents ba1c511 + 3311ff2 commit ee801ff
Showing 1 changed file with 6 additions and 34 deletions.
40 changes: 6 additions & 34 deletions tomotwin/modules/tools/umap.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
except ImportError:
print("cuml can't be loaded")

import mrcfile
import numpy as np
import pandas as pd
from numpy.typing import ArrayLike
Expand Down Expand Up @@ -98,28 +97,6 @@ def calcuate_umap(

return embedding, reducer

def create_embedding_mask(self, embeddings: pd.DataFrame):
"""
Creates mask where each individual subvolume of the running windows gets an individual ID
"""
print("Create embedding mask")
Z = embeddings.attrs["tomogram_input_shape"][0]
Y = embeddings.attrs["tomogram_input_shape"][1]
X = embeddings.attrs["tomogram_input_shape"][2]
stride = embeddings.attrs["stride"][0]
segmentation_array = np.zeros(shape=(Z, Y, X), dtype=np.float32)
z = np.array(embeddings["Z"], dtype=int)
y = np.array(embeddings["Y"], dtype=int)
x = np.array(embeddings["X"], dtype=int)

values = np.array(range(1, len(x) + 1))
for stride_x in tqdm(list(range(stride))):
for stride_y in range(stride):
for stride_z in range(stride):
index = (z + stride_z, y + stride_y, x + stride_x)
segmentation_array[index] = values

return segmentation_array

def run(self, args):
print("Read data")
Expand All @@ -144,23 +121,18 @@ def run(self, args):
os.makedirs(out_pth,exist_ok=True)
fname = os.path.splitext(os.path.basename(args.input))[0]
df_embeddings = pd.DataFrame(umap_embeddings)
df_embeddings.reset_index(drop=True, inplace=True)
embeddings.reset_index(drop=True, inplace=True)

print("Write embeedings to disk")
df_embeddings.columns = [f"umap_{i}" for i in range(umap_embeddings.shape[1])]
df_embeddings = pd.concat([embeddings[['X', 'Y', 'Z']], df_embeddings], axis=1)
df_embeddings.attrs['embeddings_attrs'] = embeddings.attrs
df_embeddings.attrs['embeddings_path'] = os.path.realpath(args.input)

df_embeddings.to_pickle(os.path.join(out_pth,fname+".tumap"))

print("Write umap model to disk")
pickle.dump(fitted_umap, open(os.path.join(out_pth, fname + "_umap_model.pkl"), "wb"))

print("Calculate label mask and write it to disk")
embedding_mask = self.create_embedding_mask(embeddings)
with mrcfile.new(
os.path.join(
args.output,
fname + "_label_mask.mrci",
),
overwrite=True,
) as mrc:
mrc.set_data(embedding_mask)

print("Done")

0 comments on commit ee801ff

Please sign in to comment.