Skip to content

Commit

Permalink
Merge pull request #7 from espdev/rustfmt
Browse files Browse the repository at this point in the history
Apply rustfmt
  • Loading branch information
espdev authored Oct 12, 2024
2 parents aa061c0 + ff7d5ec commit 614df20
Show file tree
Hide file tree
Showing 20 changed files with 497 additions and 550 deletions.
3 changes: 1 addition & 2 deletions src/errors.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use thiserror::Error;
use ndarray::ShapeError;

use thiserror::Error;

/// Enum provides error types
#[derive(Error, Debug)]
Expand Down
22 changes: 10 additions & 12 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,43 +120,41 @@
//!
mod errors;
mod traits;
mod ndarrayext;
mod ndg;
mod sprsext;
mod validate;
mod util;
mod traits;
mod umv;
mod ndg;
mod util;
mod validate;

use std::result;

/// Provides result type for `make` and `evaluate` methods
pub type Result<T> = result::Result<T, errors::CsapsError>;

pub use errors::CsapsError;
pub use ndg::{GridCubicSmoothingSpline, NdGridSpline};
pub use traits::{Real, RealRef};
pub use umv::{NdSpline, CubicSmoothingSpline};
pub use ndg::{NdGridSpline, GridCubicSmoothingSpline};

pub use umv::{CubicSmoothingSpline, NdSpline};

// #[cfg(test)]
// mod tests {
// use crate::CubicSmoothingSpline;
// use crate::CubicSmoothingSpline;
// use ndarray::prelude::*;

// #[test]
// fn test_new() {
// fn test_new() {

// let zeros = Array1::<f64>::zeros(1);
// let zeros = Array1::<f64>::zeros(1);

// let x = zeros.view();
// let zeros = Array2::<f64>::zeros((1,1));
// let y = zeros.view();


// let sp = CubicSmoothingSpline::new(x.view(), y.view())
// // .with_optional_weights(weights)
// // .with_optional_smooth(s)
// .make();
// }
// }
// }
189 changes: 103 additions & 86 deletions src/ndarrayext.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
use ndarray::{prelude::*, IntoDimension, Slice};
use itertools::Itertools;

use ndarray::{prelude::*, IntoDimension, Slice};

use crate::{
Result,
util::dim_from_vec,
CsapsError::{ReshapeFrom2d, ReshapeTo2d},
util::dim_from_vec, Real
Real, Result,
};


pub fn diff<'a, T: 'a, D, V>(data: V, axis: Option<Axis>) -> Array<T, D>
where
T: Real<T>,
D: Dimension,
V: AsArray<'a, T, D>
V: AsArray<'a, T, D>,
{
let data_view = data.into();
let axis = axis.unwrap_or_else(|| Axis(data_view.ndim() - 1));
Expand All @@ -24,11 +22,10 @@ where
&tail - &head
}


pub fn to_2d<'a, T: 'a, D, I>(data: I, axis: Axis) -> Result<ArrayView2<'a, T>>
where
D: Dimension,
I: AsArray<'a, T, D>,
where
D: Dimension,
I: AsArray<'a, T, D>,
{
let data_view = data.into();
let ndim = data_view.ndim();
Expand All @@ -48,45 +45,43 @@ pub fn to_2d<'a, T: 'a, D, I>(data: I, axis: Axis) -> Result<ArrayView2<'a, T>>

match data_view.permuted_axes(axes).into_shape(new_shape) {
Ok(view_2d) => Ok(view_2d),
Err(error) => Err(
ReshapeTo2d {
input_shape: shape,
output_shape: new_shape.to_vec(),
axis: axis.0,
source: error,
}
)
Err(error) => Err(ReshapeTo2d {
input_shape: shape,
output_shape: new_shape.to_vec(),
axis: axis.0,
source: error,
}),
}
}


pub fn to_2d_simple<'a, T: 'a, D>(data: ArrayView<'a, T, D>) -> Result<ArrayView2<'a, T>>
where
D: Dimension
where
D: Dimension,
{
let ndim = data.ndim();
let shape = data.shape().to_vec();
let new_shape = [shape[0..(ndim - 1)].iter().product(), shape[ndim - 1]];

match data.into_shape(new_shape) {
Ok(data_2d) => Ok(data_2d),
Err(error) => Err(
ReshapeTo2d {
input_shape: shape,
output_shape: new_shape.to_vec(),
axis: ndim - 1,
source: error,
}
)
Err(error) => Err(ReshapeTo2d {
input_shape: shape,
output_shape: new_shape.to_vec(),
axis: ndim - 1,
source: error,
}),
}
}


pub fn from_2d<'a, T: 'a, D, S, I>(data: I, shape: S, axis: Axis) -> Result<ArrayView<'a, T, S::Dim>>
where
D: Dimension,
S: IntoDimension<Dim = D>,
I: AsArray<'a, T, Ix2>,
pub fn from_2d<'a, T: 'a, D, S, I>(
data: I,
shape: S,
axis: Axis,
) -> Result<ArrayView<'a, T, S::Dim>>
where
D: Dimension,
S: IntoDimension<Dim = D>,
I: AsArray<'a, T, Ix2>,
{
let shape = shape.into_dimension();
let ndim = shape.ndim();
Expand All @@ -106,39 +101,37 @@ pub fn from_2d<'a, T: 'a, D, S, I>(data: I, shape: S, axis: Axis) -> Result<Arra

let axes: D = dim_from_vec(ndim, axes_tmp);
Ok(view_nd.permuted_axes(axes))
},
Err(error) => Err(
ReshapeFrom2d {
input_shape: data_view.shape().to_vec(),
output_shape: new_shape_vec,
axis: axis.0,
source: error,
}
)
}
Err(error) => Err(ReshapeFrom2d {
input_shape: data_view.shape().to_vec(),
output_shape: new_shape_vec,
axis: axis.0,
source: error,
}),
}
}


/// Returns the indices of the bins to which each value in input array belongs
///
/// This code works if `bins` is increasing
pub fn digitize<'a, T: 'a, A, B>(arr: A, bins: B) -> Array1<usize>
where
T: Real<T>,
// T: Clone + NdFloat + AlmostEqual,

A: AsArray<'a, T, Ix1>,
B: AsArray<'a, T, Ix1>,
where
T: Real<T>,
// T: Clone + NdFloat + AlmostEqual,
A: AsArray<'a, T, Ix1>,
B: AsArray<'a, T, Ix1>,
{
let arr_view = arr.into();
let bins_view = bins.into();

let mut indices = Array1::zeros((arr_view.len(),));
let mut kstart: usize = 0;

for (i, &a) in arr_view.iter().enumerate()
.sorted_by(|e1, e2| e1.1.partial_cmp(e2.1).unwrap()) {

for (i, &a) in arr_view
.iter()
.enumerate()
.sorted_by(|e1, e2| e1.1.partial_cmp(e2.1).unwrap())
{
let mut k = kstart;

for bins_win in bins_view.slice(s![kstart..]).windows(2) {
Expand All @@ -158,53 +151,55 @@ pub fn digitize<'a, T: 'a, A, B>(arr: A, bins: B) -> Array1<usize>
indices
}


#[cfg(test)]
mod tests {
use std::f64;
use ndarray::{array, Array1, Axis, Ix1, Ix2, Ix3};
use crate::ndarrayext::*;
use ndarray::{array, Array1, Axis, Ix1, Ix2, Ix3};
use std::f64;

#[test]
fn test_diff_1d() {
let a = array![1., 2., 3., 4., 5.];

assert_eq!(diff(&a, None),
array![1., 1., 1., 1.]);
assert_eq!(diff(&a, None), array![1., 1., 1., 1.]);

assert_eq!(diff(&a, Some(Axis(0))),
array![1., 1., 1., 1.]);
assert_eq!(diff(&a, Some(Axis(0))), array![1., 1., 1., 1.]);
}

#[test]
fn test_diff_2d() {
let a = array![[1., 2., 3., 4.], [1., 2., 3., 4.]];

assert_eq!(diff(&a, None),
array![[1., 1., 1.], [1., 1., 1.]]);
assert_eq!(diff(&a, None), array![[1., 1., 1.], [1., 1., 1.]]);

assert_eq!(diff(&a, Some(Axis(0))),
array![[0., 0., 0., 0.]]);
assert_eq!(diff(&a, Some(Axis(0))), array![[0., 0., 0., 0.]]);

assert_eq!(diff(&a, Some(Axis(1))),
array![[1., 1., 1.], [1., 1., 1.]]);
assert_eq!(diff(&a, Some(Axis(1))), array![[1., 1., 1.], [1., 1., 1.]]);
}

#[test]
fn test_diff_3d() {
let a = array![[[1., 2., 3.], [1., 2., 3.]], [[1., 2., 3.], [1., 2., 3.]]];

assert_eq!(diff(&a, None),
array![[[1., 1.], [1., 1.]], [[1., 1.], [1., 1.]]]);

assert_eq!(diff(&a, Some(Axis(0))),
array![[[0., 0., 0.], [0., 0., 0.]]]);

assert_eq!(diff(&a, Some(Axis(1))),
array![[[0., 0., 0.]], [[0., 0., 0.]]]);

assert_eq!(diff(&a, Some(Axis(2))),
array![[[1., 1.], [1., 1.]], [[1., 1.], [1., 1.]]]);
assert_eq!(
diff(&a, None),
array![[[1., 1.], [1., 1.]], [[1., 1.], [1., 1.]]]
);

assert_eq!(
diff(&a, Some(Axis(0))),
array![[[0., 0., 0.], [0., 0., 0.]]]
);

assert_eq!(
diff(&a, Some(Axis(1))),
array![[[0., 0., 0.]], [[0., 0., 0.]]]
);

assert_eq!(
diff(&a, Some(Axis(2))),
array![[[1., 1.], [1., 1.]], [[1., 1.], [1., 1.]]]
);
}

#[test]
Expand All @@ -218,8 +213,14 @@ mod tests {
fn test_to_2d_from_2d() {
let a = array![[1, 2, 3, 4], [5, 6, 7, 8]];

assert_eq!(to_2d(&a, Axis(0)).unwrap(), array![[1, 5], [2, 6], [3, 7], [4, 8]]);
assert_eq!(to_2d(&a, Axis(1)).unwrap(), array![[1, 2, 3, 4], [5, 6, 7, 8]]);
assert_eq!(
to_2d(&a, Axis(0)).unwrap(),
array![[1, 5], [2, 6], [3, 7], [4, 8]]
);
assert_eq!(
to_2d(&a, Axis(1)).unwrap(),
array![[1, 2, 3, 4], [5, 6, 7, 8]]
);
}

#[test]
Expand All @@ -229,7 +230,10 @@ mod tests {
// FIXME: incompatible memory layout
// assert_eq!(to_2d(&a, Axis(0)).unwrap(), array![[1, 7], [2, 8], [3, 9], [4, 10], [5, 11], [6, 12]]);
// assert_eq!(to_2d(&a, Axis(1)).unwrap(), array![[1, 4], [2, 5], [3, 6], [7, 10], [8, 11], [9, 12]]);
assert_eq!(to_2d(&a, Axis(2)).unwrap(), array![[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]);
assert_eq!(
to_2d(&a, Axis(2)).unwrap(),
array![[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]
);
}

#[test]
Expand All @@ -241,13 +245,19 @@ mod tests {
#[test]
fn test_to_2d_simple_from_2d() {
let a = array![[1, 2, 3, 4], [5, 6, 7, 8]];
assert_eq!(to_2d_simple(a.view()).unwrap(), array![[1, 2, 3, 4], [5, 6, 7, 8]]);
assert_eq!(
to_2d_simple(a.view()).unwrap(),
array![[1, 2, 3, 4], [5, 6, 7, 8]]
);
}

#[test]
fn test_to_2d_simple_from_3d() {
let a = array![[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]];
assert_eq!(to_2d_simple(a.view()).unwrap(), array![[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]);
assert_eq!(
to_2d_simple(a.view()).unwrap(),
array![[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]
);
}

#[test]
Expand All @@ -258,7 +268,8 @@ mod tests {

let r = from_2d(&a, s, Axis(2))
.unwrap()
.into_dimensionality::<Ix3>().unwrap();
.into_dimensionality::<Ix3>()
.unwrap();

assert_eq!(r, e);
}
Expand Down Expand Up @@ -369,7 +380,10 @@ mod tests {

let indices = digitize(&xi, &edges);

assert_eq!(indices, array![1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3])
assert_eq!(
indices,
array![1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3]
)
}

#[test]
Expand All @@ -389,6 +403,9 @@ mod tests {

let indices = digitize(&xi, &edges);

assert_eq!(indices, array![0, 1, 0, 2, 2, 1, 0, 3, 4, 4, 3, 3, 2, 2, 1, 0])
assert_eq!(
indices,
array![0, 1, 0, 2, 2, 1, 0, 3, 4, 4, 3, 3, 2, 2, 1, 0]
)
}
}
Loading

0 comments on commit 614df20

Please sign in to comment.