Skip to content

Commit

Permalink
Merge pull request #5 from TravisWheelerLab/memory-options
Browse files Browse the repository at this point in the history
Draft: Hybrid storage option
  • Loading branch information
isaacrobinson2000 authored Jan 23, 2024
2 parents ab19010 + c6440cb commit dd406d1
Show file tree
Hide file tree
Showing 17 changed files with 384 additions and 261 deletions.
9 changes: 8 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,11 @@ docs/source/api
# Virtual environment for docs if being used...
docs/.venv
docs/venv/
docs/ENV/
docs/ENV/

# Ignore pycache files...
**/__pycache__/

.DS_Store

*.zip
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ DIPLOMAT has documentation on ReadTheDocs at [https://diplomat.readthedocs.io/en
## Development

DIPLOMAT is written entirely in python. To set up an environment for developing DIPLOMAT, you can simply pull down this repository and install its
requirements using poetry. For a further description of how to set up DIPLOMAT for development, see the
requirements using pip. For a further description of how to set up DIPLOMAT for development, see the
[Development Usage](https://diplomat.readthedocs.io/en/latest/advanced_usage.html#development-usage) section in the documentation.

## Contributing
Expand Down
9 changes: 5 additions & 4 deletions RELEASE_NOTES.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
Changes for this version of DIPLOMAT:
- Fixed installation process for DIPLOMAT with SLEAP on windows by adding keras dependency.
- Added support for changing some UI appearance settings to DIPLOMAT's supervised and tweak UI.
- Make SLEAP frontend error out unless a user explicitly passes a number of outputs parameter.
- Improved point rendering in the UI (proper alpha transparency support and smoother point rendering).
- Added new memory mode hybrid and set it to the default
- Added a save to disk button to the UI which saves frames to disk when in memory mode
- New configuration file with fixes for the CPU conda environment
- Rename CLI commands from supervised / unsupervised / track / restore to track_and_interact / track / track_with / interact
- Minor UI bug fixes
10 changes: 1 addition & 9 deletions conda-environments/DIPLOMAT-SLEAP-CPU.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ name: DIPLOMAT-SLEAP

channels:
- conda-forge
- nvidia
- anaconda

dependencies:
Expand Down Expand Up @@ -35,15 +34,8 @@ dependencies:
- conda-forge::scikit-learn ==1.0
- conda-forge::scikit-video
- conda-forge::seaborn
- tensorflow >=2.6.0,<2.11 # No windows GPU support for >2.10
- tensorflow-hub # Pinned in meta.yml, but no problems here... yet
- keras

# Packages required by tensorflow to find/use GPUs
- conda-forge::cudatoolkit ==11.3.1
# "==" results in package not found
- conda-forge::cudnn=8.2.1
- nvidia::cuda-nvcc=11.3
- pip:
- sleap==1.3.3
- diplomat-track
- diplomat-track
2 changes: 1 addition & 1 deletion diplomat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
A tool providing multi-animal tracking capabilities on top of other Deep learning based tracking software.
"""

__version__ = "0.0.9"
__version__ = "0.1.0"
# Can be used by functions to determine if diplomat was invoked through it's CLI interface.
CLI_RUN = False

Expand Down
1 change: 1 addition & 0 deletions diplomat/core_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,3 +545,4 @@ def interact(
start_time=start_time,
end_time=end_time
)

2 changes: 1 addition & 1 deletion diplomat/predictors/fpe/arr_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def to_keys(_x, _y):
# We perform a 3x3 max-convolution to find peaks.
for i in range(-1, 2):
for j in range(-1, 2):
if (i == 0 and j == 0):
if(i == 0 and j == 0):
continue
neighbor = lookup_table[to_keys(x + j, y + i)]
below_to_right = (i >= 0) & (j >= 0)
Expand Down
2 changes: 1 addition & 1 deletion diplomat/predictors/fpe/fpe_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def print_test_set():

for i, s in enumerate(TEST_FRAME_SEQUENCES):
print(f"Test Frame Sequence {i}: ")
for j, frm in enumerate(extract_frames.unpack_frame_string(s, 1)):
for j, frm in enumerate(extract_frames.unpack_frame_string(s, 1)[1]):
for bp_idx in range(frm.get_bodypart_count()):
print(f"Frame {j} Body Part {bp_idx}")
extract_frames.pretty_print_frame(frm, 0, bp_idx)
Expand Down
57 changes: 31 additions & 26 deletions diplomat/predictors/sfpe/segmented_frame_pass_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,33 @@ def __init__(

self._current_frame = 0

def get_frame_holder(self):
output_path = Path(self.video_metadata["output-file-path"]).resolve()
video_path = Path(self.video_metadata["orig-video-path"]).resolve()
disk_path = output_path.parent / (output_path.stem + ".dipui")

self._file_obj = disk_path.open("w+b")

with video_path.open("rb") as f:
shutil.copyfileobj(f, self._file_obj)

ctx = PoolWithProgress.get_optimal_ctx()
self._manager = ctx.Manager()
self._shared_memory = allocate_shared_memory(
ctx, DiskBackedForwardBackwardData.get_shared_memory_size(self.num_frames, self._num_total_bp)
)

_frame_holder = DiskBackedForwardBackwardData(
self.num_frames,
self._num_total_bp,
self._file_obj,
self.settings.memory_cache_size,
lock=self._manager.RLock(),
memory_backing=self._shared_memory
)

return _frame_holder

def _open(self):
if(self._restore_path is not None):
# Ignore everything else,
Expand Down Expand Up @@ -479,32 +506,10 @@ def _open(self):

self._segments = np.array(self._frame_holder.metadata["segments"], dtype=np.int64)
self._segment_scores = np.array(self._frame_holder.metadata["segment_scores"], dtype=np.float32)
elif(self.settings.storage_mode == "memory"):
elif(self.settings.storage_mode in ["memory","hybrid"]):
self._frame_holder = ForwardBackwardData(self.num_frames, self._num_total_bp)
else:
output_path = Path(self.video_metadata["output-file-path"]).resolve()
video_path = Path(self.video_metadata["orig-video-path"]).resolve()
disk_path = output_path.parent / (output_path.stem + ".dipui")

self._file_obj = disk_path.open("w+b")

with video_path.open("rb") as f:
shutil.copyfileobj(f, self._file_obj)

ctx = PoolWithProgress.get_optimal_ctx()
self._manager = ctx.Manager()
self._shared_memory = allocate_shared_memory(
ctx, DiskBackedForwardBackwardData.get_shared_memory_size(self.num_frames, self._num_total_bp)
)

self._frame_holder = DiskBackedForwardBackwardData(
self.num_frames,
self._num_total_bp,
self._file_obj,
self.settings.memory_cache_size,
lock=self._manager.RLock(),
memory_backing=self._shared_memory
)
self._frame_holder = self.get_frame_holder()

self._frame_holder.metadata.settings = dict(self.settings)
self._frame_holder.metadata.video_metadata = dict(self.video_metadata)
Expand Down Expand Up @@ -1466,8 +1471,8 @@ def get_settings(cls) -> ConfigSpec:
"Greedy is faster/simpler, hungarian provides better results."
),
"storage_mode": (
"disk",
type_casters.Literal("disk", "memory"),
"hybrid",
type_casters.Literal("disk", "hybrid", "memory"),
"Location to store frames while the algorithm is running."
),
"memory_cache_size": (
Expand Down
5 changes: 3 additions & 2 deletions diplomat/predictors/supervised_fpe/labelers.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,8 +383,8 @@ def __init__(self, frame_engine: EditableFramePassEngine):
self._frame_engine = frame_engine
self._settings = labeler_lib.SettingCollection(
minimum_peak_value=labeler_lib.FloatSpin(0, 1, 0.05, 0.001, 4),
selected_peak_value=labeler_lib.FloatSpin(0, 1, 0.95, 0.001, 4),
unselected_peak_value=labeler_lib.FloatSpin(0, 1, 0.05, 0.001, 4)
selected_peak_value=labeler_lib.FloatSpin(0.5, 1, 0.95, 0.001, 4),
unselected_peak_value=labeler_lib.FloatSpin(0, 0.5, 0.05, 0.001, 4)
)

def predict_location(
Expand Down Expand Up @@ -415,6 +415,7 @@ def predict_location(

peak_locs = find_peaks(xs, ys, probs, meta.width)
peak_locs = peak_locs[probs[peak_locs] >= config.minimum_peak_value]
print(peak_locs)
if(len(peak_locs) <= 1):
# No peaks, or only one peak, perform basically a no-op, return prior frame state...
x, y, prob = self._frame_engine.scmap_to_video_coord(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import shutil
import traceback

from collections import UserList
from pathlib import Path

from diplomat.predictors.sfpe.disk_sparse_storage import DiskBackedForwardBackwardData
from diplomat.wx_gui.progress_dialog import FBProgressDialog
from diplomat.predictors.supervised_fpe.labelers import Approximate, Point, NearestPeakInSource, ApproximateSourceOnly
from diplomat.predictors.supervised_fpe.scorers import EntropyOfTransitions, MaximumJumpInStandardDeviations
Expand Down Expand Up @@ -544,8 +546,8 @@ def _partial_rerun(

def _on_run_fb(self, submit_evt: bool = True) -> bool:
"""
PRIVATE: Method is run whenever the Frame Pass Engine is rerun on the data. Runs the Frame Passes only in chunks the user has modified and
then updates the UI to display the changed data.
PRIVATE: Method is run whenever the Frame Pass Engine is rerun on the data. Runs the Frame Passes only in
chunks the user has modified and then updates the UI to display the changed data.
:param submit_evt: A boolean, determines if this should submit a new history event. Undo/Redo actions call
this method with this parameter set to false, otherwise is defaults to true.
Expand Down Expand Up @@ -589,6 +591,43 @@ def _on_run_fb(self, submit_evt: bool = True) -> bool:
# Return false to not clear the history....
return False

def _copy_to_disk(self, progress_bar: ProgressBar, new_frame_holder: ForwardBackwardData):
progress_bar.message("Saving to Disk")
progress_bar.reset(self._frame_holder.num_frames * self._frame_holder.num_bodyparts)

new_frame_holder.metadata = self._frame_holder.metadata
for frame_idx in range(len(self._frame_holder.frames)):
for bodypart_idx in range(len(self._frame_holder.frames[frame_idx])):
new_frame_holder.frames[frame_idx][bodypart_idx] = self._frame_holder.frames[frame_idx][
bodypart_idx]
progress_bar.update()

def _on_manual_save(self):
output_path = Path(self.video_metadata["output-file-path"]).resolve()
video_path = Path(self.video_metadata["orig-video-path"]).resolve()
disk_path = output_path.parent / (output_path.stem + ".dipui")

with disk_path.open("w+b") as disk_ui_file:
with video_path.open("rb") as f:
shutil.copyfileobj(f, disk_ui_file)

with DiskBackedForwardBackwardData(
self.num_frames,
self._num_total_bp,
disk_ui_file,
self.settings.memory_cache_size
) as disk_frame_holder:
with FBProgressDialog(self._fb_editor, title="Save to Disk") as dialog:
dialog.Show()
self._fb_editor.Enable(False)
self._copy_to_disk(dialog.progress_bar, disk_frame_holder)
self._fb_editor.Enable(True)

def _on_visual_settings_change(self, data):
old_data = self._frame_holder.metadata["video_metadata"]
old_data.update(data)
self._frame_holder.metadata["video_metadata"] = old_data

def _on_end(self, progress_bar: ProgressBar) -> Optional[Pose]:
if(self._restore_path is None):
self._run_frame_passes(progress_bar)
Expand Down Expand Up @@ -616,6 +655,12 @@ def _on_end(self, progress_bar: ProgressBar) -> Optional[Pose]:
relaxed_radius=self.settings.relaxed_maximum_radius
)

if(self._restore_path is None and self.settings.storage_mode == "hybrid"):
new_frame_holder = self.get_frame_holder()
self._copy_to_disk(progress_bar, new_frame_holder)
self._frame_holder = new_frame_holder
self._frame_holder._frames.flush()

self._video_hdl = cv2.VideoCapture(self._video_path)

app = wx.App()
Expand All @@ -630,7 +675,8 @@ def _on_end(self, progress_bar: ProgressBar) -> Optional[Pose]:
[Approximate(self), ApproximateSourceOnly(self), Point(self), NearestPeakInSource(self)],
[EntropyOfTransitions(self), MaximumJumpInStandardDeviations(self)],
None,
list(range(1, self.num_outputs + 1)) * (self._num_total_bp // self.num_outputs)
list(range(1, self.num_outputs + 1)) * (self._num_total_bp // self.num_outputs),
self._on_manual_save if(self.settings.storage_mode == "memory") else None
)

for s in self._fb_editor.score_displays:
Expand All @@ -643,6 +689,7 @@ def _on_end(self, progress_bar: ProgressBar) -> Optional[Pose]:
self._fb_editor.history.register_redoer(self.RERUN_HIST_EVT, self._on_hist_fb)
self._fb_editor.history.register_confirmer(self.RERUN_HIST_EVT, self._confirm_action)
self._fb_editor.set_fb_runner(self._on_run_fb)
self._fb_editor.set_plot_settings_changer(self._on_visual_settings_change)

self._fb_editor.Show()

Expand Down
24 changes: 16 additions & 8 deletions diplomat/utils/extract_frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Provides utility functions for quickly extracting frames from diplomat frame store files, and also printing frame data to the terminal
for debugging and display purposes.
"""
from typing import BinaryIO, Sequence, Callable, Optional, Generator, Union, Tuple, NamedTuple
from typing import BinaryIO, Sequence, Callable, Optional, Generator, Union, Tuple, NamedTuple, List
from diplomat.processing import TrackingData
from diplomat.utils import frame_store_fmt
from io import BytesIO
Expand Down Expand Up @@ -243,7 +243,10 @@ def extract_n_pack(
return base64.encodebytes(out.getvalue())


def unpack_frame_string(frame_string: bytes, frames_per_iter: int = 0) -> Union[TrackingData, Generator[TrackingData, None, None]]:
def unpack_frame_string(
frame_string: bytes,
frames_per_iter: int = 0
) -> Tuple[List[str], Union[TrackingData, Generator[TrackingData, None, None]]]:
"""
Unpack a frame store string into a tracking data object for access to the original probability frame data.
Expand All @@ -252,22 +255,27 @@ def unpack_frame_string(frame_string: bytes, frames_per_iter: int = 0) -> Union[
0 or less, this function returns a single TrackingData object storing all frames instead of
returning a generator.
:returns: A single TrackingData object if frames_per_iter <= 0, a Generator of TrackingData objects if
frames_per_iter > 0.
:returns: A tuple containing:
- A list of strings (body parts) and,
- A single TrackingData object if frames_per_iter <= 0,
or a Generator of TrackingData objects if frames_per_iter > 0.
"""
f = BytesIO(base64.decodebytes(frame_string))

reader = frame_store_fmt.DLFSReader(f)

if (frames_per_iter <= 0):
yield reader.read_frames(reader.get_header().number_of_frames)
return
if(frames_per_iter <= 0):
return (reader.get_header().bodypart_names, reader.read_frames(reader.get_header().number_of_frames))
else:
return (reader.get_header().bodypart_names, _unpack_frame_string_gen(reader, frames_per_iter))


def _unpack_frame_string_gen(reader: frame_store_fmt.DLFSReader, frames_per_iter: int = 0):
while(reader.has_next(frames_per_iter)):
yield reader.read_frames(frames_per_iter)

extra = reader.get_header().number_of_frames - (reader.tell_frame() + 1)
if(extra > 0):
yield reader.read_frames(extra)

return
return None
Loading

0 comments on commit dd406d1

Please sign in to comment.