Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
aliberts committed Nov 25, 2024
1 parent 975c1c2 commit b004a0e
Showing 1 changed file with 92 additions and 0 deletions.
92 changes: 92 additions & 0 deletions lerobot/common/datasets/compute_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,3 +211,95 @@ def aggregate_stats(ls_datasets) -> dict[str, torch.Tensor]:
)
)
return stats


# TODO(aliberts): refactor stats in save_episodes
# import numpy as np
# from lerobot.common.datasets.utils import load_image_as_numpy
# def aggregate_stats_v2(stats_list: list) -> dict:
# """Aggregate stats from multiple compute_stats outputs into a single set of stats.

# The final stats will have the union of all data keys from each of the stats dicts.

# For instance:
# - new_min = min(min_dataset_0, min_dataset_1, ...)
# - new_max = max(max_dataset_0, max_dataset_1, ...)
# - new_mean = (mean of all data, weighted by counts)
# - new_std = (std of all data)
# """
# data_keys = set(key for stats in stats_list for key in stats.keys())
# aggregated_stats = {key: {} for key in data_keys}

# for key in data_keys:
# # Collect stats for the current key from all datasets where it exists
# stats_with_key = [stats[key] for stats in stats_list if key in stats]

# # Aggregate 'min' and 'max' using np.minimum and np.maximum
# aggregated_stats[key]['min'] = np.minimum.reduce([s['min'] for s in stats_with_key])
# aggregated_stats[key]['max'] = np.maximum.reduce([s['max'] for s in stats_with_key])

# # Extract means, variances (std^2), and counts
# means = np.array([s['mean'] for s in stats_with_key])
# variances = np.array([s['std']**2 for s in stats_with_key])
# counts = np.array([s['count'] for s in stats_with_key])

# # Ensure counts can broadcast with means/variances if they have additional dimensions
# counts = counts.reshape(-1, *[1]*(means.ndim - 1))

# # Compute total counts
# total_count = counts.sum(axis=0)

# # Compute the weighted mean
# weighted_means = means * counts
# total_mean = weighted_means.sum(axis=0) / total_count

# # Compute the variance using the parallel algorithm
# delta_means = means - total_mean
# weighted_variances = (variances + delta_means**2) * counts
# total_variance = weighted_variances.sum(axis=0) / total_count

# # Store the aggregated stats
# aggregated_stats[key]['mean'] = total_mean
# aggregated_stats[key]['std'] = np.sqrt(total_variance)
# aggregated_stats[key]['count'] = total_count

# return aggregated_stats


# def compute_episode_stats(episode_buffer: dict, features: dict, episode_length: int, image_sampling: int = 10) -> dict:
# stats = {}
# for key, data in episode_buffer.items():
# if features[key]["dtype"] in ["image", "video"]:
# stats[key] = compute_image_stats(data, sampling=image_sampling)
# else:
# axes_to_reduce = 0 # Compute stats over the first axis
# stats[key] = {
# "min": np.min(data, axis=axes_to_reduce),
# "max": np.max(data, axis=axes_to_reduce),
# "mean": np.mean(data, axis=axes_to_reduce),
# "std": np.std(data, axis=axes_to_reduce),
# "count": episode_length,
# }
# return stats


# def compute_image_stats(image_paths: list[str], sampling: int = 10) -> dict:
# images = []
# samples = range(0, len(image_paths), sampling)
# for idx in samples:
# path = image_paths[idx]
# img = load_image_as_numpy(path, channel_first=True)
# images.append(img)

# images = np.stack(images)
# axes_to_reduce = (0, 2, 3) # keep channel dim
# image_stats = {
# "min": np.min(images, axis=axes_to_reduce, keepdims=True),
# "max": np.max(images, axis=axes_to_reduce, keepdims=True),
# "mean": np.mean(images, axis=axes_to_reduce, keepdims=True),
# "std": np.std(images, axis=axes_to_reduce, keepdims=True)
# }
# for key in image_stats: # squeeze batch dim
# image_stats[key] = np.squeeze(image_stats[key], axis=0)

# return image_stats

0 comments on commit b004a0e

Please sign in to comment.