Skip to content

Commit

Permalink
removed indices_to_onehot function
Browse files Browse the repository at this point in the history
  • Loading branch information
picciama committed Apr 17, 2024
1 parent f545fd3 commit 7b7b6a1
Showing 1 changed file with 0 additions and 33 deletions.
33 changes: 0 additions & 33 deletions tests/unit_tests/test_charge.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import unittest
from typing import List, Optional, Union

import numpy as np

Expand All @@ -18,38 +17,6 @@ def test_indices_to_one_hot_with_classes(self):

# Inside charge.py


def indices_to_one_hot(labels: Union[int, List[int], np.ndarray], classes: Optional[int] = None) -> np.ndarray:
"""
Convert a single or a list of labels to one-hot encoding.
:param labels: The labels to be one-hot encoding. Must be one-based.
:param classes: The number of classes, i.e. the length of the encoding. If omitted, set to the max label + 1.
:raises TypeError: If the type of labels is not understood
:raises ValueError: If the highest label in labels is larger or equal to the number of classes.
:return: np.ndarray with the one-hot encoded labels.
"""
if isinstance(labels, int):
labels = np.array([labels])
elif isinstance(labels, (list, np.ndarray)):
labels = np.array(labels)
else:
raise TypeError(
f"Type of labels not understood. Only int, List[int] and np.ndarray are supported. Given: {type(labels)}."
)

if classes is None:
classes = np.max(labels) + 1
elif classes < np.max(labels) + 1:
raise ValueError(
f"Number of classes must be greater than or equal to the maximum label in labels. "
f"Given classes: {classes}, maximum label: {np.max(labels)}."
)

one_hot = np.zeros((labels.size, classes), dtype=int)
one_hot[np.arange(labels.size), labels - 1] = 1
return one_hot

def test_indices_to_one_hot_with_int_and_class(self):
"""Test indices_to_one_hot with a single integer and given number of classes."""
labels = 1
Expand Down

0 comments on commit 7b7b6a1

Please sign in to comment.