src/python.rs

changeset 3
fbdee8e4a78d
parent 0
e8f3b6c55ce7
child 6
aefacd832408
equal deleted inserted replaced
2:0fc131645bd8 3:fbdee8e4a78d
48 inner: DiscreteMeasure::from_iter(vec.into_iter()), 48 inner: DiscreteMeasure::from_iter(vec.into_iter()),
49 } 49 }
50 }*/ 50 }*/
51 51
52 fn __iter__(slf: Bound<'_, Self>) -> PyResult<Py<$iter>> { 52 fn __iter__(slf: Bound<'_, Self>) -> PyResult<Py<$iter>> {
53 Py::new(slf.py(), $iter { measure_ref: slf.unbind(), next: 0 }) 53 Py::new(slf.py(), $iter {
54 measure_ref: slf.unbind(),
55 next: 0,
56 pad: false,
57 })
58 }
59
60 /// Same as iterating the object, but pads (or cuts) location vectors to 3 dimensions.
61 fn iter_padded(slf: Bound<'_, Self>) -> PyResult<Py<$iter>> {
62 Py::new(slf.py(), $iter {
63 measure_ref: slf.unbind(),
64 next: 0,
65 pad: true,
66 })
54 } 67 }
55 } 68 }
56 69
57 #[allow(non_camel_case_types)] 70 #[allow(non_camel_case_types)]
58 #[pyclass(module = "pointsource_algs")] 71 #[pyclass(module = "pointsource_algs")]
59 /// Python-side iterator for [`DiscreteMeasure<Loc<$N, $F>, $F>,`] 72 /// Python-side iterator for [`DiscreteMeasure<Loc<$N, $F>, $F>,`]
60 /// Returns tuples (weight, coords) 73 /// Returns tuples (weight, coords)
61 pub struct $iter { 74 pub struct $iter {
62 measure_ref: Py<$name>, 75 measure_ref: Py<$name>,
63 next: usize, 76 next: usize,
77 pad: bool,
64 } 78 }
65 79
66 #[pymethods] 80 #[pymethods]
67 impl $iter { 81 impl $iter {
68 fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { 82 fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
75 mut slf: PyRefMut<'_, Self>, 89 mut slf: PyRefMut<'_, Self>,
76 ) -> PyResult<Option<(Bound<'_, PyArray<$F, Ix1>>, $F)>> { 90 ) -> PyResult<Option<(Bound<'_, PyArray<$F, Ix1>>, $F)>> {
77 let py = slf.py(); 91 let py = slf.py();
78 let meas_: PyRef<'_, $name> = slf.measure_ref.extract(py)?; 92 let meas_: PyRef<'_, $name> = slf.measure_ref.extract(py)?;
79 let meas = &(meas_.inner); 93 let meas = &(meas_.inner);
94 let pad = slf.pad;
80 let next = &mut slf.next; 95 let next = &mut slf.next;
81 Ok((*next < meas.len()).then(|| { 96 Ok((*next < meas.len()).then(|| {
82 let δ = meas[*next]; 97 let δ = meas[*next];
83 *next += 1; 98 *next += 1;
84 (PyArray::from_slice(py, &δ.x.0), δ.α) 99 if pad {
100 (
101 PyArray::from_iter(
102 py,
103 δ.x.iter().copied().chain(std::iter::repeat(0.0)).take(3),
104 ),
105 δ.α,
106 )
107 } else {
108 (PyArray::from_slice(py, &δ.x.0), δ.α)
109 }
85 })) 110 }))
86 } 111 }
87 } 112 }
88 113
89 // Direct access without passing through the wrappers 114 // Direct access without passing through the wrappers

mercurial