Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: Improve DataFrame.sort().limit/top_k performance #19731

Merged
merged 4 commits into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ pub fn _arg_bottom_k(
_broadcast_bools(by_column.len(), &mut sort_options.descending);
_broadcast_bools(by_column.len(), &mut sort_options.nulls_last);

// Don't go into row encoding.
if by_column.len() == 1 && sort_options.limit.is_some() && !sort_options.maintain_order {
return Ok(NoNull::new(by_column[0].arg_sort((&*sort_options).into())));
}

let encoded = _get_rows_encoded(
by_column,
&sort_options.descending,
Expand Down
64 changes: 58 additions & 6 deletions crates/polars-core/src/chunked_array/ops/sort/arg_sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ pub(super) fn arg_sort<I, J, T>(
iters: I,
options: SortOptions,
null_count: usize,
len: usize,
mut len: usize,
) -> IdxCa
where
I: IntoIterator<Item = J>,
Expand Down Expand Up @@ -49,14 +49,46 @@ where
vals.extend(iter);
}

sort_impl(vals.as_mut_slice(), options);
let vals = if let Some((limit, desc)) = options.limit {
let limit = limit as usize;
// Overwrite output len.
len = limit;
let out = if limit >= vals.len() {
vals.as_mut_slice()
} else if desc {
let (lower, _el, _upper) = vals
.as_mut_slice()
.select_nth_unstable_by(limit, |a, b| b.1.tot_cmp(&a.1));
lower
} else {
let (lower, _el, _upper) = vals
.as_mut_slice()
.select_nth_unstable_by(limit, |a, b| a.1.tot_cmp(&b.1));
lower
};

sort_impl(out, options);
out
} else {
sort_impl(vals.as_mut_slice(), options);
vals.as_slice()
};

let iter = vals.into_iter().map(|(idx, _v)| idx);
let iter = vals.iter().map(|(idx, _v)| idx).copied();
let idx = if nulls_last {
let mut idx = Vec::with_capacity(len);
idx.extend(iter);
idx.extend(nulls_idx);

let nulls_idx = if options.limit.is_some() {
&nulls_idx[..len - idx.len()]
} else {
&nulls_idx
};
idx.extend_from_slice(nulls_idx);
idx
} else if options.limit.is_some() {
nulls_idx.extend(iter.take(len - nulls_idx.len()));
nulls_idx
} else {
let ptr = nulls_idx.as_ptr() as usize;
nulls_idx.extend(iter);
Expand Down Expand Up @@ -90,9 +122,29 @@ where
}));
}

sort_impl(vals.as_mut_slice(), options);
let vals = if let Some((limit, desc)) = options.limit {
let limit = limit as usize;
let out = if limit >= vals.len() {
vals.as_mut_slice()
} else if desc {
let (lower, _el, _upper) = vals
.as_mut_slice()
.select_nth_unstable_by(limit, |a, b| b.1.tot_cmp(&a.1));
lower
} else {
let (lower, _el, _upper) = vals
.as_mut_slice()
.select_nth_unstable_by(limit, |a, b| a.1.tot_cmp(&b.1));
lower
};
sort_impl(out, options);
out
} else {
sort_impl(vals.as_mut_slice(), options);
vals.as_slice()
};

let iter = vals.into_iter().map(|(idx, _v)| idx);
let iter = vals.iter().map(|(idx, _v)| idx).copied();
let idx: Vec<_> = iter.collect_trusted();

ChunkedArray::with_chunk(name, IdxArr::from_data_default(Buffer::from(idx), None))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ impl CategoricalChunked {
descending,
multithreaded: true,
maintain_order: false,
limit: None,
})
}

Expand Down
10 changes: 10 additions & 0 deletions crates/polars-core/src/chunked_array/ops/sort/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ impl ChunkSort<StringType> for StringChunked {
nulls_last: false,
multithreaded: true,
maintain_order: false,
limit: None,
})
}

Expand Down Expand Up @@ -406,6 +407,7 @@ impl ChunkSort<BinaryType> for BinaryChunked {
nulls_last: false,
multithreaded: true,
maintain_order: false,
limit: None,
})
}

Expand Down Expand Up @@ -536,6 +538,7 @@ impl ChunkSort<BinaryOffsetType> for BinaryOffsetChunked {
nulls_last: false,
multithreaded: true,
maintain_order: false,
limit: None,
})
}

Expand Down Expand Up @@ -672,6 +675,7 @@ impl ChunkSort<BooleanType> for BooleanChunked {
nulls_last: false,
multithreaded: true,
maintain_order: false,
limit: None,
})
}

Expand Down Expand Up @@ -797,6 +801,7 @@ mod test {
nulls_last: false,
multithreaded: true,
maintain_order: false,
limit: None,
});
assert_eq!(
Vec::from(&out),
Expand All @@ -816,6 +821,7 @@ mod test {
nulls_last: true,
multithreaded: true,
maintain_order: false,
limit: None,
});
assert_eq!(
Vec::from(&out),
Expand Down Expand Up @@ -925,6 +931,7 @@ mod test {
nulls_last: false,
multithreaded: true,
maintain_order: false,
limit: None,
});
let expected = &[None, None, Some("a"), Some("b"), Some("c")];
assert_eq!(Vec::from(&out), expected);
Expand All @@ -934,6 +941,7 @@ mod test {
nulls_last: false,
multithreaded: true,
maintain_order: false,
limit: None,
});

let expected = &[None, None, Some("c"), Some("b"), Some("a")];
Expand All @@ -944,6 +952,7 @@ mod test {
nulls_last: true,
multithreaded: true,
maintain_order: false,
limit: None,
});
let expected = &[Some("a"), Some("b"), Some("c"), None, None];
assert_eq!(Vec::from(&out), expected);
Expand All @@ -953,6 +962,7 @@ mod test {
nulls_last: true,
multithreaded: true,
maintain_order: false,
limit: None,
});
let expected = &[Some("c"), Some("b"), Some("a"), None, None];
assert_eq!(Vec::from(&out), expected);
Expand Down
12 changes: 12 additions & 0 deletions crates/polars-core/src/chunked_array/ops/sort/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ pub struct SortOptions {
/// If true maintain the order of equal elements.
/// Default `false`.
pub maintain_order: bool,
/// Limit a sort output, this is for optimization purposes and might be ignored.
/// - Len
/// - Descending
pub limit: Option<(IdxSize, bool)>,
}

/// Sort options for multi-series sorting.
Expand Down Expand Up @@ -96,6 +100,10 @@ pub struct SortMultipleOptions {
pub multithreaded: bool,
/// Whether maintain the order of equal elements. Default `false`.
pub maintain_order: bool,
/// Limit a sort output, this is for optimization purposes and might be ignored.
/// - Len
/// - Descending
pub limit: Option<(IdxSize, bool)>,
}

impl Default for SortOptions {
Expand All @@ -105,6 +113,7 @@ impl Default for SortOptions {
nulls_last: false,
multithreaded: true,
maintain_order: false,
limit: None,
}
}
}
Expand All @@ -116,6 +125,7 @@ impl Default for SortMultipleOptions {
nulls_last: vec![false],
multithreaded: true,
maintain_order: false,
limit: None,
}
}
}
Expand Down Expand Up @@ -224,6 +234,7 @@ impl From<&SortOptions> for SortMultipleOptions {
nulls_last: vec![value.nulls_last],
multithreaded: value.multithreaded,
maintain_order: value.maintain_order,
limit: value.limit,
}
}
}
Expand All @@ -235,6 +246,7 @@ impl From<&SortMultipleOptions> for SortOptions {
nulls_last: value.nulls_last.first().copied().unwrap_or(false),
multithreaded: value.multithreaded,
maintain_order: value.maintain_order,
limit: value.limit,
}
}
}
7 changes: 7 additions & 0 deletions crates/polars-core/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1989,6 +1989,12 @@ impl DataFrame {
return Ok(out);
}
if let Some((0, k)) = slice {
let desc = if sort_options.descending.len() == 1 {
sort_options.descending[0]
} else {
false
};
sort_options.limit = Some((k as IdxSize, desc));
return self.bottom_k_impl(k, by_column, sort_options);
}

Expand All @@ -2012,6 +2018,7 @@ impl DataFrame {
nulls_last: sort_options.nulls_last[0],
multithreaded: sort_options.multithreaded,
maintain_order: sort_options.maintain_order,
limit: sort_options.limit,
};
// fast path for a frame with a single series
// no need to compute the sort indices and then take by these indices
Expand Down
2 changes: 2 additions & 0 deletions crates/polars-expr/src/expressions/sortby.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ fn sort_by_groups_multiple_by(
nulls_last: nulls_last.to_owned(),
multithreaded,
maintain_order,
limit: None,
};

let sorted_idx = groups[0]
Expand All @@ -180,6 +181,7 @@ fn sort_by_groups_multiple_by(
nulls_last: nulls_last.to_owned(),
multithreaded,
maintain_order,
limit: None,
};
let sorted_idx = groups[0]
.as_materialized_series()
Expand Down
5 changes: 5 additions & 0 deletions crates/polars-lazy/src/tests/aggregations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,7 @@ fn take_aggregations() -> PolarsResult<()> {
nulls_last: false,
multithreaded: true,
maintain_order: false,
limit: None,
})
.head(Some(2)),
)
Expand Down Expand Up @@ -489,6 +490,7 @@ fn test_take_consistency() -> PolarsResult<()> {
nulls_last: false,
multithreaded: true,
maintain_order: false,
limit: None,
})
.get(lit(0))])
.collect()?;
Expand All @@ -507,6 +509,7 @@ fn test_take_consistency() -> PolarsResult<()> {
nulls_last: false,
multithreaded: true,
maintain_order: false,
limit: None,
})
.get(lit(0))])
.collect()?;
Expand All @@ -526,6 +529,7 @@ fn test_take_consistency() -> PolarsResult<()> {
nulls_last: false,
multithreaded: true,
maintain_order: false,
limit: None,
})
.get(lit(0))
.alias("1"),
Expand All @@ -537,6 +541,7 @@ fn test_take_consistency() -> PolarsResult<()> {
nulls_last: false,
multithreaded: true,
maintain_order: false,
limit: None,
})
.get(lit(0)),
)
Expand Down
1 change: 1 addition & 0 deletions crates/polars-lazy/src/tests/queries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1666,6 +1666,7 @@ fn test_single_group_result() -> PolarsResult<()> {
nulls_last: false,
multithreaded: true,
maintain_order: false,
limit: None,
})
.over([col("a")])])
.collect()?;
Expand Down
1 change: 1 addition & 0 deletions crates/polars-ops/src/chunked_array/top_k.rs
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ fn top_k_by_impl(
nulls_last: vec![true; by.len()],
multithreaded,
maintain_order: false,
limit: None,
};

let idx = _arg_bottom_k(k, by, &mut sort_options)?;
Expand Down
3 changes: 3 additions & 0 deletions crates/polars-ops/src/frame/join/hash_join/sort_merge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ pub(crate) fn _sort_or_hash_inner(
nulls_last: false,
multithreaded: true,
maintain_order: false,
limit: None,
});
let s_right = unsafe { s_right.take_unchecked(&sort_idx) };
let ids = par_sorted_merge_inner_no_nulls(s_left, &s_right);
Expand Down Expand Up @@ -252,6 +253,7 @@ pub(crate) fn _sort_or_hash_inner(
nulls_last: false,
multithreaded: true,
maintain_order: false,
limit: None,
});
let s_left = unsafe { s_left.take_unchecked(&sort_idx) };
let ids = par_sorted_merge_inner_no_nulls(&s_left, s_right);
Expand Down Expand Up @@ -323,6 +325,7 @@ pub(crate) fn sort_or_hash_left(
nulls_last: false,
multithreaded: true,
maintain_order: false,
limit: None,
});
let s_right = unsafe { s_right.take_unchecked(&sort_idx) };

Expand Down
2 changes: 2 additions & 0 deletions crates/polars-pipe/src/executors/sinks/sort/source.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ impl SortSource {
nulls_last: self.nulls_last,
multithreaded: true,
maintain_order: false,
limit: None,
},
),
Some((offset, len)) => {
Expand All @@ -119,6 +120,7 @@ impl SortSource {
nulls_last: self.nulls_last,
multithreaded: true,
maintain_order: false,
limit: None,
},
);
*len = len.saturating_sub(df_len);
Expand Down
Loading
Loading