From 4c181ca70c47daebe07190781f5b98a2e566b329 Mon Sep 17 00:00:00 2001 From: lkarthee Date: Wed, 24 Jan 2024 20:55:41 +0530 Subject: [PATCH] Fix float macro, add tests (#837) --- native/explorer/src/series.rs | 52 +++++++++++++++++++---------------- test/explorer/series_test.exs | 10 +++++++ 2 files changed, 39 insertions(+), 23 deletions(-) diff --git a/native/explorer/src/series.rs b/native/explorer/src/series.rs index 83cee9e5b..ca2ff8eef 100644 --- a/native/explorer/src/series.rs +++ b/native/explorer/src/series.rs @@ -12,7 +12,7 @@ use polars::export::arrow::array::Utf8Array; use polars::prelude::*; use polars_ops::chunked_array::cov::{cov, pearson_corr}; use polars_ops::prelude::peaks::*; -use rustler::{Binary, Encoder, Env, ListIterator, Term, TermType}; +use rustler::{Binary, Encoder, Env, Error, ListIterator, NifResult, Term, TermType}; use std::{result::Result, slice}; pub mod log; @@ -53,32 +53,38 @@ from_list!(s_from_list_str, String); macro_rules! from_list_float { ($name:ident, $type:ty, $module:ident) => { #[rustler::nif(schedule = "DirtyCpu")] - pub fn $name(name: &str, val: Term) -> ExSeries { + pub fn $name(name: &str, val: Term) -> NifResult { let nan = atoms::nan(); let infinity = atoms::infinity(); let neg_infinity = atoms::neg_infinity(); - ExSeries::new(Series::new( - name, - val.decode::() - .unwrap() - .map(|item| match item.get_type() { - TermType::Number => Some(item.decode::<$type>().unwrap()), - TermType::Atom => { - if nan.eq(&item) { - Some($module::NAN) - } else if infinity.eq(&item) { - Some($module::INFINITY) - } else if neg_infinity.eq(&item) { - Some($module::NEG_INFINITY) - } else { - None - } - } - term_type => panic!("from_list/2 not implemented for {term_type:?}"), - }) - .collect::>>(), - )) + let values: NifResult>> = val + .decode::()? + .map(|item| match item.get_type() { + TermType::Number => item.decode::>(), + TermType::Atom => Ok(if nan.eq(&item) { + Some($module::NAN) + } else if infinity.eq(&item) { + Some($module::INFINITY) + } else if neg_infinity.eq(&item) { + Some($module::NEG_INFINITY) + } else { + None + }), + term_type => { + let message = format!("from_list/2 not implemented for {term_type:?}"); + Err(Error::RaiseTerm(Box::new(message))) + } + }) + .collect::>>>(); + + match (values) { + Ok(x) => { + let s = Series::new(name, x); + Ok(ExSeries::new(s)) + } + Err(x) => Err(x), + } } }; } diff --git a/test/explorer/series_test.exs b/test/explorer/series_test.exs index 4756f5848..890dcdb47 100644 --- a/test/explorer/series_test.exs +++ b/test/explorer/series_test.exs @@ -64,6 +64,16 @@ defmodule Explorer.SeriesTest do assert Series.dtype(s) == {:f, 64} end + test "float 32 overflow" do + assert_raise ArgumentError, + "argument error", + fn -> + Series.from_list([1_055_028_234_663_852_885_981_170_418_348_451_692_544.0], + dtype: {:f, 32} + ) + end + end + test "with nan" do s = Series.from_list([:nan, :nan, :nan]) assert Series.to_list(s) === [:nan, :nan, :nan]