Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add keepdims keyword argument #701

Merged
merged 5 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 20 additions & 13 deletions code/numpy/numerical.c
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ static mp_obj_t numerical_sum_mean_std_iterable(mp_obj_t oin, uint8_t optype, si
}
}

static mp_obj_t numerical_sum_mean_std_ndarray(ndarray_obj_t *ndarray, mp_obj_t axis, uint8_t optype, size_t ddof) {
static mp_obj_t numerical_sum_mean_std_ndarray(ndarray_obj_t *ndarray, mp_obj_t axis, mp_obj_t keepdims, uint8_t optype, size_t ddof) {
COMPLEX_DTYPE_NOT_IMPLEMENTED(ndarray->dtype)
uint8_t *array = (uint8_t *)ndarray->array;
shape_strides _shape_strides = tools_reduce_axes(ndarray, axis);
Expand Down Expand Up @@ -372,15 +372,15 @@ static mp_obj_t numerical_sum_mean_std_ndarray(ndarray_obj_t *ndarray, mp_obj_t
mp_float_t norm = (mp_float_t)_shape_strides.shape[0];
// re-wind the array here
farray = (mp_float_t *)results->array;
for(size_t i=0; i < results->len; i++) {
for(size_t i = 0; i < results->len; i++) {
*farray++ *= norm;
}
}
} else {
bool isStd = optype == NUMERICAL_STD ? 1 : 0;
results = ndarray_new_dense_ndarray(_shape_strides.ndim, _shape_strides.shape, NDARRAY_FLOAT);
farray = (mp_float_t *)results->array;
// we can return the 0 array here, if the degrees of freedom is larger than the length of the axis
// we can return the 0 array here, if the degrees of freedom are larger than the length of the axis
if((optype == NUMERICAL_STD) && (_shape_strides.shape[0] <= ddof)) {
return MP_OBJ_FROM_PTR(results);
}
Expand All @@ -397,11 +397,9 @@ static mp_obj_t numerical_sum_mean_std_ndarray(ndarray_obj_t *ndarray, mp_obj_t
RUN_MEAN_STD(mp_float_t, array, farray, _shape_strides, div, isStd);
}
}
if(results->ndim == 0) { // return a scalar here
return mp_binary_get_val_array(results->dtype, results->array, 0);
}
return MP_OBJ_FROM_PTR(results);
return ulab_tools_restore_dims(ndarray, results, keepdims, _shape_strides);
}
// we should never get to this point
return mp_const_none;
}
#endif
Expand Down Expand Up @@ -441,7 +439,7 @@ static mp_obj_t numerical_argmin_argmax_iterable(mp_obj_t oin, uint8_t optype) {
}
}

static mp_obj_t numerical_argmin_argmax_ndarray(ndarray_obj_t *ndarray, mp_obj_t axis, uint8_t optype) {
static mp_obj_t numerical_argmin_argmax_ndarray(ndarray_obj_t *ndarray, mp_obj_t keepdims, mp_obj_t axis, uint8_t optype) {
// TODO: treat the flattened array
if(ndarray->len == 0) {
mp_raise_ValueError(MP_ERROR_TEXT("attempt to get (arg)min/(arg)max of empty sequence"));
Expand Down Expand Up @@ -521,7 +519,9 @@ static mp_obj_t numerical_argmin_argmax_ndarray(ndarray_obj_t *ndarray, mp_obj_t
int32_t *strides = m_new0(int32_t, ULAB_MAX_DIMS);

numerical_reduce_axes(ndarray, ax, shape, strides);
uint8_t index = ULAB_MAX_DIMS - ndarray->ndim + ax;
shape_strides _shape_strides = tools_reduce_axes(ndarray, axis);

uint8_t index = _shape_strides.axis;

ndarray_obj_t *results = NULL;

Expand Down Expand Up @@ -550,8 +550,9 @@ static mp_obj_t numerical_argmin_argmax_ndarray(ndarray_obj_t *ndarray, mp_obj_t
if(results->len == 1) {
return mp_binary_get_val_array(results->dtype, results->array, 0);
}
return MP_OBJ_FROM_PTR(results);
return ulab_tools_restore_dims(ndarray, results, keepdims, _shape_strides);
}
// we should never get to this point
return mp_const_none;
}
#endif
Expand All @@ -560,13 +561,16 @@ static mp_obj_t numerical_function(size_t n_args, const mp_obj_t *pos_args, mp_m
static const mp_arg_t allowed_args[] = {
{ MP_QSTR_, MP_ARG_REQUIRED | MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE} } ,
{ MP_QSTR_axis, MP_ARG_OBJ, { .u_rom_obj = MP_ROM_NONE } },
{ MP_QSTR_keepdims, MP_ARG_OBJ, { .u_rom_obj = MP_ROM_FALSE } },
};

mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)];
mp_arg_parse_all(n_args, pos_args, kw_args, MP_ARRAY_SIZE(allowed_args), allowed_args, args);

mp_obj_t oin = args[0].u_obj;
mp_obj_t axis = args[1].u_obj;
mp_obj_t keepdims = args[2].u_obj;

if((axis != mp_const_none) && (!mp_obj_is_int(axis))) {
mp_raise_TypeError(MP_ERROR_TEXT("axis must be None, or an integer"));
}
Expand Down Expand Up @@ -598,11 +602,11 @@ static mp_obj_t numerical_function(size_t n_args, const mp_obj_t *pos_args, mp_m
case NUMERICAL_ARGMIN:
case NUMERICAL_ARGMAX:
COMPLEX_DTYPE_NOT_IMPLEMENTED(ndarray->dtype)
return numerical_argmin_argmax_ndarray(ndarray, axis, optype);
return numerical_argmin_argmax_ndarray(ndarray, keepdims, axis, optype);
case NUMERICAL_SUM:
case NUMERICAL_MEAN:
COMPLEX_DTYPE_NOT_IMPLEMENTED(ndarray->dtype)
return numerical_sum_mean_std_ndarray(ndarray, axis, optype, 0);
return numerical_sum_mean_std_ndarray(ndarray, axis, keepdims, optype, 0);
default:
mp_raise_NotImplementedError(MP_ERROR_TEXT("operation is not implemented on ndarrays"));
}
Expand Down Expand Up @@ -1385,6 +1389,7 @@ mp_obj_t numerical_std(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_arg
{ MP_QSTR_, MP_ARG_REQUIRED | MP_ARG_OBJ, {.u_rom_obj = MP_ROM_NONE } } ,
{ MP_QSTR_axis, MP_ARG_OBJ, {.u_rom_obj = MP_ROM_NONE } },
{ MP_QSTR_ddof, MP_ARG_KW_ONLY | MP_ARG_INT, {.u_int = 0} },
{ MP_QSTR_keepdims, MP_ARG_OBJ, { .u_rom_obj = MP_ROM_FALSE } },
};

mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)];
Expand All @@ -1393,6 +1398,8 @@ mp_obj_t numerical_std(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_arg
mp_obj_t oin = args[0].u_obj;
mp_obj_t axis = args[1].u_obj;
size_t ddof = args[2].u_int;
mp_obj_t keepdims = args[2].u_obj;

if((axis != mp_const_none) && (mp_obj_get_int(axis) != 0) && (mp_obj_get_int(axis) != 1)) {
// this seems to pass with False, and True...
mp_raise_ValueError(MP_ERROR_TEXT("axis must be None, or an integer"));
Expand All @@ -1401,7 +1408,7 @@ mp_obj_t numerical_std(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_arg
return numerical_sum_mean_std_iterable(oin, NUMERICAL_STD, ddof);
} else if(mp_obj_is_type(oin, &ulab_ndarray_type)) {
ndarray_obj_t *ndarray = MP_OBJ_TO_PTR(oin);
return numerical_sum_mean_std_ndarray(ndarray, axis, NUMERICAL_STD, ddof);
return numerical_sum_mean_std_ndarray(ndarray, axis, keepdims, NUMERICAL_STD, ddof);
} else {
mp_raise_TypeError(MP_ERROR_TEXT("input must be tuple, list, range, or ndarray"));
}
Expand Down
2 changes: 1 addition & 1 deletion code/ulab.c
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
#include "user/user.h"
#include "utils/utils.h"

#define ULAB_VERSION 6.7.0
#define ULAB_VERSION 6.7.1
#define xstr(s) str(s)
#define str(s) #s

Expand Down
76 changes: 52 additions & 24 deletions code/ulab_tools.c
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,15 @@ void *ndarray_set_float_function(uint8_t dtype) {
}
#endif /* NDARRAY_BINARY_USES_FUN_POINTER */

int8_t tools_get_axis(mp_obj_t axis, uint8_t ndim) {
int8_t ax = mp_obj_get_int(axis);
if(ax < 0) ax += ndim;
if((ax < 0) || (ax > ndim - 1)) {
mp_raise_ValueError(MP_ERROR_TEXT("axis is out of bounds"));
}
return ax;
}

shape_strides tools_reduce_axes(ndarray_obj_t *ndarray, mp_obj_t axis) {
// TODO: replace numerical_reduce_axes with this function, wherever applicable
// This function should be used, whenever a tensor is contracted;
Expand All @@ -172,38 +181,36 @@ shape_strides tools_reduce_axes(ndarray_obj_t *ndarray, mp_obj_t axis) {
}
shape_strides _shape_strides;

size_t *shape = m_new(size_t, ULAB_MAX_DIMS + 1);
_shape_strides.shape = shape;
int32_t *strides = m_new(int32_t, ULAB_MAX_DIMS + 1);
_shape_strides.strides = strides;

_shape_strides.increment = 0;
// this is the contracted dimension (won't be overwritten for axis == None)
_shape_strides.ndim = 0;

memcpy(_shape_strides.shape, ndarray->shape, sizeof(size_t) * ULAB_MAX_DIMS);
memcpy(_shape_strides.strides, ndarray->strides, sizeof(int32_t) * ULAB_MAX_DIMS);

if(axis == mp_const_none) {
_shape_strides.shape = ndarray->shape;
_shape_strides.strides = ndarray->strides;
return _shape_strides;
}

uint8_t index = ULAB_MAX_DIMS - 1; // value of index for axis == mp_const_none (won't be overwritten)
size_t *shape = m_new(size_t, ULAB_MAX_DIMS + 1);
_shape_strides.shape = shape;
int32_t *strides = m_new(int32_t, ULAB_MAX_DIMS + 1);
_shape_strides.strides = strides;

memcpy(_shape_strides.shape, ndarray->shape, sizeof(size_t) * ULAB_MAX_DIMS);
memcpy(_shape_strides.strides, ndarray->strides, sizeof(int32_t) * ULAB_MAX_DIMS);

_shape_strides.axis = ULAB_MAX_DIMS - 1; // value of index for axis == mp_const_none (won't be overwritten)

if(axis != mp_const_none) { // i.e., axis is an integer
int8_t ax = mp_obj_get_int(axis);
if(ax < 0) ax += ndarray->ndim;
if((ax < 0) || (ax > ndarray->ndim - 1)) {
mp_raise_ValueError(MP_ERROR_TEXT("index out of range"));
}
index = ULAB_MAX_DIMS - ndarray->ndim + ax;
int8_t ax = tools_get_axis(axis, ndarray->ndim);
_shape_strides.axis = ULAB_MAX_DIMS - ndarray->ndim + ax;
_shape_strides.ndim = ndarray->ndim - 1;
}

// move the value stored at index to the leftmost position, and align everything else to the right
_shape_strides.shape[0] = ndarray->shape[index];
_shape_strides.strides[0] = ndarray->strides[index];
for(uint8_t i = 0; i < index; i++) {
_shape_strides.shape[0] = ndarray->shape[_shape_strides.axis];
_shape_strides.strides[0] = ndarray->strides[_shape_strides.axis];
for(uint8_t i = 0; i < _shape_strides.axis; i++) {
// entries to the right of index must be shifted by one position to the left
_shape_strides.shape[i + 1] = ndarray->shape[i];
_shape_strides.strides[i + 1] = ndarray->strides[i];
Expand All @@ -213,16 +220,37 @@ shape_strides tools_reduce_axes(ndarray_obj_t *ndarray, mp_obj_t axis) {
_shape_strides.increment = 1;
}

if(_shape_strides.ndim == 0) {
_shape_strides.ndim = 1;
_shape_strides.shape[ULAB_MAX_DIMS - 1] = 1;
_shape_strides.strides[ULAB_MAX_DIMS - 1] = ndarray->itemsize;
}

return _shape_strides;
}

int8_t tools_get_axis(mp_obj_t axis, uint8_t ndim) {
int8_t ax = mp_obj_get_int(axis);
if(ax < 0) ax += ndim;
if((ax < 0) || (ax > ndim - 1)) {
mp_raise_ValueError(MP_ERROR_TEXT("axis is out of bounds"));
mp_obj_t ulab_tools_restore_dims(ndarray_obj_t *ndarray, ndarray_obj_t *results, mp_obj_t keepdims, shape_strides _shape_strides) {
// restores the contracted dimension, if keepdims is True
if((ndarray->ndim == 1) && (keepdims != mp_const_true)) {
// since the original array has already been contracted and
// we don't want to keep the dimensions here, we have to return a scalar
return mp_binary_get_val_array(results->dtype, results->array, 0);
}
return ax;

if(keepdims == mp_const_true) {
results->ndim += 1;
for(int8_t i = 0; i < ULAB_MAX_DIMS; i++) {
results->shape[i] = ndarray->shape[i];
}
results->shape[_shape_strides.axis] = 1;

results->strides[ULAB_MAX_DIMS - 1] = ndarray->itemsize;
for(uint8_t i = ULAB_MAX_DIMS; i > 1; i--) {
results->strides[i - 2] = results->strides[i - 1] * results->shape[i - 1];
}
}

return MP_OBJ_FROM_PTR(results);
}

#if ULAB_MAX_DIMS > 1
Expand Down
2 changes: 2 additions & 0 deletions code/ulab_tools.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

typedef struct _shape_strides_t {
uint8_t increment;
uint8_t axis;
uint8_t ndim;
size_t *shape;
int32_t *strides;
Expand All @@ -34,6 +35,7 @@ void *ndarray_set_float_function(uint8_t );

shape_strides tools_reduce_axes(ndarray_obj_t *, mp_obj_t );
int8_t tools_get_axis(mp_obj_t , uint8_t );
mp_obj_t ulab_tools_restore_dims(ndarray_obj_t * , ndarray_obj_t * , mp_obj_t , shape_strides );
ndarray_obj_t *tools_object_is_square(mp_obj_t );

uint8_t ulab_binary_get_size(uint8_t );
Expand Down
12 changes: 12 additions & 0 deletions docs/ulab-change-log.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,15 @@
Mon, 30 Dec 2024

version 6.7.1

add keepdims keyword argument to numerical functions

Sun, 15 Dec 2024

version 6.7.0

add scipy.integrate module

Sun, 24 Nov 2024

version 6.6.1
Expand Down
23 changes: 23 additions & 0 deletions tests/2d/numpy/sum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
try:
from ulab import numpy as np
except ImportError:
import numpy as np

for dtype in (np.uint8, np.int8, np.uint16, np.int8, np.float):
a = np.array(range(12), dtype=dtype)
b = a.reshape((3, 4))

print(a)
print(b)
print()

print(np.sum(a))
print(np.sum(a, axis=0))
print(np.sum(a, axis=0, keepdims=True))

print()
print(np.sum(b))
print(np.sum(b, axis=0))
print(np.sum(b, axis=1))
print(np.sum(b, axis=0, keepdims=True))
print(np.sum(b, axis=1, keepdims=True))
80 changes: 80 additions & 0 deletions tests/2d/numpy/sum.py.exp
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
array([0, 1, 2, ..., 9, 10, 11], dtype=uint8)
array([[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11]], dtype=uint8)

66
66
array([66], dtype=uint8)

66
array([12, 15, 18, 21], dtype=uint8)
array([6, 22, 38], dtype=uint8)
array([[12, 15, 18, 21]], dtype=uint8)
array([[6],
[22],
[38]], dtype=uint8)
array([0, 1, 2, ..., 9, 10, 11], dtype=int8)
array([[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11]], dtype=int8)

66
66
array([66], dtype=int8)

66
array([12, 15, 18, 21], dtype=int8)
array([6, 22, 38], dtype=int8)
array([[12, 15, 18, 21]], dtype=int8)
array([[6],
[22],
[38]], dtype=int8)
array([0, 1, 2, ..., 9, 10, 11], dtype=uint16)
array([[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11]], dtype=uint16)

66
66
array([66], dtype=uint16)

66
array([12, 15, 18, 21], dtype=uint16)
array([6, 22, 38], dtype=uint16)
array([[12, 15, 18, 21]], dtype=uint16)
array([[6],
[22],
[38]], dtype=uint16)
array([0, 1, 2, ..., 9, 10, 11], dtype=int8)
array([[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11]], dtype=int8)

66
66
array([66], dtype=int8)

66
array([12, 15, 18, 21], dtype=int8)
array([6, 22, 38], dtype=int8)
array([[12, 15, 18, 21]], dtype=int8)
array([[6],
[22],
[38]], dtype=int8)
array([0.0, 1.0, 2.0, ..., 9.0, 10.0, 11.0], dtype=float64)
array([[0.0, 1.0, 2.0, 3.0],
[4.0, 5.0, 6.0, 7.0],
[8.0, 9.0, 10.0, 11.0]], dtype=float64)

66.0
66.0
array([66.0], dtype=float64)

66.0
array([12.0, 15.0, 18.0, 21.0], dtype=float64)
array([6.0, 22.0, 38.0], dtype=float64)
array([[12.0, 15.0, 18.0, 21.0]], dtype=float64)
array([[6.0],
[22.0],
[38.0]], dtype=float64)
Loading