Skip to content

Commit

Permalink
Fix float macro, add tests (#837)
Browse files Browse the repository at this point in the history
  • Loading branch information
lkarthee authored Jan 24, 2024
1 parent 74eaa17 commit 4c181ca
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 23 deletions.
52 changes: 29 additions & 23 deletions native/explorer/src/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use polars::export::arrow::array::Utf8Array;
use polars::prelude::*;
use polars_ops::chunked_array::cov::{cov, pearson_corr};
use polars_ops::prelude::peaks::*;
use rustler::{Binary, Encoder, Env, ListIterator, Term, TermType};
use rustler::{Binary, Encoder, Env, Error, ListIterator, NifResult, Term, TermType};
use std::{result::Result, slice};

pub mod log;
Expand Down Expand Up @@ -53,32 +53,38 @@ from_list!(s_from_list_str, String);
macro_rules! from_list_float {
($name:ident, $type:ty, $module:ident) => {
#[rustler::nif(schedule = "DirtyCpu")]
pub fn $name(name: &str, val: Term) -> ExSeries {
pub fn $name(name: &str, val: Term) -> NifResult<ExSeries> {
let nan = atoms::nan();
let infinity = atoms::infinity();
let neg_infinity = atoms::neg_infinity();

ExSeries::new(Series::new(
name,
val.decode::<ListIterator>()
.unwrap()
.map(|item| match item.get_type() {
TermType::Number => Some(item.decode::<$type>().unwrap()),
TermType::Atom => {
if nan.eq(&item) {
Some($module::NAN)
} else if infinity.eq(&item) {
Some($module::INFINITY)
} else if neg_infinity.eq(&item) {
Some($module::NEG_INFINITY)
} else {
None
}
}
term_type => panic!("from_list/2 not implemented for {term_type:?}"),
})
.collect::<Vec<Option<$type>>>(),
))
let values: NifResult<Vec<Option<$type>>> = val
.decode::<ListIterator>()?
.map(|item| match item.get_type() {
TermType::Number => item.decode::<Option<$type>>(),
TermType::Atom => Ok(if nan.eq(&item) {
Some($module::NAN)
} else if infinity.eq(&item) {
Some($module::INFINITY)
} else if neg_infinity.eq(&item) {
Some($module::NEG_INFINITY)
} else {
None
}),
term_type => {
let message = format!("from_list/2 not implemented for {term_type:?}");
Err(Error::RaiseTerm(Box::new(message)))
}
})
.collect::<NifResult<Vec<Option<$type>>>>();

match (values) {
Ok(x) => {
let s = Series::new(name, x);
Ok(ExSeries::new(s))
}
Err(x) => Err(x),
}
}
};
}
Expand Down
10 changes: 10 additions & 0 deletions test/explorer/series_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,16 @@ defmodule Explorer.SeriesTest do
assert Series.dtype(s) == {:f, 64}
end

test "float 32 overflow" do
assert_raise ArgumentError,
"argument error",
fn ->
Series.from_list([1_055_028_234_663_852_885_981_170_418_348_451_692_544.0],
dtype: {:f, 32}
)
end
end

test "with nan" do
s = Series.from_list([:nan, :nan, :nan])
assert Series.to_list(s) === [:nan, :nan, :nan]
Expand Down

0 comments on commit 4c181ca

Please sign in to comment.