Skip to content

Commit

Permalink
chore: dbg
Browse files Browse the repository at this point in the history
  • Loading branch information
FL33TW00D committed Jan 22, 2024
1 parent c33dee9 commit 20a75ec
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions crates/ratchet-core/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -347,15 +347,14 @@ impl Tensor {
}

#[cfg(feature = "pyo3")]
pub fn to_py<'p, T: TensorDType + numpy::Element>(
&'p self,
py: pyo3::Python<'p>,
pub fn to_py<'s, 'p: 's, T: TensorDType + numpy::Element>(
&'s self,
py: &'p pyo3::Python<'p>,
) -> &PyArrayDyn<T> {
use numpy::PyArray;
//You must deep_copy here
let cloned = self.deep_clone();
println!("Deep cloned tensor");
PyArray::from_owned_array(py, cloned.into_ndarray::<T>())
PyArray::from_owned_array(*py, cloned.into_ndarray::<T>())
}

pub fn deep_clone(&self) -> Tensor {
Expand Down Expand Up @@ -449,9 +448,10 @@ mod tests {
"x",
)?;

let py_a = a.to_py::<f32>(py);
let py_b = b.to_py::<f32>(py);
println!("Successfully converted tensors to numpy arrays");
let py_a = a.to_py::<f32>(&py);
println!("Converted A to pyo3 array");
let py_b = b.to_py::<f32>(&py);
println!("Converted B to pyo3 array");

let result = prg
.getattr("matmul")?
Expand Down

0 comments on commit 20a75ec

Please sign in to comment.