Skip to content

Commit

Permalink
ENH: Add partition/rpartition ufunc for string dtypes
Browse files Browse the repository at this point in the history
Closes numpy#25993.
  • Loading branch information
lysnikolaou committed Mar 21, 2024
1 parent 7f1c8cb commit ff7c91d
Show file tree
Hide file tree
Showing 7 changed files with 465 additions and 45 deletions.
10 changes: 10 additions & 0 deletions numpy/_core/code_generators/generate_umath.py
Original file line number Diff line number Diff line change
Expand Up @@ -1300,6 +1300,16 @@ def english_upper(s):
docstrings.get('numpy._core.umath._zfill'),
None,
),
'_partition':
Ufunc(2, 3, None,
docstrings.get('numpy._core.umath._partition'),
None,
),
'_rpartition':
Ufunc(2, 3, None,
docstrings.get('numpy._core.umath._rpartition'),
None,
),
}

def indent(st, spaces):
Expand Down
82 changes: 82 additions & 0 deletions numpy/_core/code_generators/ufunc_docstrings.py
Original file line number Diff line number Diff line change
Expand Up @@ -5028,3 +5028,85 @@ def add_newdoc(place, name, doc):
array(['001', '-01', '+01'], dtype='<U3')
""")

add_newdoc('numpy._core.umath', '_partition',
"""
Partition each element in ``x1`` around ``x2``.
For each element in ``x1``, split the element at the first
occurrence of ``x2``, and return a 3-tuple containing the part
before the separator, a boolean signifying whether the separator
was found, and the part after the separator. If the separator is
not found, the first part will contain the whole string,
the boolean will be false, and the third part will be the empty
string.
Parameters
----------
x1 : array-like, with ``StringDType``, ``bytes_``, or ``str_`` dtype
Input array
x2 : array-like, with ``StringDType``, ``bytes_``, or ``str_`` dtype
Separator to split each string element in ``x1``.
Returns
-------
out : 3-tuple:
- ``StringDType``, ``bytes_`` or ``str_`` dtype string with the part
before the separator
- ``bool_`` dtype, whether the separator was found
- ``StringDType``, ``bytes_`` or ``str_`` dtype string with the part
after the separator
See Also
--------
str.partition
Examples
--------
>>> x = np.array(["Numpy is nice!"])
>>> np.strings.partition(x, " ")
array([['Numpy', ' ', 'is nice!']], dtype='<U8')
""")

add_newdoc('numpy._core.umath', '_rpartition',
"""
Partition (split) each element around the right-most separator.
For each element in ``x1``, split the element at the first
occurrence of ``x2``, and return a 3-tuple containing the part
before the separator, a boolean signifying whether the separator
was found, and the part after the separator. If the separator is
not found, the first part will contain the whole string,
the boolean will be false, and the third part will be the empty
string.
Parameters
----------
x1 : array-like, with ``StringDType``, ``bytes_``, or ``str_`` dtype
Input array
x2 : array-like, with ``StringDType``, ``bytes_``, or ``str_`` dtype
Separator to split each string element in ``x1``.
Returns
-------
out : 3-tuple:
- ``StringDType``, ``bytes_`` or ``str_`` dtype string with the part
before the separator
- ``bool_`` dtype, whether the separator was found
- ``StringDType``, ``bytes_`` or ``str_`` dtype string with the part
after the separator
See Also
--------
str.rpartition
Examples
--------
>>> a = np.array(['aAaAaA', ' aA ', 'abBABba'])
>>> np.strings.rpartition(a, 'A')
array([['aAaAa', 'A', ''],
[' a', 'A', ' '],
['abB', 'A', 'Bba']], dtype='<U5')
""")
60 changes: 60 additions & 0 deletions numpy/_core/src/umath/string_buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -1593,4 +1593,64 @@ string_zfill(Buffer<enc> buf, npy_int64 width, Buffer<enc> out)
}


template <ENCODING enc>
static inline npy_bool
string_partition(Buffer<enc> buf1, Buffer<enc> buf2,
Buffer<enc> out1, Buffer<enc> out2,
npy_intp *final_len1, npy_intp *final_len2,
STARTPOSITION pos)
{
size_t len1 = buf1.num_codepoints();
size_t len2 = buf2.num_codepoints();

if (len2 == 0) {
npy_gil_error(PyExc_ValueError, "empty separator");
*final_len1 = *final_len2 = -1;
return false;
}

if (len1 < len2) {
buf1.buffer_memcpy(out1, len1);
*final_len1 = len1;
*final_len2 = 0;
return false;
}

npy_intp idx;
switch(enc) {
case ENCODING::UTF8:
assert(0); // TODO
break;
case ENCODING::ASCII:
idx = fastsearch(buf1.buf, len1, buf2.buf, len2, -1,
pos == STARTPOSITION::FRONT ? FAST_SEARCH : FAST_RSEARCH);
break;
case ENCODING::UTF32:
idx = fastsearch((npy_ucs4 *)buf1.buf, len1, (npy_ucs4 *)buf2.buf, len2, -1,
pos == STARTPOSITION::FRONT ? FAST_SEARCH : FAST_RSEARCH);
break;
}

if (idx < 0) {
if (pos == STARTPOSITION::FRONT) {
buf1.buffer_memcpy(out1, len1);
*final_len1 = len1;
*final_len2 = 0;
}
else {
buf1.buffer_memcpy(out2, len1);
*final_len1 = 0;
*final_len2 = len1;
}
return false;
}

buf1.buffer_memcpy(out1, idx);
*final_len1 = idx;
(buf1 + idx + len2).buffer_memcpy(out2, len1 - idx - len2);
*final_len2 = len1 - idx - len2;
return true;
}


#endif /* _NPY_CORE_SRC_UMATH_STRING_BUFFER_H_ */
150 changes: 150 additions & 0 deletions numpy/_core/src/umath/string_ufuncs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,52 @@ string_zfill_loop(PyArrayMethod_Context *context,
}


template <ENCODING enc>
static int
string_partition_loop(PyArrayMethod_Context *context,
char *const data[], npy_intp const dimensions[],
npy_intp const strides[], NpyAuxData *NPY_UNUSED(auxdata))
{
STARTPOSITION startposition = *(STARTPOSITION *)(context->method->static_data);
int elsize1 = context->descriptors[0]->elsize;
int elsize2 = context->descriptors[1]->elsize;
int outsize1 = context->descriptors[2]->elsize;
int outsize2 = context->descriptors[4]->elsize;

char *in1 = data[0];
char *in2 = data[1];
char *out1 = data[2];
char *out2 = data[3];
char *out3 = data[4];

npy_intp N = dimensions[0];

while (N--) {
Buffer<enc> buf1(in1, elsize1);
Buffer<enc> buf2(in2, elsize2);
Buffer<enc> outbuf1(out1, outsize1);
Buffer<enc> outbuf2(out3, outsize2);

npy_intp final_len1, final_len2;
*(npy_bool *) out2 = string_partition(buf1, buf2, outbuf1, outbuf2,
&final_len1, &final_len2, startposition);
if (final_len1 < 0 || final_len2 < 0) {
return -1;
}
outbuf1.buffer_fill_with_zeros_after_index(final_len1);
outbuf2.buffer_fill_with_zeros_after_index(final_len2);

in1 += strides[0];
in2 += strides[1];
out1 += strides[2];
out2 += strides[3];
out3 += strides[4];
}

return 0;
}


/* Resolve descriptors & promoter functions */

static NPY_CASTING
Expand Down Expand Up @@ -947,6 +993,82 @@ string_zfill_resolve_descriptors(
}


static int
string_partition_promoter(PyObject *NPY_UNUSED(ufunc),
PyArray_DTypeMeta *op_dtypes[], PyArray_DTypeMeta *signature[],
PyArray_DTypeMeta *new_op_dtypes[])
{
Py_INCREF(op_dtypes[0]);
new_op_dtypes[0] = op_dtypes[0];
Py_INCREF(op_dtypes[1]);
new_op_dtypes[1] = op_dtypes[1];


Py_INCREF(op_dtypes[0]);
new_op_dtypes[2] = op_dtypes[0];
new_op_dtypes[3] = NPY_DT_NewRef(&PyArray_BoolDType);
Py_INCREF(op_dtypes[4]);
new_op_dtypes[4] = op_dtypes[0];
return 0;
}


static NPY_CASTING
string_partition_resolve_descriptors(
PyArrayMethodObject *NPY_UNUSED(self),
PyArray_DTypeMeta *NPY_UNUSED(dtypes[3]),
PyArray_Descr *given_descrs[3],
PyArray_Descr *loop_descrs[3],
npy_intp *NPY_UNUSED(view_offset))
{
if (given_descrs[2] == NULL) {
PyErr_SetString(
PyExc_TypeError,
"The 'out' kwarg is necessary. Use numpy.strings without it.");
return _NPY_ERROR_OCCURRED_IN_CAST;
}

if (given_descrs[4] == NULL) {
PyErr_SetString(
PyExc_TypeError,
"The 'out' kwarg is necessary. Use numpy.strings without it.");
return _NPY_ERROR_OCCURRED_IN_CAST;
}

loop_descrs[0] = NPY_DT_CALL_ensure_canonical(given_descrs[0]);
if (loop_descrs[0] == NULL) {
return _NPY_ERROR_OCCURRED_IN_CAST;
}

loop_descrs[1] = NPY_DT_CALL_ensure_canonical(given_descrs[1]);
if (loop_descrs[1] == NULL) {
return _NPY_ERROR_OCCURRED_IN_CAST;
}

loop_descrs[2] = NPY_DT_CALL_ensure_canonical(given_descrs[2]);
if (loop_descrs[2] == NULL) {
return _NPY_ERROR_OCCURRED_IN_CAST;
}

if (loop_descrs[3] == NULL) {
loop_descrs[3] = PyArray_DescrFromType(NPY_BOOL);
}
else {
loop_descrs[3] = NPY_DT_CALL_ensure_canonical(given_descrs[3]);
}
if (loop_descrs[3] == NULL) {
return _NPY_ERROR_OCCURRED_IN_CAST;
}

loop_descrs[4] = NPY_DT_CALL_ensure_canonical(given_descrs[4]);
if (loop_descrs[4] == NULL) {
return _NPY_ERROR_OCCURRED_IN_CAST;
}

return NPY_NO_CASTING;
}


/*
* Machinery to add the string loops to the existing ufuncs.
*/
Expand Down Expand Up @@ -1599,6 +1721,34 @@ init_string_ufuncs(PyObject *umath)
return -1;
}

dtypes[0] = dtypes[1] = dtypes[2] = dtypes[4] = NPY_OBJECT;
dtypes[3] = NPY_BOOL;

const char *partition_names[] = {"_partition", "_rpartition"};

static STARTPOSITION partition_startpositions[] = {
STARTPOSITION::FRONT, STARTPOSITION::BACK
};

for (int i = 0; i < 2; i++) {
if (init_ufunc(
umath, partition_names[i], 2, 3, dtypes, ENCODING::ASCII,
string_partition_loop<ENCODING::ASCII>,
string_partition_resolve_descriptors, &partition_startpositions[i]) < 0) {
return -1;
}
if (init_ufunc(
umath, partition_names[i], 2, 3, dtypes, ENCODING::UTF32,
string_partition_loop<ENCODING::UTF32>,
string_partition_resolve_descriptors, &partition_startpositions[i]) < 0) {
return -1;
}
if (init_promoter(umath, partition_names[i], 2, 3,
string_partition_promoter) < 0) {
return -1;
}
}

return 0;
}

Expand Down
Loading

0 comments on commit ff7c91d

Please sign in to comment.