Skip to content

Commit

Permalink
Merge pull request #6 from TravisWheelerLab/update-pose-labelers
Browse files Browse the repository at this point in the history
Draft: Branch for making changed to the pose labelers
  • Loading branch information
daphnedemekas authored Mar 5, 2024
2 parents dd406d1 + 20e32c0 commit efd7b39
Show file tree
Hide file tree
Showing 8 changed files with 187 additions and 10 deletions.
2 changes: 1 addition & 1 deletion diplomat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
A tool providing multi-animal tracking capabilities on top of other Deep learning based tracking software.
"""

__version__ = "0.1.0"
__version__ = "0.1.1"
# Can be used by functions to determine if diplomat was invoked through it's CLI interface.
CLI_RUN = False

Expand Down
150 changes: 150 additions & 0 deletions diplomat/predictors/fpe/frame_passes/mit_viterbi.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ def __init__(
self._enter_exit_prob = to_log_space(enter_exit_prob)
self._enter_stay_prob = to_log_space(enter_stay_prob)

"""The ViterbiTransitionTable class is used to manage transition probabilities,
including those modified by the dominance relationship and the "flat-topped" Gaussian distribution."""

@staticmethod
def _is_enter_state(coords: Coords) -> bool:
return len(coords[0]) == 1 and np.isneginf(coords[0][0])
Expand Down Expand Up @@ -132,6 +135,7 @@ def _init_gaussian_table(self, metadata: AttributeDict):
else:
self._scaled_std = (std if (std != "auto") else 1) / metadata.down_scaling

#flat topped gaussian
self._flatten_std = None if (conf.gaussian_plateau is None) else self._scaled_std * conf.gaussian_plateau
self._gaussian_table = norm(fpe_math.gaussian_table(
self.height, self.width, self._scaled_std, conf.amplitude,
Expand All @@ -153,6 +157,27 @@ def _init_gaussian_table(self, metadata: AttributeDict):
metadata.include_soft_domination = self.config.include_soft_domination

def _init_skeleton(self, data: ForwardBackwardData):
"""If skeleton data is available, this function initializes the skeleton tables,
which are used to enhance tracking by considering the structural
relationships between different body parts.
The _skeleton_tables is a StorageGraph object that stores the relationship between different body parts
as defined in the skeleton data from the metadata. Each entry in this table represents a connection
between two body parts (nodes) and contains the statistical data (bin_val, freq, avg) related to that connection.
This data is used to enhance tracking accuracy by considering the structural relationships between body parts.
Specifically, it stores:
# - The names of the nodes (body parts) involved in the skeleton structure.
# - A matrix for each pair of connected nodes, which is computed based on the skeleton formula. This matrix
# represents the likelihood of transitioning from one body part to another, taking into account the average
# distance and frequency of such transitions as observed in the training data.
# - The configuration parameters used for calculating these matrices, which include adjustments for log space
# calculations and other statistical considerations.
# This structure is crucial for the Viterbi algorithm to accurately model the movement and relationships
# between different parts of the body during tracking.
"""

if("skeleton" in data.metadata):
meta = data.metadata
self._skeleton_tables = StorageGraph(meta.skeleton.node_names())
Expand Down Expand Up @@ -188,6 +213,10 @@ def run_pass(
in_place: bool = True,
reset_bar: bool = True
) -> ForwardBackwardData:
"""
This is the main function that orchestrates the forward and backward passes of the Viterbi algorithm.
It initializes the necessary tables and states, then runs the forward pass to calculate probabilities,
followed by a backtrace to determine the most probable paths."""
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
if("fixed_frame_index" not in fb_data.metadata):
Expand Down Expand Up @@ -343,6 +372,9 @@ def _run_backtrace(
@staticmethod
def _get_pool():
# Check globals for a pool...
"""This function sets up a multiprocessing pool for parallel processing,
improving the efficiency of the algorithm by allowing it to process
multiple parts of the frame or multiple frames simultaneously."""
if(FramePass.GLOBAL_POOL is not None):
return FramePass.GLOBAL_POOL

Expand Down Expand Up @@ -431,6 +463,42 @@ def _compute_backtrace_step(
soft_dom_weight: float = 0,
skeleton_weight: float = 0
) -> List[np.ndarray]:
"""This method is responsible for computing the transition probabilities from the prior maximum locations
(highest probability states) of all body parts in the prior frame to the current frame's states.
It's where the algorithm determines the most probable path that leads to each pixel
in the current frame based on the accumulated probabilities from previous frames.
Parameters
prior: A list of lists containing tuples.
Each tuple represents the probability and coordinates (x, y) of the prior maximum locations
for all body parts in the prior frame.
This data structure allows the method to consider multiple potential origins for each body part's current position.
current: A list of tuples containing the probability and coordinates (x, y) of the current frame's states
This represents the possible current positions and their associated probabilities.
bp_idx: The index of the body part being processed. This is used to identify which part of the data corresponds to the current body part in multi-body part tracking scenarios.
metadata: The metadata from the ForwardBackwardData object.
An AttributeDict containing metadata that might be necessary for the computation, such as configuration parameters or additional data needed for probability calculations.
transition_function: A function or callable object that calculates the transition probabilities between states. This is crucial for determining how likely it is to move from one state to another.
resist_transition_function: A function or callable object that calculates the resistance to transitioning between states.
Similar to transition_function, but used for calculating resistive transitions, which might be part of handling interactions between different tracked objects or body parts.
skeleton_table: A StorageGraph object that stores the relationship between different body parts as defined in the skeleton data from the metadata.
An optional parameter that, if provided, contains skeleton information that can be used to enhance the tracking by considering the structural relationships between different body parts.
soft_dom_weight: A float representing the weight of the soft domination factor.
skeleton_weight: A float representing the weight of the skeleton factor.
"""

# If skeleton information is available, the method first computes the influence of skeletal connections
# on the transition probabilities.
# This involves considering the structural relationships between body parts and adjusting probabilities accordingly.
skel_res = cls._compute_from_skeleton(
prior,
current,
Expand All @@ -439,6 +507,11 @@ def _compute_backtrace_step(
skeleton_table
)

#The method then calculates the effect of soft domination,
# which is a technique used to handle the dominance relationship between different paths.
# This step adjusts the probabilities to favor more likely paths and suppress less likely ones,
# based on the configured soft domination weight.

from_soft_dom = cls._compute_soft_domination(
prior,
current,
Expand All @@ -447,12 +520,22 @@ def _compute_backtrace_step(
resist_transition_function,
)

#The core of the method involves calculating the transition probabilities from the prior states to the current states.
# This is done using the transition_function, which takes into account the distances between states and other factors
# to determine how likely it is to transition from one state to another.
trans_res = cls.log_viterbi_between(
current,
prior[bp_idx],
transition_function
)

#The calculated probabilities from the skeleton influence, soft domination,
# and direct transitions are then combined to determine the overall probability of transitioning
# to each current state from the prior states.
# This involves weighting each component according to the configured weights and summing them up to get the final probabilities.

#Normalization: Finally, the probabilities are normalized to ensure they are within a valid range
# and to facilitate comparison between different paths.
return norm_together([
t + s * skeleton_weight + d * soft_dom_weight for t, s, d in zip(trans_res, skel_res, from_soft_dom)
])
Expand Down Expand Up @@ -527,6 +610,8 @@ def _compute_from_skeleton(
merge_internal: Callable[[np.ndarray, int], np.ndarray] = np.max,
merge_results: bool = True
) -> Union[List[Tuple[int, List[NumericArray]]], List[NumericArray]]:

#TODO: Add docstring and notes in coda
if(skeleton_table is None):
return [0] * len(current_data) if(merge_results) else []

Expand Down Expand Up @@ -597,6 +682,45 @@ def _compute_soft_domination(
merge_internal: Callable[[np.ndarray, int], np.ndarray] = np.max,
merge_results: bool = True
) -> Union[List[Tuple[int, List[NumericArray]]], List[NumericArray]]:
"""
Computes the soft domination for a given body part across frames, considering prior and current data.
This method calculates the soft domination values by comparing the probabilities of a body part being in
a certain state in the current frame against its probabilities in the prior frames. It uses a transition
function to determine the likelihood of transitioning from each state in the prior frames to each state in
the current frame. The results are merged using specified merging functions to find the most probable state
transitions. This method can optionally merge the results across all body parts to find the overall most
probable states.
Parameters:
- prior: Union[List[ForwardBackwardFrame], List[List[Tuple[np.ndarray, Tuple[np.ndarray, np.ndarray]]]]]
The prior frame data or computed probabilities and coordinates for each body part.
- current_data: List[Tuple[np.ndarray, Tuple[np.ndarray, np.ndarray]]]
The current frame data including probabilities and coordinates for each body part.
- bp_idx: int
The index of the body part being processed.
- metadata: AttributeDict
Metadata containing configuration and state information for the current processing.
- transition_func: Optional[TransitionFunction]
The function used to compute the transition probabilities between states.
- merge_arrays: Callable[[Iterable[np.ndarray]], np.ndarray]
A function to merge arrays of probabilities from different transitions.
- merge_internal: Callable[[np.ndarray, int], np.ndarray]
A function to merge probabilities within a single transition.
- merge_results: bool
A flag indicating whether to merge the results across all body parts.
Returns:
- Union[List[Tuple[int, List[NumericArray]]], List[NumericArray]]:
The computed soft domination values for the specified body part, either as a list of numeric arrays
(if merge_results is False) or as a list of tuples containing the body part index and the list of
numeric arrays (if merge_results is True).
This method is crucial for optimizing the Viterbi path selection by considering not only the most probable
paths but also how these paths compare when considering potential transitions from prior states. It helps
in refining the selection of paths that are not only probable in isolation but also in the context of the
sequence of frames being analyzed.
"""
if(transition_func is None or metadata.num_outputs <= 1):
return [0] * len(current_data) if(merge_results) else []

Expand Down Expand Up @@ -669,6 +793,12 @@ def _compute_normal_frame(
soft_dom_weight: float = 0,
skeleton_weight: float = 0
) -> List[ForwardBackwardFrame]:

"""processes a single frame in the context of tracking multiple body parts or individuals,
calculating the probabilities of each body part being in each position based on prior information,
current observations, and various transition models.
It integrates several key concepts, including handling occlusions, leveraging skeleton information,
and applying soft domination to refine the tracking process."""
group_range = range(
bp_group * metadata.num_outputs,
(bp_group + 1) * metadata.num_outputs
Expand Down Expand Up @@ -814,6 +944,21 @@ def log_viterbi_between(
merge_arrays: Callable[[Iterable[np.ndarray]], np.ndarray] = np.maximum.reduce,
merge_internal: Callable[[np.ndarray, int], np.ndarray] = np.nanmax
) -> List[np.ndarray]:
"""
This method calculates the transition probabilities between the prior and current data points for each body part.
It utilizes a transition function to compute the probabilities of moving from each prior state to each current state.
The method then merges these probabilities across all body parts to determine the most likely transitions.
Parameters:
- current_data: A sequence of tuples containing the current probabilities and coordinates for each body part.
- prior_data: A sequence of tuples containing the prior probabilities and coordinates for each body part.
- transition_function: A callable that computes the transition probabilities between prior and current states.
- merge_arrays: A callable that merges arrays of probabilities across all body parts.
- merge_internal: A callable that merges probabilities within each body part.
Returns:
A list of numpy arrays representing the merged transition probabilities for each body part.
"""
return [
merge_arrays([
merge_internal(
Expand Down Expand Up @@ -851,6 +996,11 @@ def generate_occluded(
@classmethod
def get_config_options(cls) -> ConfigSpec:
# Class to enforce that probabilities are between 0 and 1....
"""This function returns a dictionary of configuration options that can be adjusted to
customize the behavior of the algorithm.
These options include parameters for the Gaussian distribution,
probabilities for obscured and edge states,
and weights for the dominance relationship and skeleton data."""
return {
"standard_deviation": (
"auto", tc.Union(float, tc.Literal("auto")),
Expand Down
13 changes: 10 additions & 3 deletions diplomat/predictors/supervised_fpe/labelers.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def predict_location(
frame = self._frame_engine.frame_data.frames[frame_idx][bp_idx]

if(x is None):
#should we be returning this prob value or the probability value?
x, y, prob = self._frame_engine.scmap_to_video_coord(
*self._frame_engine.get_maximum_with_defaults(frame),
meta.down_scaling
Expand Down Expand Up @@ -206,7 +207,7 @@ def predict_location(
bp_idx: int,
x: float,
y: float,
probability: float
probability: float,
) -> Tuple[Any, Tuple[float, float, float]]:
info = self._settings.get_values()
user_amp = info.user_input_strength / 1000
Expand Down Expand Up @@ -293,7 +294,10 @@ def pose_change(self, new_state: Any) -> Any:
)
new_data.pack(*[np.array([item]) for item in [y, x, prob, off_x, off_y]])
else:
new_data = suggested_frame.src_data
y, x, prob, x_offset, y_offset = suggested_frame.src_data.unpack()
max_prob_idx = np.argmax(prob)
new_data = SparseTrackingData()
new_data.pack(*[np.array([item]) for item in [y[max_prob_idx], x[max_prob_idx], 1, x_offset[max_prob_idx], y_offset[max_prob_idx]]])

new_frame = ForwardBackwardFrame()
new_frame.orig_data = new_data
Expand Down Expand Up @@ -481,7 +485,10 @@ def pose_change(self, new_state: Any) -> Any:
)
new_data.pack(*[np.array([item]) for item in [y, x, prob, off_x, off_y]])
else:
new_data = suggested_frame.src_data
y, x, prob, x_offset, y_offset = suggested_frame.src_data.unpack()
max_prob_idx = np.argmax(prob)
new_data = SparseTrackingData()
new_data.pack(*[np.array([item]) for item in [y[max_prob_idx], x[max_prob_idx], 1, x_offset[max_prob_idx], y_offset[max_prob_idx]]])

new_frame = ForwardBackwardFrame()
new_frame.orig_data = new_data
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def on_end(self, progress_bar: ProgressBar) -> Union[None, Pose]:
self._get_names(),
self.video_metadata,
self._get_crop_box(),
[Approximate(self), ApproximateSourceOnly(self), Point(self), NearestPeakInSource(self)],
[Approximate(self), Point(self), NearestPeakInSource(self), ApproximateSourceOnly(self)],
[EntropyOfTransitions(self), MaximumJumpInStandardDeviations(self)],
None,
list(range(1, self.num_outputs + 1)) * (self._num_total_bp // self.num_outputs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,15 @@ def _partial_rerun(
old_poses: Pose,
progress_bar: ProgressBar
) -> Tuple[Pose, Iterable[int]]:

#TODO : delete below lines, not doing as expected

# # For each changed frame and each body part, take the maximum probability coordinates and set them to one
# for (frame_idx, bp_idx), frame in changed_frames.items():
# max_prob_coord = np.unravel_index(frame.frame_probs.argmax(), frame.frame_probs.shape)
# new_frame_probs = np.zeros_like(frame.frame_probs) #copy because this is read only
# new_frame_probs[max_prob_coord] = 1
# frame.frame_probs = new_frame_probs
# Determine what segments have been manipulated...
segment_indexes = sorted({np.searchsorted(self._segments[:, 1], f_i, "right") for f_i, b_i in changed_frames})

Expand All @@ -531,7 +540,8 @@ def _partial_rerun(
for (s_i, e_i, f_i), seg_ord in zip(self._segments, self._segment_bp_order):
poses[s_i:e_i, :] = poses[s_i:e_i, seg_ord]
old_poses.get_all()[:] = poses.reshape(old_poses.get_frame_count(), old_poses.get_bodypart_count() * 3)



return (
self.get_maximums(
self._frame_holder,
Expand Down Expand Up @@ -672,7 +682,7 @@ def _on_end(self, progress_bar: ProgressBar) -> Optional[Pose]:
self._get_names(),
self.video_metadata,
self._get_crop_box(),
[Approximate(self), ApproximateSourceOnly(self), Point(self), NearestPeakInSource(self)],
[Approximate(self), Point(self), NearestPeakInSource(self), ApproximateSourceOnly(self)],
[EntropyOfTransitions(self), MaximumJumpInStandardDeviations(self)],
None,
list(range(1, self.num_outputs + 1)) * (self._num_total_bp // self.num_outputs),
Expand Down
2 changes: 2 additions & 0 deletions diplomat/wx_gui/fpe_editor.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ def __init__(
bp_names=names,
labeling_modes=labeling_modes,
group_list=part_groups,
# skeleton_info = self.skeleton_info
**ps
)
self.video_controls = VideoController(self._sub_panel, video_player=self.video_player.video_viewer)
Expand Down Expand Up @@ -336,6 +337,7 @@ def __init__(

self.video_controls.Bind(PointViewNEdit.EVT_FRAME_CHANGE, self._on_frame_chg)


def _on_close_caller(self, event: wx.CloseEvent):
self._on_close(event, self._was_save_button_flag)
self._was_save_button_flag = False
Expand Down
Loading

0 comments on commit efd7b39

Please sign in to comment.