Skip to content

Commit

Permalink
ergonomics: more readible math code through Mul operator and Borrow t…
Browse files Browse the repository at this point in the history
…rait (#46)

- add Mul operator for group multiplication (on vales and &)
- make use of Borrow trait for generic types (which are not Copy in
general)

**examples:**

```
let true_cam1_from_cam0 =
     true_world_from_cameras[1].inverse() * true_world_from_cameras[0];
 ```
 
vs
 
 ```
let true_cam2_from_cam0 = 
    true_world_from_cameras[2].inverse().group_mul(&true_world_from_cameras[0]);
```

and

```
res_fn(
    Isometry3::<DualScalar, 1>::exp(x) * isometry.to_dual_c()
    Isometry3::from_params(DualVector::from_real_vector(isometry_prior.0.params())),
)
````

vs.

```
res_fn(
    Isometry3::<DualScalar, 1>::exp(&x).group_mul(&isometry.to_dual_c()),
    Isometry3::from_params(&DualVector::from_real_vector(*isometry_prior.0.params())),
)
````
  • Loading branch information
strasdat authored Dec 15, 2024
1 parent 11b2883 commit 82a26b6
Show file tree
Hide file tree
Showing 60 changed files with 1,700 additions and 882 deletions.
4 changes: 2 additions & 2 deletions crates/sophus/src/examples/viewer_example.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ pub fn make_distorted_frame() -> ImageFrame {
let cy = 240.0;

let unified_cam = DynCameraF64::new_unified(
&VecF64::from_array([focal_length, focal_length, cx, cy, 0.629, 1.22]),
VecF64::from_array([focal_length, focal_length, cx, cy, 0.629, 1.22]),
image_size,
);

Expand All @@ -47,7 +47,7 @@ pub fn make_distorted_frame() -> ImageFrame {
for v in 0..image_size.height {
for u in 0..image_size.width {
let uv = VecF64::<2>::new(u as f64, v as f64);
let p_on_z1 = unified_cam.cam_unproj(&uv);
let p_on_z1 = unified_cam.cam_unproj(uv);

if p_on_z1[0].abs() < 0.5 {
*img.mut_pixel(u, v) = SVec::<u8, 4>::new(255, 0, 0, 255);
Expand Down
73 changes: 50 additions & 23 deletions crates/sophus_core/src/calculus/dual/dual_batch_matrix.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::calculus::dual::dual_batch_scalar::DualBatchScalar;
use crate::calculus::dual::dual_matrix::DijPairM;
use crate::calculus::dual::dual_matrix::DijPairMV;
use crate::calculus::dual::DualBatchVector;
use crate::linalg::BatchMatF64;
Expand All @@ -10,6 +11,7 @@ use crate::tensor::mut_tensor::MutTensorDDR;
use crate::tensor::mut_tensor::MutTensorDDRC;
use approx::AbsDiffEq;
use approx::RelativeEq;
use core::borrow::Borrow;
use core::fmt::Debug;
use core::ops::Add;
use core::ops::Mul;
Expand All @@ -20,8 +22,6 @@ use core::simd::Mask;
use core::simd::SupportedLaneCount;
use num_traits::Zero;

use crate::calculus::dual::dual_matrix::DijPairM;

/// DualScalarLike matrix
#[derive(Clone)]
pub struct DualBatchMatrix<const ROWS: usize, const COLS: usize, const BATCH: usize>
Expand Down Expand Up @@ -307,40 +307,53 @@ where
}
}

fn from_scalar(val: DualBatchScalar<BATCH>) -> Self {
fn from_scalar<S>(val: S) -> Self
where
S: Borrow<DualBatchScalar<BATCH>>,
{
let val = val.borrow();

DualBatchMatrix {
real_part: BatchMatF64::<ROWS, COLS, BATCH>::from_scalar(val.real_part),
dij_part: val.dij_part.map(|dij_val| {
dij_part: val.dij_part.clone().map(|dij_val| {
MutTensorDDRC::from_map(&dij_val.view(), |v| {
BatchMatF64::<ROWS, COLS, BATCH>::from_scalar(*v)
})
}),
}
}

fn mat_mul<const COLS2: usize>(
&self,
rhs: DualBatchMatrix<COLS, COLS2, BATCH>,
) -> DualBatchMatrix<ROWS, COLS2, BATCH> {
fn mat_mul<const COLS2: usize, M>(&self, rhs: M) -> DualBatchMatrix<ROWS, COLS2, BATCH>
where
M: Borrow<DualBatchMatrix<COLS, COLS2, BATCH>>,
{
DualBatchMatrix {
real_part: self.real_part * rhs.real_part,
real_part: self.real_part.mat_mul(rhs.borrow().real_part),
dij_part: DualBatchMatrix::<ROWS, COLS2, BATCH>::binary_mm_dij(
&self.dij_part,
&rhs.dij_part,
|l_dij| l_dij * rhs.real_part,
|r_dij| self.real_part * r_dij,
&rhs.borrow().dij_part,
|l_dij| l_dij.mat_mul(rhs.borrow().real_part),
|r_dij| self.real_part.mat_mul(r_dij),
),
}
}

fn from_real_matrix(val: BatchMatF64<ROWS, COLS, BATCH>) -> Self {
fn from_real_matrix<A>(val: A) -> Self
where
A: Borrow<BatchMatF64<ROWS, COLS, BATCH>>,
{
let val = val.borrow();
Self {
real_part: val,
real_part: *val,
dij_part: None,
}
}

fn scaled(&self, s: DualBatchScalar<BATCH>) -> Self {
fn scaled<Q>(&self, s: Q) -> Self
where
Q: Borrow<DualBatchScalar<BATCH>>,
{
let s = s.borrow();
DualBatchMatrix {
real_part: self.real_part * s.real_part,
dij_part: DualBatchMatrix::<ROWS, COLS, BATCH>::binary_ms_dij(
Expand All @@ -366,7 +379,11 @@ where
}
}

fn from_array2(duals: [[DualBatchScalar<BATCH>; COLS]; ROWS]) -> Self {
fn from_array2<A>(duals: A) -> Self
where
A: Borrow<[[DualBatchScalar<BATCH>; COLS]; ROWS]>,
{
let duals = duals.borrow();
let mut shape = None;
let mut val_mat = BatchMatF64::<ROWS, COLS, BATCH>::zeros();
for i in 0..duals.len() {
Expand Down Expand Up @@ -532,14 +549,20 @@ where
}
}

fn from_real_scalar_array2(vals: [[BatchScalarF64<BATCH>; COLS]; ROWS]) -> Self {
fn from_real_scalar_array2<A>(vals: A) -> Self
where
A: Borrow<[[BatchScalarF64<BATCH>; COLS]; ROWS]>,
{
DualBatchMatrix {
real_part: BatchMatF64::from_real_scalar_array2(vals),
dij_part: None,
}
}

fn from_f64_array2(vals: [[f64; COLS]; ROWS]) -> Self {
fn from_f64_array2<A>(vals: A) -> Self
where
A: Borrow<[[f64; COLS]; ROWS]>,
{
DualBatchMatrix {
real_part: BatchMatF64::from_f64_array2(vals),
dij_part: None,
Expand All @@ -555,12 +578,16 @@ where
todo!();
}

fn to_dual(self) -> <DualBatchScalar<BATCH> as IsScalar<BATCH>>::DualMatrix<ROWS, COLS> {
self
fn to_dual(&self) -> <DualBatchScalar<BATCH> as IsScalar<BATCH>>::DualMatrix<ROWS, COLS> {
self.clone()
}

fn select(self, mask: &Mask<i64, BATCH>, other: Self) -> Self {
let maybe_dij = Self::two_dx(self.dij_part, other.dij_part);
fn select<Q>(&self, mask: &Mask<i64, BATCH>, other: Q) -> Self
where
Q: Borrow<Self>,
{
let other = other.borrow();
let maybe_dij = Self::two_dx(self.dij_part.clone(), other.dij_part.clone());

DualBatchMatrix {
real_part: self.real_part.select(mask, other.real_part),
Expand All @@ -581,7 +608,7 @@ where
}
}

fn transposed(self) -> DualBatchMatrix<COLS, ROWS, BATCH> {
fn transposed(&self) -> DualBatchMatrix<COLS, ROWS, BATCH> {
DualBatchMatrix {
real_part: self.real_part.transpose(),
dij_part: self.dij_part.clone().map(|_dij_val| todo!()),
Expand Down
41 changes: 23 additions & 18 deletions crates/sophus_core/src/calculus/dual/dual_batch_scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use alloc::vec::Vec;
use approx::assert_abs_diff_eq;
use approx::AbsDiffEq;
use approx::RelativeEq;
use core::borrow::Borrow;
use core::fmt::Debug;
use core::ops::Add;
use core::ops::AddAssign;
Expand Down Expand Up @@ -340,7 +341,7 @@ where
}
}

fn cos(self) -> DualBatchScalar<BATCH>
fn cos(&self) -> DualBatchScalar<BATCH>
where
BatchScalarF64<BATCH>: IsCoreScalar,
LaneCount<BATCH>: SupportedLaneCount,
Expand All @@ -367,7 +368,7 @@ where
}
}

fn sin(self) -> DualBatchScalar<BATCH>
fn sin(&self) -> DualBatchScalar<BATCH>
where
BatchScalarF64<BATCH>: IsCoreScalar,
LaneCount<BATCH>: SupportedLaneCount,
Expand All @@ -387,7 +388,7 @@ where
}
}

fn abs(self) -> Self {
fn abs(&self) -> Self {
Self {
real_part: self.real_part.abs(),
dij_part: match self.dij_part.clone() {
Expand All @@ -403,7 +404,11 @@ where
}
}

fn atan2(self, rhs: Self) -> Self {
fn atan2<S>(&self, rhs: S) -> Self
where
S: Borrow<Self>,
{
let rhs = rhs.borrow();
let inv_sq_nrm: BatchScalarF64<BATCH> = BatchScalarF64::ones()
/ (self.real_part * self.real_part + rhs.real_part * rhs.real_part);
Self {
Expand All @@ -421,11 +426,11 @@ where
self.real_part
}

fn sqrt(self) -> Self {
fn sqrt(&self) -> Self {
let sqrt = self.real_part.sqrt();
Self {
real_part: sqrt,
dij_part: match self.dij_part {
dij_part: match &self.dij_part {
Some(dij) => {
let out_dij =
MutTensorDD::from_map(&dij.view(), |dij: &BatchScalarF64<BATCH>| {
Expand All @@ -439,20 +444,20 @@ where
}
}

fn to_vec(self) -> DualBatchVector<1, BATCH> {
fn to_vec(&self) -> DualBatchVector<1, BATCH> {
DualBatchVector::<1, BATCH> {
real_part: self.real_part.real_part().to_vec(),
dij_part: match self.dij_part {
dij_part: match &self.dij_part {
Some(dij) => {
let tmp = dij.inner_scalar_to_vec();
let tmp = dij.clone().inner_scalar_to_vec();
Some(tmp)
}
None => None,
},
}
}

fn tan(self) -> Self {
fn tan(&self) -> Self {
Self {
real_part: self.real_part.tan(),
dij_part: match self.dij_part.clone() {
Expand All @@ -470,7 +475,7 @@ where
}
}

fn acos(self) -> Self {
fn acos(&self) -> Self {
Self {
real_part: self.real_part.acos(),
dij_part: match self.dij_part.clone() {
Expand All @@ -489,7 +494,7 @@ where
}
}

fn asin(self) -> Self {
fn asin(&self) -> Self {
Self {
real_part: self.real_part.asin(),
dij_part: match self.dij_part.clone() {
Expand All @@ -508,7 +513,7 @@ where
}
}

fn atan(self) -> Self {
fn atan(&self) -> Self {
Self {
real_part: self.real_part.atan(),
dij_part: match self.dij_part.clone() {
Expand All @@ -526,7 +531,7 @@ where
}
}

fn fract(self) -> Self {
fn fract(&self) -> Self {
Self {
real_part: self.real_part.fract(),
dij_part: match self.dij_part.clone() {
Expand Down Expand Up @@ -576,14 +581,14 @@ where
self.real_part.less_equal(&rhs.real_part)
}

fn to_dual(self) -> Self::DualScalar {
self
fn to_dual(&self) -> Self::DualScalar {
self.clone()
}

fn select(self, mask: &Self::Mask, other: Self) -> Self {
fn select(&self, mask: &Self::Mask, other: Self) -> Self {
Self {
real_part: self.real_part.select(mask, other.real_part),
dij_part: match (self.dij_part, other.dij_part) {
dij_part: match (self.dij_part.clone(), other.dij_part) {
(Some(lhs), Some(rhs)) => {
let dyn_mat = MutTensorDD::from_map2(
&lhs.view(),
Expand Down
Loading

0 comments on commit 82a26b6

Please sign in to comment.