# HG changeset patch # User Tuomo Valkonen # Date 1778874390 18000 # Node ID 3868555d135cbfcb43c467bffbcc374dd6406cd2 # Parent 1f19c6bbf07b1305db75b5fef0a129bad63ac67e# Parent 1f301affeae303baa66c845705432863143b21e5 Merge dev into default diff -r 1f19c6bbf07b -r 3868555d135c .gitignore --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/.gitignore Fri May 15 14:46:30 2026 -0500 @@ -0,0 +1,1 @@ +.hgignore \ No newline at end of file diff -r 1f19c6bbf07b -r 3868555d135c .hgignore --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/.hgignore Fri May 15 14:46:30 2026 -0500 @@ -0,0 +1,4 @@ +syntax: glob +target/ +Cargo.lock +**/*.orig diff -r 1f19c6bbf07b -r 3868555d135c Cargo.lock --- a/Cargo.lock Sun Apr 27 20:29:43 2025 -0500 +++ b/Cargo.lock Fri May 15 14:46:30 2026 -0500 @@ -4,7 +4,7 @@ [[package]] name = "alg_tools" -version = "0.3.2" +version = "0.4.1-dev" dependencies = [ "anyhow", "colored", @@ -15,6 +15,7 @@ "num", "num-traits", "numeric_literals", + "pyo3", "rayon", "rustc_version", "serde", @@ -122,6 +123,114 @@ checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" [[package]] +name = "glam" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "333928d5eb103c5d4050533cec0384302db6be8ef7d3cebd30ec6a35350353da" + +[[package]] +name = "glam" +version = "0.15.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3abb554f8ee44336b72d522e0a7fe86a29e09f839a36022fa869a7dfe941a54b" + +[[package]] +name = "glam" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4126c0479ccf7e8664c36a2d719f5f2c140fbb4f9090008098d2c291fa5b3f16" + +[[package]] +name = "glam" +version = "0.17.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e01732b97afd8508eee3333a541b9f7610f454bb818669e66e90f5f57c93a776" + +[[package]] +name = "glam" +version = "0.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "525a3e490ba77b8e326fb67d4b44b4bd2f920f44d4cc73ccec50adc68e3bee34" + +[[package]] +name = "glam" +version = "0.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b8509e6791516e81c1a630d0bd7fbac36d2fa8712a9da8662e716b52d5051ca" + +[[package]] +name = "glam" +version = "0.20.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f43e957e744be03f5801a55472f593d43fabdebf25a4585db250f04d86b1675f" + +[[package]] +name = "glam" +version = "0.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "518faa5064866338b013ff9b2350dc318e14cc4fcd6cb8206d7e7c9886c98815" + +[[package]] +name = "glam" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "12f597d56c1bd55a811a1be189459e8fad2bbc272616375602443bdfb37fa774" + +[[package]] +name = "glam" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e4afd9ad95555081e109fe1d21f2a30c691b5f0919c67dfa690a2e1eb6bd51c" + +[[package]] +name = "glam" +version = "0.24.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5418c17512bdf42730f9032c74e1ae39afc408745ebb2acf72fbc4691c17945" + +[[package]] +name = "glam" +version = "0.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "151665d9be52f9bb40fc7966565d39666f2d1e69233571b71b87791c7e0528b3" + +[[package]] +name = "glam" +version = "0.27.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e05e7e6723e3455f4818c7b26e855439f7546cf617ef669d1adedb8669e5cb9" + +[[package]] +name = "glam" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "779ae4bf7e8421cf91c0b3b64e7e8b40b862fba4d393f59150042de7c4965a94" + +[[package]] +name = "glam" +version = "0.29.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8babf46d4c1c9d92deac9f7be466f76dfc4482b6452fc5024b5e8daf6ffeb3ee" + +[[package]] +name = "glam" +version = "0.30.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd47b05dddf0005d850e5644cae7f2b14ac3df487979dbfff3b56f20b1a6ae46" + +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "indoc" +version = "2.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f4c7245a08504955605670dbf141fceab975f15ca21570696aebe9d2e71576bd" + +[[package]] name = "itertools" version = "0.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -165,12 +274,37 @@ checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" [[package]] +name = "memoffset" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a" +dependencies = [ + "autocfg", +] + +[[package]] name = "nalgebra" -version = "0.33.2" +version = "0.34.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26aecdf64b707efd1310e3544d709c5c0ac61c13756046aaaba41be5c4f66a3b" +checksum = "c4d5b3eff5cd580f93da45e64715e8c20a3996342f1e466599cf7a267a0c2f5f" dependencies = [ "approx", + "glam 0.14.0", + "glam 0.15.2", + "glam 0.16.0", + "glam 0.17.3", + "glam 0.18.0", + "glam 0.19.0", + "glam 0.20.5", + "glam 0.21.3", + "glam 0.22.0", + "glam 0.23.0", + "glam 0.24.2", + "glam 0.25.0", + "glam 0.27.0", + "glam 0.28.0", + "glam 0.29.3", + "glam 0.30.9", "matrixmultiply", "nalgebra-macros", "num-complex", @@ -182,9 +316,9 @@ [[package]] name = "nalgebra-macros" -version = "0.2.2" +version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "254a5372af8fc138e36684761d3c0cdb758a4410e938babcff1c860ce14ddbfc" +checksum = "973e7178a678cfd059ccec50887658d482ce16b0aa9da3888ddeab5cd5eb4889" dependencies = [ "proc-macro2", "quote", @@ -275,12 +409,24 @@ ] [[package]] +name = "once_cell" +version = "1.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" + +[[package]] name = "paste" version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" [[package]] +name = "portable-atomic" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "350e9b48cbc6b0e028b0473b114454c6316e57336ee184ceab6e53f72c178b3e" + +[[package]] name = "proc-macro2" version = "1.0.92" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -290,6 +436,67 @@ ] [[package]] +name = "pyo3" +version = "0.27.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37a6df7eab65fc7bee654a421404947e10a0f7085b6951bf2ea395f4659fb0cf" +dependencies = [ + "indoc", + "libc", + "memoffset", + "once_cell", + "portable-atomic", + "pyo3-build-config", + "pyo3-ffi", + "pyo3-macros", + "unindent", +] + +[[package]] +name = "pyo3-build-config" +version = "0.27.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f77d387774f6f6eec64a004eac0ed525aab7fa1966d94b42f743797b3e395afb" +dependencies = [ + "target-lexicon", +] + +[[package]] +name = "pyo3-ffi" +version = "0.27.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2dd13844a4242793e02df3e2ec093f540d948299a6a77ea9ce7afd8623f542be" +dependencies = [ + "libc", + "pyo3-build-config", +] + +[[package]] +name = "pyo3-macros" +version = "0.27.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eaf8f9f1108270b90d3676b8679586385430e5c0bb78bb5f043f95499c821a71" +dependencies = [ + "proc-macro2", + "pyo3-macros-backend", + "quote", + "syn 2.0.90", +] + +[[package]] +name = "pyo3-macros-backend" +version = "0.27.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70a3b2274450ba5288bc9b8c1b69ff569d1d61189d4bff38f8d22e03d17f932b" +dependencies = [ + "heck", + "proc-macro2", + "pyo3-build-config", + "quote", + "syn 2.0.90", +] + +[[package]] name = "quote" version = "1.0.37" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -422,6 +629,12 @@ ] [[package]] +name = "target-lexicon" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e502f78cdbb8ba4718f566c418c52bc729126ffd16baee5baa718cf25dd5a69a" + +[[package]] name = "typenum" version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -434,6 +647,12 @@ checksum = "adb9e6ca4f869e1180728b7950e35922a7fc6397f7b641499e8f3ef06e50dc83" [[package]] +name = "unindent" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3" + +[[package]] name = "wide" version = "0.7.30" source = "registry+https://github.com/rust-lang/crates.io-index" diff -r 1f19c6bbf07b -r 3868555d135c Cargo.toml --- a/Cargo.toml Sun Apr 27 20:29:43 2025 -0500 +++ b/Cargo.toml Fri May 15 14:46:30 2026 -0500 @@ -1,6 +1,6 @@ [package] name = "alg_tools" -version = "0.3.2" +version = "0.4.1-dev" edition = "2021" rust-version = "1.85" authors = ["Tuomo Valkonen "] @@ -20,7 +20,7 @@ [dependencies] serde = { version = "1.0", features = ["derive"] } csv = "~1.3.1" -nalgebra = "~0.33.0" +nalgebra = "~0.34.0" num-traits = { version = "~0.2.14", features = ["std"] } colored = "~2.1.0" num = "~0.4.0" @@ -31,11 +31,11 @@ rayon = "1.5.3" simba = "0.9.0" anyhow = "1.0.95" +pyo3 = { version = "~0.27.0", optional = true } [package.metadata.docs.rs] rustdoc-args = ["--html-in-header", "katex-header.html"] - [profile.release] debug = true @@ -45,6 +45,7 @@ # The nightly feature enables some additional features. # Nightly-based optimisations are decided automatically by build.rs. nightly = [] +pyo3 = ["dep:pyo3"] [build-dependencies] rustc_version = "0.4" diff -r 1f19c6bbf07b -r 3868555d135c rustfmt.toml --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/rustfmt.toml Fri May 15 14:46:30 2026 -0500 @@ -0,0 +1,3 @@ +overflow_delimited_expr = true +struct_lit_width = 80 + diff -r 1f19c6bbf07b -r 3868555d135c src/bisection_tree.rs --- a/src/bisection_tree.rs Sun Apr 27 20:29:43 2025 -0500 +++ b/src/bisection_tree.rs Fri May 15 14:46:30 2026 -0500 @@ -14,8 +14,10 @@ value of the sum, each $f_k$ also needs to implement [`Mapping`][crate::mapping::Mapping]. Moreover, the sum needs to be represented by a [`SupportGenerator`] that associates to a low-storage-requirement identifier (typically `usize`) an object of the type that represents -$f_k$. [`BTFN`]s support basic vector space operations, and [minimisation][BTFN::minimise] and -[maximisation][BTFN::maximise] via a [branch-and-bound strategy][BTSearch::search_and_refine]. +$f_k$. [`BTFN`]s support basic vector space operations, and +[minimisation][crate::bounds::MinMaxMapping::minimise] and +[maximisation][crate::bounds::MinMaxMapping::maximise] +via a [branch-and-bound strategy][BTSearch::search_and_refine]. The nodes of a bisection tree also store aggregate information about the objects stored in the tree via an [`Aggregator`]. This way, rough upper and lower [bound][Bounds] estimates on @@ -42,10 +44,12 @@ a [`SupportGenerator`]. They can be summed and multipliced by a schalar using standard arithmetic operations. The types of the objects in two summed `BTFN`s do not need to be the same. To find an approximate minimum of a `BTFN` using a branch-and-bound strategy, -use [`BTFN::minimise`]. [`Bounded::bounds`] provides a shortcut to [`GlobalAnalysis`] with the +use [`crate::bounds::MinMaxMapping::minimise`]. +[`crate::bounds::Bounded::bounds`] provides a shortcut to [`GlobalAnalysis`] with the [`Bounds`] aggregator. If the rough bounds so obtained do not indicate that the `BTFN` is in some given bounds, instead of doing a full minimisation and maximisation for higher quality bounds, -it is more efficient to use [`BTFN::has_upper_bound`] and [`BTFN::has_lower_bound`]. +it is more efficient to use [`crate::bounds::MinMaxMapping::has_upper_bound`] and +[`crate::bounds::MinMaxMapping::has_lower_bound`]. */ mod supportid; diff -r 1f19c6bbf07b -r 3868555d135c src/bisection_tree/aggregator.rs --- a/src/bisection_tree/aggregator.rs Sun Apr 27 20:29:43 2025 -0500 +++ b/src/bisection_tree/aggregator.rs Fri May 15 14:46:30 2026 -0500 @@ -2,9 +2,8 @@ Aggregation / summarisation of information in branches of bisection trees. */ +pub use crate::bounds::Bounds; use crate::types::*; -use crate::sets::Set; -use crate::instance::Instance; /// Trait for aggregating information about a branch of a [bisection tree][super::BT]. /// @@ -19,180 +18,60 @@ /// of a function on a greater domain from bounds on subdomains /// (in practise [`Cube`][crate::sets::Cube]s). /// -pub trait Aggregator : Clone + Sync + Send + 'static + std::fmt::Debug { +pub trait Aggregator: Clone + Sync + Send + 'static + std::fmt::Debug { /// Aggregate a new data to current state. - fn aggregate(&mut self, aggregates : I) - where I : Iterator; + fn aggregate(&mut self, aggregates: I) + where + I: Iterator; /// Summarise several other aggregators, resetting current state. - fn summarise<'a, I>(&'a mut self, aggregates : I) - where I : Iterator; + fn summarise<'a, I>(&'a mut self, aggregates: I) + where + I: Iterator; /// Create a new “empty” aggregate data. fn new() -> Self; } /// An [`Aggregator`] that doesn't aggregate anything. -#[derive(Clone,Debug)] +#[derive(Clone, Debug)] pub struct NullAggregator; impl Aggregator for NullAggregator { - fn aggregate(&mut self, _aggregates : I) - where I : Iterator {} - - fn summarise<'a, I>(&'a mut self, _aggregates : I) - where I : Iterator {} - - fn new() -> Self { NullAggregator } -} - -/// Upper and lower bounds on an `F`-valued function. -#[derive(Copy,Clone,Debug)] -pub struct Bounds( - /// Lower bound - pub F, - /// Upper bound - pub F -); - -impl Bounds { - /// Returns the lower bound - #[inline] - pub fn lower(&self) -> F { self.0 } - - /// Returns the upper bound - #[inline] - pub fn upper(&self) -> F { self.1 } -} - -impl Bounds { - /// Returns a uniform bound. - /// - /// This is maximum over the absolute values of the upper and lower bound. - #[inline] - pub fn uniform(&self) -> F { - let &Bounds(lower, upper) = self; - lower.abs().max(upper.abs()) + fn aggregate(&mut self, _aggregates: I) + where + I: Iterator, + { } - /// Construct a bounds, making sure `lower` bound is less than `upper` - #[inline] - pub fn corrected(lower : F, upper : F) -> Self { - if lower <= upper { - Bounds(lower, upper) - } else { - Bounds(upper, lower) - } + fn summarise<'a, I>(&'a mut self, _aggregates: I) + where + I: Iterator, + { } - /// Refine the lower bound - #[inline] - pub fn refine_lower(&self, lower : F) -> Self { - let &Bounds(l, u) = self; - debug_assert!(l <= u); - Bounds(l.max(lower), u.max(lower)) - } - - /// Refine the lower bound - #[inline] - pub fn refine_upper(&self, upper : F) -> Self { - let &Bounds(l, u) = self; - debug_assert!(l <= u); - Bounds(l.min(upper), u.min(upper)) + fn new() -> Self { + NullAggregator } } -impl<'a, F : Float> std::ops::Add for Bounds { - type Output = Self; - #[inline] - fn add(self, Bounds(l2, u2) : Self) -> Self::Output { - let Bounds(l1, u1) = self; - debug_assert!(l1 <= u1 && l2 <= u2); - Bounds(l1 + l2, u1 + u2) - } -} - -impl<'a, F : Float> std::ops::Mul for Bounds { - type Output = Self; - #[inline] - fn mul(self, Bounds(l2, u2) : Self) -> Self::Output { - let Bounds(l1, u1) = self; - debug_assert!(l1 <= u1 && l2 <= u2); - let a = l1 * l2; - let b = u1 * u2; - // The order may flip when negative numbers are involved, so need min/max - Bounds(a.min(b), a.max(b)) - } -} - -impl std::iter::Product for Bounds { +impl Aggregator for Bounds { #[inline] - fn product(mut iter: I) -> Self - where I: Iterator { - match iter.next() { - None => Bounds(F::ZERO, F::ZERO), - Some(init) => iter.fold(init, |a, b| a*b) - } - } -} - -impl Set for Bounds { - fn contains>(&self, item : I) -> bool { - let v = item.own(); - let &Bounds(l, u) = self; - debug_assert!(l <= u); - l <= v && v <= u - } -} - -impl Bounds { - /// Calculate a common bound (glb, lub) for two bounds. - #[inline] - pub fn common(&self, &Bounds(l2, u2) : &Self) -> Self { - let &Bounds(l1, u1) = self; - debug_assert!(l1 <= u1 && l2 <= u2); - Bounds(l1.min(l2), u1.max(u2)) - } - - /// Indicates whether `Self` is a superset of the argument bound. - #[inline] - pub fn superset(&self, &Bounds(l2, u2) : &Self) -> bool { - let &Bounds(l1, u1) = self; - debug_assert!(l1 <= u1 && l2 <= u2); - l1 <= l2 && u2 <= u1 - } - - /// Returns the greatest bound contained by both argument bounds, if one exists. - #[inline] - pub fn glb(&self, &Bounds(l2, u2) : &Self) -> Option { - let &Bounds(l1, u1) = self; - debug_assert!(l1 <= u1 && l2 <= u2); - let l = l1.max(l2); - let u = u1.min(u2); - debug_assert!(l <= u); - if l < u { - Some(Bounds(l, u)) - } else { - None - } - } -} - -impl Aggregator for Bounds { - #[inline] - fn aggregate(&mut self, aggregates : I) - where I : Iterator { + fn aggregate(&mut self, aggregates: I) + where + I: Iterator, + { *self = aggregates.fold(*self, |a, b| a + b); } #[inline] - fn summarise<'a, I>(&'a mut self, mut aggregates : I) - where I : Iterator { + fn summarise<'a, I>(&'a mut self, mut aggregates: I) + where + I: Iterator, + { *self = match aggregates.next() { None => Bounds(F::ZERO, F::ZERO), // No parts in this cube; the function is zero - Some(&bounds) => { - aggregates.fold(bounds, |a, b| a.common(b)) - } + Some(&bounds) => aggregates.fold(bounds, |a, b| a.common(b)), } } diff -r 1f19c6bbf07b -r 3868555d135c src/bisection_tree/bt.rs --- a/src/bisection_tree/bt.rs Sun Apr 27 20:29:43 2025 -0500 +++ b/src/bisection_tree/bt.rs Fri May 15 14:46:30 2026 -0500 @@ -1,34 +1,28 @@ - /*! Bisection tree basics, [`BT`] type and the [`BTImpl`] trait. */ -use std::slice::IterMut; -use std::iter::once; -use std::sync::Arc; -use serde::{Serialize, Deserialize}; +use itertools::izip; pub(super) use nalgebra::Const; -use itertools::izip; +use serde::{Deserialize, Serialize}; +use std::iter::once; +use std::slice::IterMut; +use std::sync::Arc; -use crate::types::{Float, Num}; -use crate::parallelism::{with_task_budget, TaskBudget}; +use super::aggregator::*; +use super::support::*; use crate::coefficients::pow; -use crate::maputil::{ - array_init, - map2, - map2_indexed, - collect_into_array_unchecked -}; +use crate::loc::Loc; +use crate::maputil::{array_init, collect_into_array_unchecked, map2, map2_indexed}; +use crate::parallelism::{with_task_budget, TaskBudget}; use crate::sets::Cube; -use crate::loc::Loc; -use super::support::*; -use super::aggregator::*; +use crate::types::{Float, Num}; /// An enum that indicates whether a [`Node`] of a [`BT`] is uninitialised, leaf, or branch. /// /// For the type and const parametere, see the [module level documentation][super]. -#[derive(Clone,Debug)] -pub(super) enum NodeOption { +#[derive(Clone, Debug)] +pub(super) enum NodeOption { /// Indicates an uninitilised node; may become a branch or a leaf. // TODO: Could optimise Uninitialised away by simply treat Leaf with an empty Vec as // something that can be still replaced with Branches. @@ -44,39 +38,41 @@ /// /// For the type and const parameteres, see the [module level documentation][super]. #[derive(Clone, Debug)] -pub struct Node { +pub struct Node { /// The data or branches under the node. - pub(super) data : NodeOption, + pub(super) data: NodeOption, /// Aggregator for `data`. - pub(super) aggregator : A, + pub(super) aggregator: A, } /// Branching information of a [`Node`] of a [`BT`] bisection tree into `P` subnodes. /// /// For the type and const parameters, see the [module level documentation][super]. #[derive(Clone, Debug)] -pub(super) struct Branches { +pub(super) struct Branches { /// Point for subdivision of the (unstored) [`Cube`] corresponding to the node. - pub(super) branch_at : Loc, + pub(super) branch_at: Loc, /// Subnodes - pub(super) nodes : [Node; P], + pub(super) nodes: [Node; P], } /// Dirty workaround to broken Rust drop, see [https://github.com/rust-lang/rust/issues/58068](). -impl -Drop for Node { +impl Drop for Node { fn drop(&mut self) { use NodeOption as NO; - let process = |brc : Arc>, - to_drop : &mut Vec>>| { + let process = |brc: Arc>, + to_drop: &mut Vec>>| { // We only drop Branches if we have the only strong reference. // FIXME: update the RwLocks on Nodes. - Arc::try_unwrap(brc).ok().map(|branches| branches.nodes.map(|mut node| { - if let NO::Branches(brc2) = std::mem::replace(&mut node.data, NO::Uninitialised) { - to_drop.push(brc2) - } - })); + Arc::try_unwrap(brc).ok().map(|branches| { + branches.nodes.map(|mut node| { + if let NO::Branches(brc2) = std::mem::replace(&mut node.data, NO::Uninitialised) + { + to_drop.push(brc2) + } + }) + }); }; // We mark Self as NodeOption::Uninitialised, extracting the real contents. @@ -98,9 +94,9 @@ /// Trait for the depth of a [`BT`]. /// /// This will generally be either a runtime [`DynamicDepth`] or compile-time [`Const`] depth. -pub trait Depth : 'static + Copy + Send + Sync + std::fmt::Debug { +pub trait Depth: 'static + Copy + Send + Sync + std::fmt::Debug { /// Lower depth type. - type Lower : Depth; + type Lower: Depth; /// Returns a lower depth, if there still is one. fn lower(&self) -> Option; @@ -113,26 +109,26 @@ } /// Dynamic (runtime) [`Depth`] for a [`BT`]. -#[derive(Copy,Clone,Debug,Serialize,Deserialize)] +#[derive(Copy, Clone, Debug, Serialize, Deserialize)] pub struct DynamicDepth( /// The depth - pub u8 + pub u8, ); impl Depth for DynamicDepth { type Lower = Self; #[inline] fn lower(&self) -> Option { - if self.0>0 { - Some(DynamicDepth(self.0-1)) - } else { + if self.0 > 0 { + Some(DynamicDepth(self.0 - 1)) + } else { None } } #[inline] fn lower_or(&self) -> Self { - DynamicDepth(if self.0>0 { self.0 - 1 } else { 0 }) + DynamicDepth(if self.0 > 0 { self.0 - 1 } else { 0 }) } #[inline] @@ -143,9 +139,15 @@ impl Depth for Const<0> { type Lower = Self; - fn lower(&self) -> Option { None } - fn lower_or(&self) -> Self::Lower { Const } - fn value(&self) -> u32 { 0 } + fn lower(&self) -> Option { + None + } + fn lower_or(&self) -> Self::Lower { + Const + } + fn value(&self) -> u32 { + 0 + } } macro_rules! impl_constdepth { @@ -165,7 +167,7 @@ /// The const parameter `P` from the [module level documentation][super] is required to satisfy /// `Const

: Branchcount`. /// This trait is implemented for `P=pow(2, N)` for small `N`. -pub trait BranchCount {} +pub trait BranchCount {} macro_rules! impl_branchcount { ($($n:literal)*) => { $( impl BranchCount<$n> for Const<{pow(2, $n)}>{} @@ -173,18 +175,19 @@ } impl_branchcount!(1 2 3 4 5 6 7 8); -impl Branches -where Const

: BranchCount, - A : Aggregator +impl Branches +where + Const

: BranchCount, + A: Aggregator, { /// Returns the index in {0, …, `P`-1} for the branch to which the point `x` corresponds. /// /// This only takes the branch subdivision point $d$ into account, so is always succesfull. /// Thus, for this point, each branch corresponds to a quadrant of $ℝ^N$ relative to $d$. - fn get_node_index(&self, x : &Loc) -> usize { - izip!(0..P, x.iter(), self.branch_at.iter()).map(|(i, x_i, branch_i)| - if x_i > branch_i { 1<) -> usize { + izip!(0..P, x.iter(), self.branch_at.iter()) + .map(|(i, x_i, branch_i)| if x_i > branch_i { 1 << i } else { 0 }) + .sum() } /// Returns the node within `Self` containing the point `x`. @@ -192,37 +195,37 @@ /// This only takes the branch subdivision point $d$ into account, so is always succesfull. /// Thus, for this point, each branch corresponds to a quadrant of $ℝ^N$ relative to $d$. #[inline] - fn get_node(&self, x : &Loc) -> &Node { - &self.nodes[self.get_node_index(x)] + fn get_node(&self, x: &Loc) -> &Node { + &self.nodes[self.get_node_index(x)] } } /// An iterator over the $P=2^N$ subcubes of a [`Cube`] subdivided at a point `d`. -pub(super) struct SubcubeIter<'b, F : Float, const N : usize, const P : usize> { - domain : &'b Cube, - branch_at : Loc, - index : usize, +pub(super) struct SubcubeIter<'b, F: Float, const N: usize, const P: usize> { + domain: &'b Cube, + branch_at: Loc, + index: usize, } /// Returns the `i`:th subcube of `domain` subdivided at `branch_at`. #[inline] -fn get_subcube( - branch_at : &Loc, - domain : &Cube, - i : usize -) -> Cube { +fn get_subcube( + branch_at: &Loc, + domain: &Cube, + i: usize, +) -> Cube { map2_indexed(branch_at, domain, move |j, &branch, &[start, end]| { if i & (1 << j) != 0 { [branch, end] } else { [start, branch] } - }).into() + }) + .into() } -impl<'a, 'b, F : Float, const N : usize, const P : usize> Iterator -for SubcubeIter<'b, F, N, P> { - type Item = Cube; +impl<'a, 'b, F: Float, const N: usize, const P: usize> Iterator for SubcubeIter<'b, F, N, P> { + type Item = Cube; #[inline] fn next(&mut self) -> Option { if self.index < P { @@ -235,30 +238,33 @@ } } -impl -Branches -where Const

: BranchCount, - A : Aggregator, - D : 'static + Copy + Send + Sync { - +impl Branches +where + Const

: BranchCount, + A: Aggregator, + D: 'static + Copy + Send + Sync, +{ /// Creates a new node branching structure, subdividing `domain` based on the /// [hint][Support::support_hint] of `support`. - pub(super) fn new_with>( - domain : &Cube, - support : &S + pub(super) fn new_with + Support>( + domain: &Cube, + support: &S, ) -> Self { let hint = support.bisection_hint(domain); let branch_at = map2(&hint, domain, |h, r| { - h.unwrap_or_else(|| (r[0]+r[1])/F::TWO).max(r[0]).min(r[1]) - }).into(); - Branches{ - branch_at : branch_at, - nodes : array_init(|| Node::new()), + h.unwrap_or_else(|| (r[0] + r[1]) / F::TWO) + .max(r[0]) + .min(r[1]) + }) + .into(); + Branches { + branch_at: branch_at, + nodes: array_init(|| Node::new()), } } /// Summarises the aggregators of these branches into `agg` - pub(super) fn summarise_into(&self, agg : &mut A) { + pub(super) fn summarise_into(&self, agg: &mut A) { // We need to create an array of the aggregators clones due to the RwLock. agg.summarise(self.nodes.iter().map(Node::get_aggregator)); } @@ -266,19 +272,18 @@ /// Returns an iterator over the subcubes of `domain` subdivided at the branching point /// of `self`. #[inline] - pub(super) fn iter_subcubes<'b>(&self, domain : &'b Cube) - -> SubcubeIter<'b, F, N, P> { + pub(super) fn iter_subcubes<'b>(&self, domain: &'b Cube) -> SubcubeIter<'b, F, N, P> { SubcubeIter { - domain : domain, - branch_at : self.branch_at, - index : 0, + domain: domain, + branch_at: self.branch_at, + index: 0, } } /* /// Returns an iterator over all nodes and corresponding subcubes of `self`. #[inline] - pub(super) fn nodes_and_cubes<'a, 'b>(&'a self, domain : &'b Cube) + pub(super) fn nodes_and_cubes<'a, 'b>(&'a self, domain : &'b Cube) -> std::iter::Zip>, SubcubeIter<'b, F, N, P>> { self.nodes.iter().zip(self.iter_subcubes(domain)) } @@ -286,8 +291,10 @@ /// Mutably iterate over all nodes and corresponding subcubes of `self`. #[inline] - pub(super) fn nodes_and_cubes_mut<'a, 'b>(&'a mut self, domain : &'b Cube) - -> std::iter::Zip>, SubcubeIter<'b, F, N, P>> { + pub(super) fn nodes_and_cubes_mut<'a, 'b>( + &'a mut self, + domain: &'b Cube, + ) -> std::iter::Zip>, SubcubeIter<'b, F, N, P>> { let subcube_iter = self.iter_subcubes(domain); self.nodes.iter_mut().zip(subcube_iter) } @@ -296,12 +303,16 @@ #[inline] fn recurse<'scope, 'smaller, 'refs>( &'smaller mut self, - domain : &'smaller Cube, - task_budget : TaskBudget<'scope, 'refs>, - guard : impl Fn(&Node, &Cube) -> bool + Send + 'smaller, - mut f : impl for<'a> FnMut(&mut Node, &Cube, TaskBudget<'smaller, 'a>) - + Send + Copy + 'smaller - ) where 'scope : 'smaller { + domain: &'smaller Cube, + task_budget: TaskBudget<'scope, 'refs>, + guard: impl Fn(&Node, &Cube) -> bool + Send + 'smaller, + mut f: impl for<'a> FnMut(&mut Node, &Cube, TaskBudget<'smaller, 'a>) + + Send + + Copy + + 'smaller, + ) where + 'scope: 'smaller, + { let subs = self.nodes_and_cubes_mut(domain); task_budget.zoom(move |s| { for (node, subcube) in subs { @@ -321,19 +332,23 @@ /// * `support` is the [`Support`] that is used determine with which subcubes of `domain` /// (at subdivision depth `new_leaf_depth`) the data `d` is to be associated with. /// - pub(super) fn insert<'refs, 'scope, M : Depth, S : LocalAnalysis>( + pub(super) fn insert<'refs, 'scope, M: Depth, S: LocalAnalysis + Support>( &mut self, - domain : &Cube, - d : D, - new_leaf_depth : M, - support : &S, - task_budget : TaskBudget<'scope, 'refs>, + domain: &Cube, + d: D, + new_leaf_depth: M, + support: &S, + task_budget: TaskBudget<'scope, 'refs>, ) { let support_hint = support.support_hint(); - self.recurse(domain, task_budget, - |_, subcube| support_hint.intersects(&subcube), - move |node, subcube, new_budget| node.insert(subcube, d, new_leaf_depth, support, - new_budget)); + self.recurse( + domain, + task_budget, + |_, subcube| support_hint.intersects(&subcube), + move |node, subcube, new_budget| { + node.insert(subcube, d, new_leaf_depth, support, new_budget) + }, + ); } /// Construct a new instance of the branch for a different aggregator. @@ -344,20 +359,24 @@ /// generator's `SupportType`. pub(super) fn convert_aggregator( self, - generator : &G, - domain : &Cube - ) -> Branches - where ANew : Aggregator, - G : SupportGenerator, - G::SupportType : LocalAnalysis { + generator: &G, + domain: &Cube, + ) -> Branches + where + ANew: Aggregator, + G: SupportGenerator, + G::SupportType: LocalAnalysis, + { let branch_at = self.branch_at; let subcube_iter = self.iter_subcubes(domain); - let new_nodes = self.nodes.into_iter().zip(subcube_iter).map(|(node, subcube)| { - Node::convert_aggregator(node, generator, &subcube) - }); + let new_nodes = self + .nodes + .into_iter() + .zip(subcube_iter) + .map(|(node, subcube)| Node::convert_aggregator(node, generator, &subcube)); Branches { - branch_at : branch_at, - nodes : collect_into_array_unchecked(new_nodes), + branch_at: branch_at, + nodes: collect_into_array_unchecked(new_nodes), } } @@ -367,37 +386,43 @@ /// [`Support`]s. The `domain` is the cube corresponding to `self`. pub(super) fn refresh_aggregator<'refs, 'scope, G>( &mut self, - generator : &G, - domain : &Cube, - task_budget : TaskBudget<'scope, 'refs>, - ) where G : SupportGenerator, - G::SupportType : LocalAnalysis { - self.recurse(domain, task_budget, - |_, _| true, - move |node, subcube, new_budget| node.refresh_aggregator(generator, subcube, - new_budget)); + generator: &G, + domain: &Cube, + task_budget: TaskBudget<'scope, 'refs>, + ) where + G: SupportGenerator, + G::SupportType: LocalAnalysis, + { + self.recurse( + domain, + task_budget, + |_, _| true, + move |node, subcube, new_budget| { + node.refresh_aggregator(generator, subcube, new_budget) + }, + ); } } -impl -Node -where Const

: BranchCount, - A : Aggregator, - D : 'static + Copy + Send + Sync { - +impl Node +where + Const

: BranchCount, + A: Aggregator, + D: 'static + Copy + Send + Sync, +{ /// Create a new node #[inline] pub(super) fn new() -> Self { Node { - data : NodeOption::Uninitialised, - aggregator : A::new(), + data: NodeOption::Uninitialised, + aggregator: A::new(), } } /* /// Get leaf data #[inline] - pub(super) fn get_leaf_data(&self, x : &Loc) -> Option<&Vec> { + pub(super) fn get_leaf_data(&self, x : &Loc) -> Option<&Vec> { match self.data { NodeOption::Uninitialised => None, NodeOption::Leaf(ref data) => Some(data), @@ -407,7 +432,7 @@ /// Get leaf data iterator #[inline] - pub(super) fn get_leaf_data_iter(&self, x : &Loc) -> Option> { + pub(super) fn get_leaf_data_iter(&self, x: &Loc) -> Option> { match self.data { NodeOption::Uninitialised => None, NodeOption::Leaf(ref data) => Some(data.iter()), @@ -434,13 +459,13 @@ /// If `self` is a [`NodeOption::Branches`], the data is passed to branches whose subcubes /// `support` intersects. If an [`NodeOption::Uninitialised`] node is encountered, a new leaf is /// created at a minimum depth of `new_leaf_depth`. - pub(super) fn insert<'refs, 'scope, M : Depth, S : LocalAnalysis >( + pub(super) fn insert<'refs, 'scope, M: Depth, S: LocalAnalysis + Support>( &mut self, - domain : &Cube, - d : D, - new_leaf_depth : M, - support : &S, - task_budget : TaskBudget<'scope, 'refs>, + domain: &Cube, + d: D, + new_leaf_depth: M, + support: &S, + task_budget: TaskBudget<'scope, 'refs>, ) { match &mut self.data { NodeOption::Uninitialised => { @@ -451,10 +476,10 @@ self.aggregator.aggregate(once(a)); // TODO: this is currently a dirty hard-coded heuristic; // should add capacity as a parameter - let mut vec = Vec::with_capacity(2*P+1); + let mut vec = Vec::with_capacity(2 * P + 1); vec.push(d); NodeOption::Leaf(vec) - }, + } Some(lower) => { let b = Arc::new({ let mut b0 = Branches::new_with(domain, support); @@ -465,19 +490,19 @@ NodeOption::Branches(b) } } - }, + } NodeOption::Leaf(leaf) => { leaf.push(d); let a = support.local_analysis(&domain); self.aggregator.aggregate(once(a)); - }, + } NodeOption::Branches(b) => { // FIXME: recursion that may cause stack overflow if the tree becomes // very deep, e.g. due to [`BTSearch::search_and_refine`]. let bm = Arc::make_mut(b); bm.insert(domain, d, new_leaf_depth.lower_or(), support, task_budget); bm.summarise_into(&mut self.aggregator); - }, + } } } @@ -489,18 +514,19 @@ /// generator's `SupportType`. pub(super) fn convert_aggregator( mut self, - generator : &G, - domain : &Cube - ) -> Node - where ANew : Aggregator, - G : SupportGenerator, - G::SupportType : LocalAnalysis { - + generator: &G, + domain: &Cube, + ) -> Node + where + ANew: Aggregator, + G: SupportGenerator, + G::SupportType: LocalAnalysis, + { // The mem::replace is needed due to the [`Drop`] implementation to extract self.data. match std::mem::replace(&mut self.data, NodeOption::Uninitialised) { NodeOption::Uninitialised => Node { - data : NodeOption::Uninitialised, - aggregator : ANew::new(), + data: NodeOption::Uninitialised, + aggregator: ANew::new(), }, NodeOption::Leaf(v) => { let mut anew = ANew::new(); @@ -510,10 +536,10 @@ })); Node { - data : NodeOption::Leaf(v), - aggregator : anew, + data: NodeOption::Leaf(v), + aggregator: anew, } - }, + } NodeOption::Branches(b) => { // FIXME: recursion that may cause stack overflow if the tree becomes // very deep, e.g. due to [`BTSearch::search_and_refine`]. @@ -521,8 +547,8 @@ let mut anew = ANew::new(); bnew.summarise_into(&mut anew); Node { - data : NodeOption::Branches(Arc::new(bnew)), - aggregator : anew, + data: NodeOption::Branches(Arc::new(bnew)), + aggregator: anew, } } } @@ -534,20 +560,22 @@ /// [`Support`]s. The `domain` is the cube corresponding to `self`. pub(super) fn refresh_aggregator<'refs, 'scope, G>( &mut self, - generator : &G, - domain : &Cube, - task_budget : TaskBudget<'scope, 'refs>, - ) where G : SupportGenerator, - G::SupportType : LocalAnalysis { + generator: &G, + domain: &Cube, + task_budget: TaskBudget<'scope, 'refs>, + ) where + G: SupportGenerator, + G::SupportType: LocalAnalysis, + { match &mut self.data { - NodeOption::Uninitialised => { }, + NodeOption::Uninitialised => {} NodeOption::Leaf(v) => { self.aggregator = A::new(); - self.aggregator.aggregate(v.iter().map(|d| { - generator.support_for(*d) - .local_analysis(&domain) - })); - }, + self.aggregator.aggregate( + v.iter() + .map(|d| generator.support_for(*d).local_analysis(&domain)), + ); + } NodeOption::Branches(ref mut b) => { // FIXME: recursion that may cause stack overflow if the tree becomes // very deep, e.g. due to [`BTSearch::search_and_refine`]. @@ -563,11 +591,13 @@ /// /// This can be removed and the methods implemented directly on [`BT`] once Rust's const generics /// are flexible enough to allow fixing `P=pow(2, N)`. -pub trait BTNode -where F : Float, - D : 'static + Copy, - A : Aggregator { - type Node : Clone + std::fmt::Debug; +pub trait BTNode +where + F: Float, + D: 'static + Copy, + A: Aggregator, +{ + type Node: Clone + std::fmt::Debug; } /// Helper structure for looking up a [`Node`] without the knowledge of `P`. @@ -580,56 +610,60 @@ /// Basic interface to a [`BT`] bisection tree. /// /// Further routines are provided by the [`BTSearch`][super::refine::BTSearch] trait. -pub trait BTImpl : std::fmt::Debug + Clone + GlobalAnalysis { +pub trait BTImpl: + std::fmt::Debug + Clone + GlobalAnalysis +{ /// The data type stored in the tree - type Data : 'static + Copy + Send + Sync; + type Data: 'static + Copy + Send + Sync; /// The depth type of the tree - type Depth : Depth; + type Depth: Depth; /// The type for the [aggregate information][Aggregator] about the `Data` stored in each node /// of the tree. - type Agg : Aggregator; + type Agg: Aggregator; /// The type of the tree with the aggregator converted to `ANew`. - type Converted : BTImpl where ANew : Aggregator; + type Converted: BTImpl + where + ANew: Aggregator; /// Insert the data `d` into the tree for `support`. /// /// Every leaf node of the tree that intersects the `support` will contain a copy of /// `d`. - fn insert>( + fn insert + Support>( &mut self, - d : Self::Data, - support : &S + d: Self::Data, + support: &S, ); /// Construct a new instance of the tree for a different aggregator /// /// The `generator` is used to convert the data of type [`Self::Data`] contained in the tree /// into corresponding [`Support`]s. - fn convert_aggregator(self, generator : &G) - -> Self::Converted - where ANew : Aggregator, - G : SupportGenerator, - G::SupportType : LocalAnalysis; - + fn convert_aggregator(self, generator: &G) -> Self::Converted + where + ANew: Aggregator, + G: SupportGenerator, + G::SupportType: LocalAnalysis; /// Refreshes the aggregator of the three after possible changes to the support generator. /// /// The `generator` is used to convert the data of type [`Self::Data`] contained in the tree /// into corresponding [`Support`]s. - fn refresh_aggregator(&mut self, generator : &G) - where G : SupportGenerator, - G::SupportType : LocalAnalysis; + fn refresh_aggregator(&mut self, generator: &G) + where + G: SupportGenerator, + G::SupportType: LocalAnalysis; /// Returns an iterator over all [`Self::Data`] items at the point `x` of the domain. - fn iter_at(&self, x : &Loc) -> std::slice::Iter<'_, Self::Data>; + fn iter_at(&self, x: &Loc) -> std::slice::Iter<'_, Self::Data>; /* /// Returns all [`Self::Data`] items at the point `x` of the domain. - fn data_at(&self, x : &Loc) -> Arc>; + fn data_at(&self, x : &Loc) -> Arc>; */ /// Create a new tree on `domain` of indicated `depth`. - fn new(domain : Cube, depth : Self::Depth) -> Self; + fn new(domain: Cube, depth: Self::Depth) -> Self; } /// The main bisection tree structure. @@ -637,20 +671,17 @@ /// It should be accessed via the [`BTImpl`] trait to hide the `const P : usize` parameter until /// const generics are flexible enough to fix `P=pow(2, N)` and thus also get rid of /// the `BTNodeLookup : BTNode` trait bound. -#[derive(Clone,Debug)] -pub struct BT< - M : Depth, - F : Float, - D : 'static + Copy, - A : Aggregator, - const N : usize, -> where BTNodeLookup : BTNode { +#[derive(Clone, Debug)] +pub struct BT +where + BTNodeLookup: BTNode, +{ /// The depth of the tree (initial, before refinement) - pub(super) depth : M, + pub(super) depth: M, /// The domain of the toplevel node - pub(super) domain : Cube, + pub(super) domain: Cube, /// The toplevel node of the tree - pub(super) topnode : >::Node, + pub(super) topnode: >::Node, } macro_rules! impl_bt { @@ -662,7 +693,7 @@ type Node = Node; } - impl BTImpl for BT + impl BTImpl<$n, F> for BT where M : Depth, F : Float, D : 'static + Copy + Send + Sync + std::fmt::Debug, @@ -672,7 +703,7 @@ type Agg = A; type Converted = BT where ANew : Aggregator; - fn insert>( + fn insert + Support< $n, F>>( &mut self, d : D, support : &S @@ -690,7 +721,7 @@ fn convert_aggregator(self, generator : &G) -> Self::Converted where ANew : Aggregator, - G : SupportGenerator, + G : SupportGenerator< $n, F, Id=D>, G::SupportType : LocalAnalysis { let topnode = self.topnode.convert_aggregator(generator, &self.domain); @@ -702,22 +733,22 @@ } fn refresh_aggregator(&mut self, generator : &G) - where G : SupportGenerator, + where G : SupportGenerator< $n, F, Id=Self::Data>, G::SupportType : LocalAnalysis { with_task_budget(|task_budget| self.topnode.refresh_aggregator(generator, &self.domain, task_budget) ) } - /*fn data_at(&self, x : &Loc) -> Arc> { + /*fn data_at(&self, x : &Loc<$n, F>) -> Arc> { self.topnode.get_leaf_data(x).unwrap_or_else(|| Arc::new(Vec::new())) }*/ - fn iter_at(&self, x : &Loc) -> std::slice::Iter<'_, D> { + fn iter_at(&self, x : &Loc<$n, F>) -> std::slice::Iter<'_, D> { self.topnode.get_leaf_data_iter(x).unwrap_or_else(|| [].iter()) } - fn new(domain : Cube, depth : M) -> Self { + fn new(domain : Cube<$n, F>, depth : M) -> Self { BT { depth : depth, domain : domain, @@ -739,4 +770,3 @@ } impl_bt!(1 2 3 4); - diff -r 1f19c6bbf07b -r 3868555d135c src/bisection_tree/btfn.rs --- a/src/bisection_tree/btfn.rs Sun Apr 27 20:29:43 2025 -0500 +++ b/src/bisection_tree/btfn.rs Fri May 15 14:46:30 2026 -0500 @@ -1,65 +1,88 @@ - +use crate::instance::{ClosedSpace, Instance, MyCow, Ownable, Space}; +use crate::mapping::{BasicDecomposition, DifferentiableImpl, DifferentiableMapping, Mapping}; +use crate::types::Float; use numeric_literals::replace_float_literals; use std::iter::Sum; use std::marker::PhantomData; use std::sync::Arc; -use crate::types::Float; -use crate::mapping::{ - Instance, Mapping, DifferentiableImpl, DifferentiableMapping, Space, - BasicDecomposition, -}; //use crate::linops::{Apply, Linear}; -use crate::sets::Set; -use crate::sets::Cube; -use crate::loc::Loc; +use super::aggregator::*; +use super::bt::*; +use super::either::*; +use super::refine::*; use super::support::*; -use super::bt::*; -use super::refine::*; -use super::aggregator::*; -use super::either::*; +use crate::bounds::MinMaxMapping; use crate::fe_model::base::RealLocalModel; use crate::fe_model::p2_local_model::*; +use crate::loc::Loc; +use crate::sets::Cube; +use crate::sets::Set; -/// Presentation for (mathematical) functions constructed as a sum of components functions with +/// Presentation for (mathematical) functions constructed as a sum of components functions with /// typically small support. /// -/// The domain of the function is [`Loc`]``, where `F` is the type of floating point numbers, +/// The domain of the function is [`Loc`]``, where `F` is the type of floating point numbers, /// and `N` the dimension. /// /// The `generator` lists the component functions that have to implement [`Support`]. /// Identifiers of the components ([`SupportGenerator::Id`], usually `usize`) are stored stored /// in a [bisection tree][BTImpl], when one is provided as `bt`. However `bt` may also be `()` /// for a [`PreBTFN`] that is only useful for vector space operations with a full [`BTFN`]. -#[derive(Clone,Debug)] -pub struct BTFN< - F : Float, - G : SupportGenerator, - BT /*: BTImpl*/, - const N : usize -> /*where G::SupportType : LocalAnalysis*/ { - bt : BT, - generator : Arc, - _phantoms : PhantomData, +#[derive(Clone, Debug)] +pub struct BTFN, BT /*: BTImpl< N, F>*/, const N: usize> /*where G::SupportType : LocalAnalysis*/ +{ + bt: BT, + generator: Arc, + _phantoms: PhantomData, } -impl -Space for BTFN +impl Ownable for BTFN where - G : SupportGenerator, - G::SupportType : LocalAnalysis, - BT : BTImpl + G: SupportGenerator, + G::SupportType: LocalAnalysis, + BT: BTImpl, { + type OwnedVariant = Self; + + fn into_owned(self) -> Self::OwnedVariant { + self + } + + fn clone_owned(&self) -> Self::OwnedVariant { + self.clone() + } + + fn cow_owned<'b>(self) -> MyCow<'b, Self::OwnedVariant> + where + Self: 'b, + { + MyCow::Owned(self) + } + + fn ref_cow_owned<'b>(&'b self) -> MyCow<'b, Self::OwnedVariant> + where + Self: 'b, + { + MyCow::Borrowed(self) + } +} + +impl Space for BTFN +where + G: SupportGenerator, + G::SupportType: LocalAnalysis, + BT: BTImpl, +{ + type Principal = Self; type Decomp = BasicDecomposition; } -impl -BTFN +impl BTFN where - G : SupportGenerator, - G::SupportType : LocalAnalysis, - BT : BTImpl + G: SupportGenerator, + G::SupportType: LocalAnalysis, + BT: BTImpl, { - /// Create a new BTFN from a support generator and a pre-initialised bisection tree. /// /// The bisection tree `bt` should be pre-initialised to correspond to the `generator`. @@ -67,16 +90,12 @@ /// when the aggregators of the tree may need updates. /// /// See the documentation for [`BTFN`] on the role of the `generator`. - pub fn new(bt : BT, generator : G) -> Self { + pub fn new(bt: BT, generator: G) -> Self { Self::new_arc(bt, Arc::new(generator)) } - fn new_arc(bt : BT, generator : Arc) -> Self { - BTFN { - bt : bt, - generator : generator, - _phantoms : std::marker::PhantomData, - } + fn new_arc(bt: BT, generator: Arc) -> Self { + BTFN { bt, generator, _phantoms: std::marker::PhantomData } } /// Create a new BTFN support generator and a pre-initialised bisection tree, @@ -86,7 +105,7 @@ /// the aggregator may be out of date. /// /// See the documentation for [`BTFN`] on the role of the `generator`. - pub fn new_refresh(bt : &BT, generator : G) -> Self { + pub fn new_refresh(bt: &BT, generator: G) -> Self { // clone().refresh_aggregator(…) as opposed to convert_aggregator // ensures that type is maintained. Due to Rc-pointer copy-on-write, // the effort is not significantly different. @@ -100,11 +119,11 @@ /// The top node of the created [`BT`] will have the given `domain`. /// /// See the documentation for [`BTFN`] on the role of the `generator`. - pub fn construct(domain : Cube, depth : BT::Depth, generator : G) -> Self { + pub fn construct(domain: Cube, depth: BT::Depth, generator: G) -> Self { Self::construct_arc(domain, depth, Arc::new(generator)) } - fn construct_arc(domain : Cube, depth : BT::Depth, generator : Arc) -> Self { + fn construct_arc(domain: Cube, depth: BT::Depth, generator: Arc) -> Self { let mut bt = BT::new(domain, depth); for (d, support) in generator.all_data() { bt.insert(d, &support); @@ -117,14 +136,16 @@ /// This will construct a [`BTFN`] with the same components and generator as the (consumed) /// `self`, but a new `BT` with [`Aggregator`]s of type `ANew`. pub fn convert_aggregator(self) -> BTFN, N> - where ANew : Aggregator, - G : SupportGenerator, - G::SupportType : LocalAnalysis { + where + ANew: Aggregator, + G: SupportGenerator, + G::SupportType: LocalAnalysis, + { BTFN::new_arc(self.bt.convert_aggregator(&*self.generator), self.generator) } /// Change the generator (after, e.g., a scaling of the latter). - fn new_generator(&self, generator : G) -> Self { + fn new_generator(&self, generator: G) -> Self { BTFN::new_refresh(&self.bt, generator) } @@ -132,20 +153,24 @@ fn refresh_aggregator(&mut self) { self.bt.refresh_aggregator(&*self.generator); } - } -impl -BTFN -where G : SupportGenerator { +impl BTFN +where + G: SupportGenerator, +{ /// Change the [bisection tree][BTImpl] of the [`BTFN`] to a different one. /// /// This can be used to convert a [`PreBTFN`] to a full [`BTFN`], or the change - /// the aggreagator; see also [`self.convert_aggregator`]. - pub fn instantiate< - BTNew : BTImpl, - > (self, domain : Cube, depth : BTNew::Depth) -> BTFN - where G::SupportType : LocalAnalysis { + /// the aggreagator; see also [`Self::convert_aggregator`]. + pub fn instantiate>( + self, + domain: Cube, + depth: BTNew::Depth, + ) -> BTFN + where + G::SupportType: LocalAnalysis, + { BTFN::construct_arc(domain, depth, self.generator) } } @@ -155,31 +180,30 @@ /// Most BTFN methods are not available, but if a BTFN is going to be summed with another /// before other use, it will be more efficient to not construct an unnecessary bisection tree /// that would be shortly dropped. -pub type PreBTFN = BTFN; +pub type PreBTFN = BTFN; -impl PreBTFN where G : SupportGenerator { - +impl PreBTFN +where + G: SupportGenerator, +{ /// Create a new [`PreBTFN`] with no bisection tree. - pub fn new_pre(generator : G) -> Self { - BTFN { - bt : (), - generator : Arc::new(generator), - _phantoms : std::marker::PhantomData, - } + pub fn new_pre(generator: G) -> Self { + BTFN { bt: (), generator: Arc::new(generator), _phantoms: std::marker::PhantomData } } } -impl -BTFN -where G : SupportGenerator, - G::SupportType : LocalAnalysis, - BT : BTImpl { - +impl BTFN +where + G: SupportGenerator, + G::SupportType: LocalAnalysis, + BT: BTImpl, +{ /// Helper function for implementing [`std::ops::Add`]. - fn add_another(&self, g2 : Arc) -> BTFN, BT, N> - where G2 : SupportGenerator, - G2::SupportType : LocalAnalysis { - + fn add_another(&self, g2: Arc) -> BTFN, BT, N> + where + G2: SupportGenerator, + G2::SupportType: LocalAnalysis, + { let mut bt = self.bt.clone(); let both = BothGenerators(Arc::clone(&self.generator), g2); @@ -187,11 +211,7 @@ bt.insert(d, &support); } - BTFN { - bt : bt, - generator : Arc::new(both), - _phantoms : std::marker::PhantomData, - } + BTFN { bt: bt, generator: Arc::new(both), _phantoms: std::marker::PhantomData } } } @@ -200,9 +220,9 @@ impl<'a, F : Float, G1, G2, BT1, BT2, const N : usize> std::ops::Add> for $lhs - where BT1 : BTImpl, - G1 : SupportGenerator + $($extra_trait)?, - G2 : SupportGenerator, + where BT1 : BTImpl< N, F, Data=usize>, + G1 : SupportGenerator< N, F, Id=usize> + $($extra_trait)?, + G2 : SupportGenerator< N, F, Id=usize>, G1::SupportType : LocalAnalysis, G2::SupportType : LocalAnalysis { type Output = BTFN, BT1, N>; @@ -215,9 +235,9 @@ impl<'a, 'b, F : Float, G1, G2, BT1, BT2, const N : usize> std::ops::Add<&'b BTFN> for $lhs - where BT1 : BTImpl, - G1 : SupportGenerator + $($extra_trait)?, - G2 : SupportGenerator, + where BT1 : BTImpl< N, F, Data=usize>, + G1 : SupportGenerator< N, F, Id=usize> + $($extra_trait)?, + G2 : SupportGenerator< N, F, Id=usize>, G1::SupportType : LocalAnalysis, G2::SupportType : LocalAnalysis { @@ -231,16 +251,16 @@ } make_btfn_add!(BTFN, std::convert::identity, ); -make_btfn_add!(&'a BTFN, Clone::clone, ); +make_btfn_add!(&'a BTFN, Clone::clone,); macro_rules! make_btfn_sub { ($lhs:ty, $preprocess:path, $($extra_trait:ident)?) => { impl<'a, F : Float, G1, G2, BT1, BT2, const N : usize> std::ops::Sub> for $lhs - where BT1 : BTImpl, - G1 : SupportGenerator + $($extra_trait)?, - G2 : SupportGenerator, + where BT1 : BTImpl< N, F, Data=usize>, + G1 : SupportGenerator< N, F, Id=usize> + $($extra_trait)?, + G2 : SupportGenerator< N, F, Id=usize>, G1::SupportType : LocalAnalysis, G2::SupportType : LocalAnalysis { type Output = BTFN, BT1, N>; @@ -257,9 +277,9 @@ impl<'a, 'b, F : Float, G1, G2, BT1, BT2, const N : usize> std::ops::Sub<&'b BTFN> for $lhs - where BT1 : BTImpl, - G1 : SupportGenerator + $($extra_trait)?, - G2 : SupportGenerator + Clone, + where BT1 : BTImpl< N, F, Data=usize>, + G1 : SupportGenerator< N, F, Id=usize> + $($extra_trait)?, + G2 : SupportGenerator< N, F, Id=usize> + Clone, G1::SupportType : LocalAnalysis, G2::SupportType : LocalAnalysis, &'b G2 : std::ops::Neg { @@ -274,52 +294,52 @@ } make_btfn_sub!(BTFN, std::convert::identity, ); -make_btfn_sub!(&'a BTFN, std::convert::identity, ); +make_btfn_sub!(&'a BTFN, std::convert::identity,); macro_rules! make_btfn_scalarop_rhs { ($trait:ident, $fn:ident, $trait_assign:ident, $fn_assign:ident) => { - impl - std::ops::$trait_assign - for BTFN - where BT : BTImpl, - G : SupportGenerator, - G::SupportType : LocalAnalysis { + impl std::ops::$trait_assign for BTFN + where + BT: BTImpl, + G: SupportGenerator, + G::SupportType: LocalAnalysis, + { #[inline] - fn $fn_assign(&mut self, t : F) { + fn $fn_assign(&mut self, t: F) { Arc::make_mut(&mut self.generator).$fn_assign(t); self.refresh_aggregator(); } } - impl - std::ops::$trait - for BTFN - where BT : BTImpl, - G : SupportGenerator, - G::SupportType : LocalAnalysis { + impl std::ops::$trait for BTFN + where + BT: BTImpl, + G: SupportGenerator, + G::SupportType: LocalAnalysis, + { type Output = Self; #[inline] - fn $fn(mut self, t : F) -> Self::Output { + fn $fn(mut self, t: F) -> Self::Output { Arc::make_mut(&mut self.generator).$fn_assign(t); self.refresh_aggregator(); self } } - impl<'a, F : Float, G, BT, const N : usize> - std::ops::$trait - for &'a BTFN - where BT : BTImpl, - G : SupportGenerator, - G::SupportType : LocalAnalysis, - &'a G : std::ops::$trait { + impl<'a, F: Float, G, BT, const N: usize> std::ops::$trait for &'a BTFN + where + BT: BTImpl, + G: SupportGenerator, + G::SupportType: LocalAnalysis, + &'a G: std::ops::$trait, + { type Output = BTFN; #[inline] - fn $fn(self, t : F) -> Self::Output { + fn $fn(self, t: F) -> Self::Output { self.new_generator(self.generator.$fn(t)) } } - } + }; } make_btfn_scalarop_rhs!(Mul, mul, MulAssign, mul_assign); @@ -330,8 +350,8 @@ impl std::ops::$trait> for $f - where BT : BTImpl<$f, N>, - G : SupportGenerator<$f, N, Id=BT::Data>, + where BT : BTImpl< N, $f>, + G : SupportGenerator< N, $f, Id=BT::Data>, G::SupportType : LocalAnalysis<$f, BT::Agg, N> { type Output = BTFN<$f, G, BT, N>; #[inline] @@ -345,8 +365,8 @@ impl<'a, G, BT, const N : usize> std::ops::$trait<&'a BTFN<$f, G, BT, N>> for $f - where BT : BTImpl<$f, N>, - G : SupportGenerator<$f, N, Id=BT::Data> + Clone, + where BT : BTImpl< N, $f>, + G : SupportGenerator< N, $f, Id=BT::Data> + Clone, G::SupportType : LocalAnalysis<$f, BT::Agg, N>, // FIXME: This causes compiler overflow /*&'a G : std::ops::$trait<$f,Output=G>*/ { @@ -368,12 +388,12 @@ macro_rules! make_btfn_unaryop { ($trait:ident, $fn:ident) => { - impl - std::ops::$trait - for BTFN - where BT : BTImpl, - G : SupportGenerator, - G::SupportType : LocalAnalysis { + impl std::ops::$trait for BTFN + where + BT: BTImpl, + G: SupportGenerator, + G::SupportType: LocalAnalysis, + { type Output = Self; #[inline] fn $fn(mut self) -> Self::Output { @@ -386,8 +406,8 @@ /*impl<'a, F : Float, G, BT, const N : usize> std::ops::$trait for &'a BTFN - where BT : BTImpl, - G : SupportGenerator, + where BT : BTImpl< N, F>, + G : SupportGenerator< N, F, Id=BT::Data>, G::SupportType : LocalAnalysis, &'a G : std::ops::$trait { type Output = BTFN; @@ -396,7 +416,7 @@ self.new_generator(std::ops::$trait::$fn(&self.generator)) } }*/ - } + }; } make_btfn_unaryop!(Neg, neg); @@ -405,39 +425,38 @@ // Apply, Mapping, Differentiate // -impl Mapping> -for BTFN +impl Mapping> for BTFN where - BT : BTImpl, - G : SupportGenerator, - G::SupportType : LocalAnalysis + Mapping, Codomain = V>, - V : Sum + Space, + BT: BTImpl, + G: SupportGenerator, + G::SupportType: LocalAnalysis + Mapping, Codomain = V>, + V: Sum + ClosedSpace, { - type Codomain = V; - fn apply>>(&self, x : I) -> Self::Codomain { - let xc = x.cow(); - self.bt.iter_at(&*xc) - .map(|&d| self.generator.support_for(d).apply(&*xc)).sum() + fn apply>>(&self, x: I) -> Self::Codomain { + let xc = x.decompose(); + self.bt + .iter_at(&*xc) + .map(|&d| self.generator.support_for(d).apply(&*xc)) + .sum() } } -impl DifferentiableImpl> -for BTFN +impl DifferentiableImpl> for BTFN where - BT : BTImpl, - G : SupportGenerator, - G::SupportType : LocalAnalysis - + DifferentiableMapping, DerivativeDomain = V>, - V : Sum + Space, + BT: BTImpl, + G: SupportGenerator, + G::SupportType: + LocalAnalysis + DifferentiableMapping, DerivativeDomain = V>, + V: Sum + ClosedSpace, { - type Derivative = V; - fn differential_impl>>(&self, x :I) -> Self::Derivative { - let xc = x.cow(); - self.bt.iter_at(&*xc) + fn differential_impl>>(&self, x: I) -> Self::Derivative { + let xc = x.decompose(); + self.bt + .iter_at(&*xc) .map(|&d| self.generator.support_for(d).differential(&*xc)) .sum() } @@ -447,12 +466,12 @@ // GlobalAnalysis // -impl GlobalAnalysis -for BTFN -where BT : BTImpl, - G : SupportGenerator, - G::SupportType : LocalAnalysis { - +impl GlobalAnalysis for BTFN +where + BT: BTImpl, + G: SupportGenerator, + G::SupportType: LocalAnalysis, +{ #[inline] fn global_analysis(&self) -> BT::Agg { self.bt.global_analysis() @@ -467,8 +486,8 @@ /* impl<'b, X, F : Float, G, BT, const N : usize> Apply<&'b X, F> for BTFN -where BT : BTImpl, - G : SupportGenerator, +where BT : BTImpl< N, F>, + G : SupportGenerator< N, F, Id=BT::Data>, G::SupportType : LocalAnalysis, X : for<'a> Apply<&'a BTFN, F> { @@ -480,8 +499,8 @@ impl Apply for BTFN -where BT : BTImpl, - G : SupportGenerator, +where BT : BTImpl< N, F>, + G : SupportGenerator< N, F, Id=BT::Data>, G::SupportType : LocalAnalysis, X : for<'a> Apply<&'a BTFN, F> { @@ -493,8 +512,8 @@ impl Linear for BTFN -where BT : BTImpl, - G : SupportGenerator, +where BT : BTImpl< N, F>, + G : SupportGenerator< N, F, Id=BT::Data>, G::SupportType : LocalAnalysis, X : for<'a> Apply<&'a BTFN, F> { type Codomain = F; @@ -505,24 +524,23 @@ /// /// `U` is the domain, generally [`Loc`]``, and `F` the type of floating point numbers. /// `Self` is generally a set of `U`, for example, [`Cube`]``. -pub trait P2Minimise : Set { +pub trait P2Minimise: Set { /// Minimise `g` over the set presented by `Self`. /// /// The function returns `(x, v)` where `x` is the minimiser `v` an approximation of `g(x)`. - fn p2_minimise F>(&self, g : G) -> (U, F); - + fn p2_minimise F>(&self, g: G) -> (U, F); } -impl P2Minimise, F> for Cube { - fn p2_minimise) -> F>(&self, g : G) -> (Loc, F) { +impl P2Minimise, F> for Cube<1, F> { + fn p2_minimise) -> F>(&self, g: G) -> (Loc<1, F>, F) { let interval = Simplex(self.corners()); interval.p2_model(&g).minimise(&interval) } } #[replace_float_literals(F::cast_from(literal))] -impl P2Minimise, F> for Cube { - fn p2_minimise) -> F>(&self, g : G) -> (Loc, F) { +impl P2Minimise, F> for Cube<2, F> { + fn p2_minimise) -> F>(&self, g: G) -> (Loc<2, F>, F) { if false { // Split into two triangle (simplex) with separate P2 model in each. // The six nodes of each triangle are the corners and the edges. @@ -537,49 +555,46 @@ let [vab, vbc, vca, vcd, vda] = [g(&ab), g(&bc), g(&ca), g(&cd), g(&da)]; let s1 = Simplex([a, b, c]); - let m1 = P2LocalModel::::new( - &[a, b, c, ab, bc, ca], - &[va, vb, vc, vab, vbc, vca] - ); + let m1 = + P2LocalModel::::new(&[a, b, c, ab, bc, ca], &[va, vb, vc, vab, vbc, vca]); - let r1@(_, v1) = m1.minimise(&s1); + let r1 @ (_, v1) = m1.minimise(&s1); let s2 = Simplex([c, d, a]); - let m2 = P2LocalModel::::new( - &[c, d, a, cd, da, ca], - &[vc, vd, va, vcd, vda, vca] - ); + let m2 = + P2LocalModel::::new(&[c, d, a, cd, da, ca], &[vc, vd, va, vcd, vda, vca]); + + let r2 @ (_, v2) = m2.minimise(&s2); - let r2@(_, v2) = m2.minimise(&s2); - - if v1 < v2 { r1 } else { r2 } + if v1 < v2 { + r1 + } else { + r2 + } } else { // Single P2 model for the entire cube. let [a, b, c, d] = self.corners(); let [va, vb, vc, vd] = [g(&a), g(&b), g(&c), g(&d)]; let [e, f] = match 'r' { - 'm' => [(&a + &b + &c) / 3.0, (&c + &d + &a) / 3.0], - 'c' => [midpoint(&a, &b), midpoint(&a, &d)], - 'w' => [(&a + &b * 2.0) / 3.0, (&a + &d * 2.0) / 3.0], - 'r' => { + 'm' => [(&a + &b + &c) / 3.0, (&c + &d + &a) / 3.0], + 'c' => [midpoint(&a, &b), midpoint(&a, &d)], + 'w' => [(&a + &b * 2.0) / 3.0, (&a + &d * 2.0) / 3.0], + 'r' => { // Pseudo-randomise edge midpoints let Loc([x, y]) = a; - let tmp : f64 = (x+y).as_(); + let tmp: f64 = (x + y).as_(); match tmp.to_bits() % 4 { - 0 => [midpoint(&a, &b), midpoint(&a, &d)], - 1 => [midpoint(&c, &d), midpoint(&a, &d)], - 2 => [midpoint(&a, &b), midpoint(&b, &c)], - _ => [midpoint(&c, &d), midpoint(&b, &c)], + 0 => [midpoint(&a, &b), midpoint(&a, &d)], + 1 => [midpoint(&c, &d), midpoint(&a, &d)], + 2 => [midpoint(&a, &b), midpoint(&b, &c)], + _ => [midpoint(&c, &d), midpoint(&b, &c)], } - }, - _ => [self.center(), (&a + &b) / 2.0], + } + _ => [self.center(), (&a + &b) / 2.0], }; let [ve, vf] = [g(&e), g(&f)]; - let m1 = P2LocalModel::::new( - &[a, b, c, d, e, f], - &[va, vb, vc, vd, ve, vf], - ); + let m1 = P2LocalModel::::new(&[a, b, c, d, e, f], &[va, vb, vc, vd, ve, vf]); m1.minimise(self) } @@ -595,44 +610,46 @@ /// A bisection tree [`Refiner`] for maximising or minimising a [`BTFN`]. /// /// The type parameter `T` should be either [`RefineMax`] or [`RefineMin`]. -struct P2Refiner { +struct P2Refiner { /// The maximum / minimum should be above / below this threshold. /// If the threshold cannot be satisfied, the refiner will return `None`. - bound : Option, + bound: Option, /// Tolerance for function value estimation. - tolerance : F, + tolerance: F, /// Maximum number of steps to execute the refiner for - max_steps : usize, + max_steps: usize, /// Either [`RefineMax`] or [`RefineMin`]. Used only for type system purposes. #[allow(dead_code)] // `how` is just for type system purposes. - how : T, + how: T, } -impl Refiner, G, N> -for P2Refiner -where Cube : P2Minimise, F>, - G : SupportGenerator, - G::SupportType : Mapping, Codomain=F> - + LocalAnalysis, N> { - type Result = Option<(Loc, F)>; +impl Refiner, G, N> for P2Refiner +where + Cube: P2Minimise, F>, + G: SupportGenerator, + G::SupportType: Mapping, Codomain = F> + LocalAnalysis, N>, +{ + type Result = Option<(Loc, F)>; type Sorting = UpperBoundSorting; fn refine( &self, - aggregator : &Bounds, - cube : &Cube, - data : &[G::Id], - generator : &G, - step : usize + aggregator: &Bounds, + cube: &Cube, + data: &[G::Id], + generator: &G, + step: usize, ) -> RefinerResult, Self::Result> { - - if self.bound.map_or(false, |b| aggregator.upper() <= b + self.tolerance) { + if self + .bound + .map_or(false, |b| aggregator.upper() <= b + self.tolerance) + { // The upper bound is below the maximisation threshold. Don't bother with this cube. - return RefinerResult::Uncertain(*aggregator, None) + return RefinerResult::Uncertain(*aggregator, None); } // g gives the negative of the value of the function presented by `data` and `generator`. - let g = move |x : &Loc| { + let g = move |x: &Loc| { let f = move |&d| generator.support_for(d).apply(x); -data.iter().map(f).sum::() }; @@ -641,8 +658,9 @@ //let v = -neg_v; let v = -g(&x); - if step < self.max_steps && (aggregator.upper() > v + self.tolerance - /*|| aggregator.lower() > v - self.tolerance*/) { + if step < self.max_steps + && (aggregator.upper() > v + self.tolerance/*|| aggregator.lower() > v - self.tolerance*/) + { // The function isn't refined enough in `cube`, so return None // to indicate that further subdivision is required. RefinerResult::NeedRefinement @@ -655,41 +673,46 @@ } } - fn fuse_results(r1 : &mut Self::Result, r2 : Self::Result) { + fn fuse_results(r1: &mut Self::Result, r2: Self::Result) { match (*r1, r2) { - (Some((_, v1)), Some((_, v2))) => if v1 < v2 { *r1 = r2 } + (Some((_, v1)), Some((_, v2))) => { + if v1 < v2 { + *r1 = r2 + } + } (None, Some(_)) => *r1 = r2, - (_, _) => {}, + (_, _) => {} } } } - -impl Refiner, G, N> -for P2Refiner -where Cube : P2Minimise, F>, - G : SupportGenerator, - G::SupportType : Mapping, Codomain=F> - + LocalAnalysis, N> { - type Result = Option<(Loc, F)>; +impl Refiner, G, N> for P2Refiner +where + Cube: P2Minimise, F>, + G: SupportGenerator, + G::SupportType: Mapping, Codomain = F> + LocalAnalysis, N>, +{ + type Result = Option<(Loc, F)>; type Sorting = LowerBoundSorting; fn refine( &self, - aggregator : &Bounds, - cube : &Cube, - data : &[G::Id], - generator : &G, - step : usize + aggregator: &Bounds, + cube: &Cube, + data: &[G::Id], + generator: &G, + step: usize, ) -> RefinerResult, Self::Result> { - - if self.bound.map_or(false, |b| aggregator.lower() >= b - self.tolerance) { + if self + .bound + .map_or(false, |b| aggregator.lower() >= b - self.tolerance) + { // The lower bound is above the minimisation threshold. Don't bother with this cube. - return RefinerResult::Uncertain(*aggregator, None) + return RefinerResult::Uncertain(*aggregator, None); } // g gives the value of the function presented by `data` and `generator`. - let g = move |x : &Loc| { + let g = move |x: &Loc| { let f = move |&d| generator.support_for(d).apply(x); data.iter().map(f).sum::() }; @@ -697,8 +720,9 @@ let (x, _v) = cube.p2_minimise(g); let v = g(&x); - if step < self.max_steps && (aggregator.lower() < v - self.tolerance - /*|| aggregator.upper() < v + self.tolerance*/) { + if step < self.max_steps + && (aggregator.lower() < v - self.tolerance/*|| aggregator.upper() < v + self.tolerance*/) + { // The function isn't refined enough in `cube`, so return None // to indicate that further subdivision is required. RefinerResult::NeedRefinement @@ -717,47 +741,51 @@ } } - fn fuse_results(r1 : &mut Self::Result, r2 : Self::Result) { + fn fuse_results(r1: &mut Self::Result, r2: Self::Result) { match (*r1, r2) { - (Some((_, v1)), Some((_, v2))) => if v1 > v2 { *r1 = r2 } + (Some((_, v1)), Some((_, v2))) => { + if v1 > v2 { + *r1 = r2 + } + } (_, Some(_)) => *r1 = r2, - (_, _) => {}, + (_, _) => {} } } } - /// A bisection tree [`Refiner`] for checking that a [`BTFN`] is within a stated //// upper or lower bound. /// /// The type parameter `T` should be either [`RefineMax`] for upper bound or [`RefineMin`] /// for lower bound. -struct BoundRefiner { +struct BoundRefiner { /// The upper/lower bound to check for - bound : F, + bound: F, /// Tolerance for function value estimation. - tolerance : F, + tolerance: F, /// Maximum number of steps to execute the refiner for - max_steps : usize, + max_steps: usize, #[allow(dead_code)] // `how` is just for type system purposes. /// Either [`RefineMax`] or [`RefineMin`]. Used only for type system purposes. - how : T, + how: T, } -impl Refiner, G, N> -for BoundRefiner -where G : SupportGenerator { +impl Refiner, G, N> for BoundRefiner +where + G: SupportGenerator, +{ type Result = bool; type Sorting = UpperBoundSorting; fn refine( &self, - aggregator : &Bounds, - _cube : &Cube, - _data : &[G::Id], - _generator : &G, - step : usize + aggregator: &Bounds, + _cube: &Cube, + _data: &[G::Id], + _generator: &G, + step: usize, ) -> RefinerResult, Self::Result> { if aggregator.upper() <= self.bound + self.tolerance { // Below upper bound within tolerances. Indicate uncertain success. @@ -774,24 +802,25 @@ } } - fn fuse_results(r1 : &mut Self::Result, r2 : Self::Result) { + fn fuse_results(r1: &mut Self::Result, r2: Self::Result) { *r1 = *r1 && r2; } } -impl Refiner, G, N> -for BoundRefiner -where G : SupportGenerator { +impl Refiner, G, N> for BoundRefiner +where + G: SupportGenerator, +{ type Result = bool; type Sorting = UpperBoundSorting; fn refine( &self, - aggregator : &Bounds, - _cube : &Cube, - _data : &[G::Id], - _generator : &G, - step : usize + aggregator: &Bounds, + _cube: &Cube, + _data: &[G::Id], + _generator: &G, + step: usize, ) -> RefinerResult, Self::Result> { if aggregator.lower() >= self.bound - self.tolerance { // Above lower bound within tolerances. Indicate uncertain success. @@ -808,7 +837,7 @@ } } - fn fuse_results(r1 : &mut Self::Result, r2 : Self::Result) { + fn fuse_results(r1: &mut Self::Result, r2: Self::Result) { *r1 = *r1 && r2; } } @@ -828,66 +857,64 @@ // there should be a result, or new nodes above the `glb` inserted into the queue. Then the waiting // threads can also continue processing. If, however, numerical inaccuracy destroyes the `glb`, // the queue may run out, and we get “Refiner failure”. -impl BTFN -where BT : BTSearch>, - G : SupportGenerator, - G::SupportType : Mapping, Codomain=F> - + LocalAnalysis, N>, - Cube : P2Minimise, F> { - - /// Maximise the `BTFN` within stated value `tolerance`. - /// - /// At most `max_steps` refinement steps are taken. - /// Returns the approximate maximiser and the corresponding function value. - pub fn maximise(&mut self, tolerance : F, max_steps : usize) -> (Loc, F) { - let refiner = P2Refiner{ tolerance, max_steps, how : RefineMax, bound : None }; - self.bt.search_and_refine(refiner, &self.generator).expect("Refiner failure.").unwrap() +impl MinMaxMapping, F> for BTFN +where + BT: BTSearch>, + G: SupportGenerator, + G::SupportType: Mapping, Codomain = F> + LocalAnalysis, N>, + Cube: P2Minimise, F>, +{ + fn maximise(&mut self, tolerance: F, max_steps: usize) -> (Loc, F) { + let refiner = P2Refiner { tolerance, max_steps, how: RefineMax, bound: None }; + self.bt + .search_and_refine(refiner, &self.generator) + .expect("Refiner failure.") + .unwrap() } - /// Maximise the `BTFN` within stated value `tolerance` subject to a lower bound. - /// - /// At most `max_steps` refinement steps are taken. - /// Returns the approximate maximiser and the corresponding function value when one is found - /// above the `bound` threshold, otherwise `None`. - pub fn maximise_above(&mut self, bound : F, tolerance : F, max_steps : usize) - -> Option<(Loc, F)> { - let refiner = P2Refiner{ tolerance, max_steps, how : RefineMax, bound : Some(bound) }; - self.bt.search_and_refine(refiner, &self.generator).expect("Refiner failure.") + fn maximise_above( + &mut self, + bound: F, + tolerance: F, + max_steps: usize, + ) -> Option<(Loc, F)> { + let refiner = P2Refiner { tolerance, max_steps, how: RefineMax, bound: Some(bound) }; + self.bt + .search_and_refine(refiner, &self.generator) + .expect("Refiner failure.") } - /// Minimise the `BTFN` within stated value `tolerance`. - /// - /// At most `max_steps` refinement steps are taken. - /// Returns the approximate minimiser and the corresponding function value. - pub fn minimise(&mut self, tolerance : F, max_steps : usize) -> (Loc, F) { - let refiner = P2Refiner{ tolerance, max_steps, how : RefineMin, bound : None }; - self.bt.search_and_refine(refiner, &self.generator).expect("Refiner failure.").unwrap() + fn minimise(&mut self, tolerance: F, max_steps: usize) -> (Loc, F) { + let refiner = P2Refiner { tolerance, max_steps, how: RefineMin, bound: None }; + self.bt + .search_and_refine(refiner, &self.generator) + .expect("Refiner failure.") + .unwrap() } - /// Minimise the `BTFN` within stated value `tolerance` subject to a lower bound. - /// - /// At most `max_steps` refinement steps are taken. - /// Returns the approximate minimiser and the corresponding function value when one is found - /// above the `bound` threshold, otherwise `None`. - pub fn minimise_below(&mut self, bound : F, tolerance : F, max_steps : usize) - -> Option<(Loc, F)> { - let refiner = P2Refiner{ tolerance, max_steps, how : RefineMin, bound : Some(bound) }; - self.bt.search_and_refine(refiner, &self.generator).expect("Refiner failure.") + fn minimise_below( + &mut self, + bound: F, + tolerance: F, + max_steps: usize, + ) -> Option<(Loc, F)> { + let refiner = P2Refiner { tolerance, max_steps, how: RefineMin, bound: Some(bound) }; + self.bt + .search_and_refine(refiner, &self.generator) + .expect("Refiner failure.") } - /// Verify that the `BTFN` has a given upper `bound` within indicated `tolerance`. - /// - /// At most `max_steps` refinement steps are taken. - pub fn has_upper_bound(&mut self, bound : F, tolerance : F, max_steps : usize) -> bool { - let refiner = BoundRefiner{ bound, tolerance, max_steps, how : RefineMax }; - self.bt.search_and_refine(refiner, &self.generator).expect("Refiner failure.") + fn has_upper_bound(&mut self, bound: F, tolerance: F, max_steps: usize) -> bool { + let refiner = BoundRefiner { bound, tolerance, max_steps, how: RefineMax }; + self.bt + .search_and_refine(refiner, &self.generator) + .expect("Refiner failure.") } - /// Verify that the `BTFN` has a given lower `bound` within indicated `tolerance`. - /// - /// At most `max_steps` refinement steps are taken. - pub fn has_lower_bound(&mut self, bound : F, tolerance : F, max_steps : usize) -> bool { - let refiner = BoundRefiner{ bound, tolerance, max_steps, how : RefineMin }; - self.bt.search_and_refine(refiner, &self.generator).expect("Refiner failure.") + fn has_lower_bound(&mut self, bound: F, tolerance: F, max_steps: usize) -> bool { + let refiner = BoundRefiner { bound, tolerance, max_steps, how: RefineMin }; + self.bt + .search_and_refine(refiner, &self.generator) + .expect("Refiner failure.") } } diff -r 1f19c6bbf07b -r 3868555d135c src/bisection_tree/either.rs --- a/src/bisection_tree/either.rs Sun Apr 27 20:29:43 2025 -0500 +++ b/src/bisection_tree/either.rs Fri May 15 14:46:30 2026 -0500 @@ -1,36 +1,28 @@ - use std::iter::Chain; use std::sync::Arc; -use crate::types::*; +use crate::iter::{MapF, MapZ, Mappable}; +use crate::loc::Loc; use crate::mapping::{ - Instance, - Mapping, - DifferentiableImpl, - DifferentiableMapping, - Space, + ClosedSpace, DifferentiableImpl, DifferentiableMapping, Instance, Mapping, Space, }; -use crate::iter::{Mappable, MapF, MapZ}; use crate::sets::Cube; -use crate::loc::Loc; +use crate::types::*; +use super::aggregator::*; use super::support::*; -use super::aggregator::*; /// A structure for storing two [`SupportGenerator`]s summed/chain together. /// /// This is needed to work with sums of different types of [`Support`]s. -#[derive(Debug,Clone)] -pub struct BothGenerators( - pub(super) Arc, - pub(super) Arc, -); +#[derive(Debug, Clone)] +pub struct BothGenerators(pub(super) Arc, pub(super) Arc); /// A structure for a [`Support`] that can be either `A` or `B`. /// /// This is needed to work with sums of different types of [`Support`]s. -#[derive(Debug,Clone)] -pub enum EitherSupport { +#[derive(Debug, Clone)] +pub enum EitherSupport { Left(A), Right(B), } @@ -38,46 +30,55 @@ // We need type alias bounds to access associate types. #[allow(type_alias_bounds)] type BothAllDataIter< - 'a, F, - G1 : SupportGenerator, - G2 : SupportGenerator, - const N : usize + 'a, + F, + G1: SupportGenerator, + G2: SupportGenerator, + const N: usize, > = Chain< - MapF, (usize, EitherSupport)>, - MapZ, usize, (usize, EitherSupport)>, + MapF, (usize, EitherSupport)>, + MapZ, usize, (usize, EitherSupport)>, >; impl BothGenerators { /// Helper for [`all_left_data`]. #[inline] - fn map_left((d, support) : (G1::Id, G1::SupportType)) - -> (usize, EitherSupport) - where G1 : SupportGenerator, - G2 : SupportGenerator { - - let id : usize = d.into(); + fn map_left( + (d, support): (G1::Id, G1::SupportType), + ) -> (usize, EitherSupport) + where + G1: SupportGenerator, + G2: SupportGenerator, + { + let id: usize = d.into(); (id.into(), EitherSupport::Left(support)) } /// Helper for [`all_right_data`]. #[inline] - fn map_right(n0 : &usize, (d, support) : (G2::Id, G2::SupportType)) - -> (usize, EitherSupport) - where G1 : SupportGenerator, - G2 : SupportGenerator { - - let id : usize = d.into(); - ((n0+id).into(), EitherSupport::Right(support)) + fn map_right( + n0: &usize, + (d, support): (G2::Id, G2::SupportType), + ) -> (usize, EitherSupport) + where + G1: SupportGenerator, + G2: SupportGenerator, + { + let id: usize = d.into(); + ((n0 + id).into(), EitherSupport::Right(support)) } /// Calls [`SupportGenerator::all_data`] on the “left” support generator. /// /// Converts both the id and the [`Support`] into a form that corresponds to `BothGenerators`. #[inline] - pub(super) fn all_left_data(&self) - -> MapF, (usize, EitherSupport)> - where G1 : SupportGenerator, - G2 : SupportGenerator { + pub(super) fn all_left_data( + &self, + ) -> MapF, (usize, EitherSupport)> + where + G1: SupportGenerator, + G2: SupportGenerator, + { self.0.all_data().mapF(Self::map_left) } @@ -85,33 +86,38 @@ /// /// Converts both the id and the [`Support`] into a form that corresponds to `BothGenerators`. #[inline] - pub(super) fn all_right_data(&self) - -> MapZ, usize, (usize, EitherSupport)> - where G1 : SupportGenerator, - G2 : SupportGenerator { + pub(super) fn all_right_data( + &self, + ) -> MapZ, usize, (usize, EitherSupport)> + where + G1: SupportGenerator, + G2: SupportGenerator, + { let n0 = self.0.support_count(); self.1.all_data().mapZ(n0, Self::map_right) } } -impl -SupportGenerator -for BothGenerators -where G1 : SupportGenerator, - G2 : SupportGenerator { - +impl SupportGenerator for BothGenerators +where + G1: SupportGenerator, + G2: SupportGenerator, +{ type Id = usize; - type SupportType = EitherSupport; - type AllDataIter<'a> = BothAllDataIter<'a, F, G1, G2, N> where G1 : 'a, G2 : 'a; + type SupportType = EitherSupport; + type AllDataIter<'a> + = BothAllDataIter<'a, F, G1, G2, N> + where + G1: 'a, + G2: 'a; #[inline] - fn support_for(&self, id : Self::Id) - -> Self::SupportType { + fn support_for(&self, id: Self::Id) -> Self::SupportType { let n0 = self.0.support_count(); if id < n0 { EitherSupport::Left(self.0.support_for(id.into())) } else { - EitherSupport::Right(self.1.support_for((id-n0).into())) + EitherSupport::Right(self.1.support_for((id - n0).into())) } } @@ -126,12 +132,13 @@ } } -impl Support for EitherSupport -where S1 : Support, - S2 : Support { - +impl Support for EitherSupport +where + S1: Support, + S2: Support, +{ #[inline] - fn support_hint(&self) -> Cube { + fn support_hint(&self) -> Cube { match self { EitherSupport::Left(ref a) => a.support_hint(), EitherSupport::Right(ref b) => b.support_hint(), @@ -139,7 +146,7 @@ } #[inline] - fn in_support(&self, x : &Loc) -> bool { + fn in_support(&self, x: &Loc) -> bool { match self { EitherSupport::Left(ref a) => a.in_support(x), EitherSupport::Right(ref b) => b.in_support(x), @@ -147,7 +154,7 @@ } #[inline] - fn bisection_hint(&self, cube : &Cube) -> [Option; N] { + fn bisection_hint(&self, cube: &Cube) -> [Option; N] { match self { EitherSupport::Left(ref a) => a.bisection_hint(cube), EitherSupport::Right(ref b) => b.bisection_hint(cube), @@ -155,13 +162,14 @@ } } -impl LocalAnalysis for EitherSupport -where A : Aggregator, - S1 : LocalAnalysis, - S2 : LocalAnalysis, { - +impl LocalAnalysis for EitherSupport +where + A: Aggregator, + S1: LocalAnalysis, + S2: LocalAnalysis, +{ #[inline] - fn local_analysis(&self, cube : &Cube) -> A { + fn local_analysis(&self, cube: &Cube) -> A { match self { EitherSupport::Left(ref a) => a.local_analysis(cube), EitherSupport::Right(ref b) => b.local_analysis(cube), @@ -169,11 +177,12 @@ } } -impl GlobalAnalysis for EitherSupport -where A : Aggregator, - S1 : GlobalAnalysis, - S2 : GlobalAnalysis, { - +impl GlobalAnalysis for EitherSupport +where + A: Aggregator, + S1: GlobalAnalysis, + S2: GlobalAnalysis, +{ #[inline] fn global_analysis(&self) -> A { match self { @@ -183,17 +192,17 @@ } } -impl Mapping for EitherSupport +impl Mapping for EitherSupport where - F : Space, - X : Space, - S1 : Mapping, - S2 : Mapping, + F: ClosedSpace, + X: Space, + S1: Mapping, + S2: Mapping, { type Codomain = F; #[inline] - fn apply>(&self, x : I) -> F { + fn apply>(&self, x: I) -> F { match self { EitherSupport::Left(ref a) => a.apply(x), EitherSupport::Right(ref b) => b.apply(x), @@ -201,17 +210,17 @@ } } -impl DifferentiableImpl for EitherSupport +impl DifferentiableImpl for EitherSupport where - O : Space, - X : Space, - S1 : DifferentiableMapping, - S2 : DifferentiableMapping, + O: ClosedSpace, + X: ClosedSpace, + S1: DifferentiableMapping, + S2: DifferentiableMapping, { type Derivative = O; #[inline] - fn differential_impl>(&self, x : I) -> O { + fn differential_impl>(&self, x: I) -> O { match self { EitherSupport::Left(ref a) => a.differential(x), EitherSupport::Right(ref b) => b.differential(x), @@ -221,44 +230,47 @@ macro_rules! make_either_scalarop_rhs { ($trait:ident, $fn:ident, $trait_assign:ident, $fn_assign:ident) => { - impl - std::ops::$trait_assign - for BothGenerators - where G1 : std::ops::$trait_assign + Clone, - G2 : std::ops::$trait_assign + Clone, { + impl std::ops::$trait_assign for BothGenerators + where + G1: std::ops::$trait_assign + Clone, + G2: std::ops::$trait_assign + Clone, + { #[inline] - fn $fn_assign(&mut self, t : F) { + fn $fn_assign(&mut self, t: F) { Arc::make_mut(&mut self.0).$fn_assign(t); Arc::make_mut(&mut self.1).$fn_assign(t); } } - impl<'a, F : Float, G1, G2> - std::ops::$trait - for &'a BothGenerators - where &'a G1 : std::ops::$trait, - &'a G2 : std::ops::$trait { + impl<'a, F: Float, G1, G2> std::ops::$trait for &'a BothGenerators + where + &'a G1: std::ops::$trait, + &'a G2: std::ops::$trait, + { type Output = BothGenerators; #[inline] - fn $fn(self, t : F) -> BothGenerators { - BothGenerators(Arc::new(self.0.$fn(t)), - Arc::new(self.1.$fn(t))) + fn $fn(self, t: F) -> BothGenerators { + BothGenerators(Arc::new(self.0.$fn(t)), Arc::new(self.1.$fn(t))) } } - } + }; } make_either_scalarop_rhs!(Mul, mul, MulAssign, mul_assign); make_either_scalarop_rhs!(Div, div, DivAssign, div_assign); impl std::ops::Neg for BothGenerators -where G1 : std::ops::Neg + Clone, - G2 : std::ops::Neg + Clone, { +where + G1: std::ops::Neg + Clone, + G2: std::ops::Neg + Clone, +{ type Output = BothGenerators; #[inline] fn neg(self) -> Self::Output { - BothGenerators(Arc::new(Arc::unwrap_or_clone(self.0).neg()), - Arc::new(Arc::unwrap_or_clone(self.1).neg())) + BothGenerators( + Arc::new(Arc::unwrap_or_clone(self.0).neg()), + Arc::new(Arc::unwrap_or_clone(self.1).neg()), + ) } } /* @@ -270,4 +282,4 @@ BothGenerators(self.0.neg(), self.1.neg()) } } -*/ \ No newline at end of file +*/ diff -r 1f19c6bbf07b -r 3868555d135c src/bisection_tree/refine.rs --- a/src/bisection_tree/refine.rs Sun Apr 27 20:29:43 2025 -0500 +++ b/src/bisection_tree/refine.rs Fri May 15 14:46:30 2026 -0500 @@ -1,32 +1,31 @@ - -use std::collections::BinaryHeap; -use std::cmp::{PartialOrd, Ord, Ordering, Ordering::*, max}; -use std::marker::PhantomData; -use std::sync::{Arc, Mutex, MutexGuard, Condvar}; -use crate::types::*; +use super::aggregator::*; +use super::bt::*; +use super::support::*; use crate::nanleast::NaNLeast; +use crate::parallelism::TaskBudget; +use crate::parallelism::{thread_pool, thread_pool_size}; use crate::sets::Cube; -use crate::parallelism::{thread_pool_size, thread_pool}; -use super::support::*; -use super::bt::*; -use super::aggregator::*; -use crate::parallelism::TaskBudget; +use crate::types::*; +use std::cmp::{max, Ord, Ordering, Ordering::*, PartialOrd}; +use std::collections::BinaryHeap; +use std::marker::PhantomData; +use std::sync::{Arc, Condvar, Mutex, MutexGuard}; /// Trait for sorting [`Aggregator`]s for [`BT`] refinement. /// /// The sorting involves two sorting keys, the “upper” and the “lower” key. Any [`BT`] nodes /// with upper key less the lower key of another are discarded from the refinement process. /// Nodes with the highest upper sorting key are picked for refinement. -pub trait AggregatorSorting : Sync + Send + 'static { +pub trait AggregatorSorting: Sync + Send + 'static { // Priority - type Agg : Aggregator; - type Sort : Ord + Copy + std::fmt::Debug + Sync + Send; + type Agg: Aggregator; + type Sort: Ord + Copy + std::fmt::Debug + Sync + Send; /// Returns lower sorting key - fn sort_lower(aggregator : &Self::Agg) -> Self::Sort; + fn sort_lower(aggregator: &Self::Agg) -> Self::Sort; /// Returns upper sorting key - fn sort_upper(aggregator : &Self::Agg) -> Self::Sort; + fn sort_upper(aggregator: &Self::Agg) -> Self::Sort; /// Returns a sorting key that is less than any other sorting key. fn bottom() -> Self::Sort; @@ -35,53 +34,64 @@ /// An [`AggregatorSorting`] for [`Bounds`], using the upper/lower bound as the upper/lower key. /// /// See [`LowerBoundSorting`] for the opposite ordering. -pub struct UpperBoundSorting(PhantomData); +pub struct UpperBoundSorting(PhantomData); /// An [`AggregatorSorting`] for [`Bounds`], using the upper/lower bound as the lower/upper key. /// /// See [`UpperBoundSorting`] for the opposite ordering. -pub struct LowerBoundSorting(PhantomData); +pub struct LowerBoundSorting(PhantomData); -impl AggregatorSorting for UpperBoundSorting { +impl AggregatorSorting for UpperBoundSorting { type Agg = Bounds; type Sort = NaNLeast; #[inline] - fn sort_lower(aggregator : &Bounds) -> Self::Sort { NaNLeast(aggregator.lower()) } - - #[inline] - fn sort_upper(aggregator : &Bounds) -> Self::Sort { NaNLeast(aggregator.upper()) } + fn sort_lower(aggregator: &Bounds) -> Self::Sort { + NaNLeast(aggregator.lower()) + } #[inline] - fn bottom() -> Self::Sort { NaNLeast(F::NEG_INFINITY) } + fn sort_upper(aggregator: &Bounds) -> Self::Sort { + NaNLeast(aggregator.upper()) + } + + #[inline] + fn bottom() -> Self::Sort { + NaNLeast(F::NEG_INFINITY) + } } - -impl AggregatorSorting for LowerBoundSorting { +impl AggregatorSorting for LowerBoundSorting { type Agg = Bounds; type Sort = NaNLeast; #[inline] - fn sort_upper(aggregator : &Bounds) -> Self::Sort { NaNLeast(-aggregator.lower()) } + fn sort_upper(aggregator: &Bounds) -> Self::Sort { + NaNLeast(-aggregator.lower()) + } #[inline] - fn sort_lower(aggregator : &Bounds) -> Self::Sort { NaNLeast(-aggregator.upper()) } + fn sort_lower(aggregator: &Bounds) -> Self::Sort { + NaNLeast(-aggregator.upper()) + } #[inline] - fn bottom() -> Self::Sort { NaNLeast(F::NEG_INFINITY) } + fn bottom() -> Self::Sort { + NaNLeast(F::NEG_INFINITY) + } } /// Return type of [`Refiner::refine`]. /// /// The parameter `R` is the result type of the refiner acting on an [`Aggregator`] of type `A`. -pub enum RefinerResult { +pub enum RefinerResult { /// Indicates an insufficiently refined state: the [`BT`] needs to be further refined. NeedRefinement, /// Indicates a certain result `R`, stop refinement immediately. Certain(R), /// Indicates an uncertain result: continue refinement until candidates have been exhausted /// or a certain result found. - Uncertain(A, R) + Uncertain(A, R), } use RefinerResult::*; @@ -92,16 +102,17 @@ /// The `Refiner` is used to determine whether an [`Aggregator`] `A` stored in the [`BT`] is /// sufficiently refined within a [`Cube`], and in such a case, produce a desired result (e.g. /// a maximum value of a function). -pub trait Refiner : Sync + Send + 'static -where F : Num, - A : Aggregator, - G : SupportGenerator { - +pub trait Refiner: Sync + Send + 'static +where + F: Num, + A: Aggregator, + G: SupportGenerator, +{ /// The result type of the refiner - type Result : std::fmt::Debug + Sync + Send + 'static; + type Result: std::fmt::Debug + Sync + Send + 'static; /// The sorting to be employed by [`BTSearch::search_and_refine`] on node aggregators /// to detemrine node priority. - type Sorting : AggregatorSorting; + type Sorting: AggregatorSorting; /// Determines whether `aggregator` is sufficiently refined within `domain`. /// @@ -124,42 +135,45 @@ /// number of steps is reached. fn refine( &self, - aggregator : &A, - domain : &Cube, - data : &[G::Id], - generator : &G, - step : usize, + aggregator: &A, + domain: &Cube, + data: &[G::Id], + generator: &G, + step: usize, ) -> RefinerResult; /// Fuse two [`Self::Result`]s (needed in threaded refinement). - fn fuse_results(r1 : &mut Self::Result, r2 : Self::Result); + fn fuse_results(r1: &mut Self::Result, r2: Self::Result); } /// Structure for tracking the refinement process in a [`BinaryHeap`]. -struct RefinementInfo<'a, F, D, A, S, RResult, const N : usize, const P : usize> -where F : Float, - D : 'static, - A : Aggregator, - S : AggregatorSorting { +struct RefinementInfo<'a, F, D, A, S, RResult, const N: usize, const P: usize> +where + F: Float, + D: 'static, + A: Aggregator, + S: AggregatorSorting, +{ /// Domain of `node` - cube : Cube, + cube: Cube, /// Node to be refined - node : &'a mut Node, + node: &'a mut Node, /// Result and improve aggregator for the [`Refiner`] - refiner_info : Option<(A, RResult)>, + refiner_info: Option<(A, RResult)>, /// For [`AggregatorSorting`] being used for the type system - sorting : PhantomData, + sorting: PhantomData, } -impl<'a, F, D, A, S, RResult, const N : usize, const P : usize> -RefinementInfo<'a, F, D, A, S, RResult, N, P> -where F : Float, - D : 'static, - A : Aggregator, - S : AggregatorSorting { - +impl<'a, F, D, A, S, RResult, const N: usize, const P: usize> + RefinementInfo<'a, F, D, A, S, RResult, N, P> +where + F: Float, + D: 'static, + A: Aggregator, + S: AggregatorSorting, +{ #[inline] - fn with_aggregator(&self, f : impl FnOnce(&A) -> U) -> U { + fn with_aggregator(&self, f: impl FnOnce(&A) -> U) -> U { match self.refiner_info { Some((ref agg, _)) => f(agg), None => f(&self.node.aggregator), @@ -177,93 +191,105 @@ } } -impl<'a, F, D, A, S, RResult, const N : usize, const P : usize> PartialEq -for RefinementInfo<'a, F, D, A, S, RResult, N, P> -where F : Float, - D : 'static, - A : Aggregator, - S : AggregatorSorting { - +impl<'a, F, D, A, S, RResult, const N: usize, const P: usize> PartialEq + for RefinementInfo<'a, F, D, A, S, RResult, N, P> +where + F: Float, + D: 'static, + A: Aggregator, + S: AggregatorSorting, +{ #[inline] - fn eq(&self, other : &Self) -> bool { self.cmp(other) == Equal } -} - -impl<'a, F, D, A, S, RResult, const N : usize, const P : usize> PartialOrd -for RefinementInfo<'a, F, D, A, S, RResult, N, P> -where F : Float, - D : 'static, - A : Aggregator, - S : AggregatorSorting { - - #[inline] - fn partial_cmp(&self, other : &Self) -> Option { Some(self.cmp(other)) } + fn eq(&self, other: &Self) -> bool { + self.cmp(other) == Equal + } } -impl<'a, F, D, A, S, RResult, const N : usize, const P : usize> Eq -for RefinementInfo<'a, F, D, A, S, RResult, N, P> -where F : Float, - D : 'static, - A : Aggregator, - S : AggregatorSorting { +impl<'a, F, D, A, S, RResult, const N: usize, const P: usize> PartialOrd + for RefinementInfo<'a, F, D, A, S, RResult, N, P> +where + F: Float, + D: 'static, + A: Aggregator, + S: AggregatorSorting, +{ + #[inline] + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } } -impl<'a, F, D, A, S, RResult, const N : usize, const P : usize> Ord -for RefinementInfo<'a, F, D, A, S, RResult, N, P> -where F : Float, - D : 'static, - A : Aggregator, - S : AggregatorSorting { +impl<'a, F, D, A, S, RResult, const N: usize, const P: usize> Eq + for RefinementInfo<'a, F, D, A, S, RResult, N, P> +where + F: Float, + D: 'static, + A: Aggregator, + S: AggregatorSorting, +{ +} +impl<'a, F, D, A, S, RResult, const N: usize, const P: usize> Ord + for RefinementInfo<'a, F, D, A, S, RResult, N, P> +where + F: Float, + D: 'static, + A: Aggregator, + S: AggregatorSorting, +{ #[inline] - fn cmp(&self, other : &Self) -> Ordering { - self.with_aggregator(|agg1| other.with_aggregator(|agg2| { - match S::sort_upper(agg1).cmp(&S::sort_upper(agg2)) { + fn cmp(&self, other: &Self) -> Ordering { + self.with_aggregator(|agg1| { + other.with_aggregator(|agg2| match S::sort_upper(agg1).cmp(&S::sort_upper(agg2)) { Equal => S::sort_lower(agg1).cmp(&S::sort_lower(agg2)), order => order, - } - })) + }) + }) } } /// This is a container for a [`BinaryHeap`] of [`RefinementInfo`]s together with tracking of /// the greatest lower bound of the [`Aggregator`]s of the [`Node`]s therein accroding to /// chosen [`AggregatorSorting`]. -struct HeapContainer<'a, F, D, A, S, RResult, const N : usize, const P : usize> -where F : Float, - D : 'static + Copy, - Const

: BranchCount, - A : Aggregator, - S : AggregatorSorting { +struct HeapContainer<'a, F, D, A, S, RResult, const N: usize, const P: usize> +where + F: Float, + D: 'static + Copy, + Const

: BranchCount, + A: Aggregator, + S: AggregatorSorting, +{ /// Priority queue of nodes to be refined - heap : BinaryHeap>, + heap: BinaryHeap>, /// Maximum of node sorting lower bounds seen in the heap - glb : S::Sort, + glb: S::Sort, /// Number of insertions in the heap since previous prune - insert_counter : usize, + insert_counter: usize, /// If a result has been found by some refinment threat, it is stored here - result : Option, + result: Option, /// Refinement step counter - step : usize, + step: usize, /// Number of threads currently processing (not sleeping) - n_processing : usize, + n_processing: usize, /// Threshold for heap pruning - heap_prune_threshold : usize, + heap_prune_threshold: usize, } -impl<'a, F, D, A, S, RResult, const N : usize, const P : usize> -HeapContainer<'a, F, D, A, S, RResult, N, P> -where F : Float, - D : 'static + Copy, - Const

: BranchCount, - A : Aggregator, - S : AggregatorSorting { - +impl<'a, F, D, A, S, RResult, const N: usize, const P: usize> + HeapContainer<'a, F, D, A, S, RResult, N, P> +where + F: Float, + D: 'static + Copy, + Const

: BranchCount, + A: Aggregator, + S: AggregatorSorting, +{ /// Push `ri` into the [`BinaryHeap`]. Do greatest lower bound maintenance. /// /// Returns a boolean indicating whether the push was actually performed due to glb /// filtering or not. #[inline] - fn push(&mut self, ri : RefinementInfo<'a, F, D, A, S, RResult, N, P>) -> bool { + fn push(&mut self, ri: RefinementInfo<'a, F, D, A, S, RResult, N, P>) -> bool { if ri.sort_upper() >= self.glb { let l = ri.sort_lower(); self.heap.push(ri); @@ -276,37 +302,38 @@ } } -impl -Branches -where Const

: BranchCount, - A : Aggregator, - D : 'static + Copy + Send + Sync { - +impl Branches +where + Const

: BranchCount, + A: Aggregator, + D: 'static + Copy + Send + Sync, +{ /// Stage all subnodes of `self` into the refinement queue `container`. fn stage_refine<'a, S, RResult>( &'a mut self, - domain : Cube, - container : &mut HeapContainer<'a, F, D, A, S, RResult, N, P>, - ) where S : AggregatorSorting { + domain: Cube, + container: &mut HeapContainer<'a, F, D, A, S, RResult, N, P>, + ) where + S: AggregatorSorting, + { // Insert all subnodes into the refinement heap. for (node, cube) in self.nodes_and_cubes_mut(&domain) { container.push(RefinementInfo { cube, node, - refiner_info : None, - sorting : PhantomData, + refiner_info: None, + sorting: PhantomData, }); } } } - -impl -Node -where Const

: BranchCount, - A : Aggregator, - D : 'static + Copy + Send + Sync { - +impl Node +where + Const

: BranchCount, + A: Aggregator, + D: 'static + Copy + Send + Sync, +{ /// If `self` is a leaf node, uses the `refiner` to determine whether further subdivision /// is required to get a sufficiently refined solution for the problem the refiner is used /// to solve. If the refiner returns [`RefinerResult::Certain`] result, it is returned. @@ -316,17 +343,18 @@ /// /// `domain`, as usual, indicates the spatial area corresponding to `self`. fn search_and_refine<'a, 'b, 'c, R, G>( - self : &'a mut Self, - domain : Cube, - refiner : &R, - generator : &G, - container_arc : &'c Arc>>, - step : usize + self: &'a mut Self, + domain: Cube, + refiner: &R, + generator: &G, + container_arc: &'c Arc>>, + step: usize, ) -> Result>> - where R : Refiner, - G : SupportGenerator, - G::SupportType : LocalAnalysis { - + where + R: Refiner, + G: SupportGenerator, + G::SupportType: LocalAnalysis, + { //drop(container); // Refine a leaf. @@ -368,26 +396,27 @@ unsafe { Arc::get_mut_unchecked(arc_b) } .stage_refine(domain, &mut *container); #[cfg(not(nightly))] - Arc::get_mut(arc_b).unwrap() + Arc::get_mut(arc_b) + .unwrap() .stage_refine(domain, &mut *container); - - return Err(container) - }, + + return Err(container); + } _ => unreachable!("This cannot happen"), } } } res - }, + } NodeOption::Branches(ref mut b) => { // Insert branches into refinement priority queue. let mut container = container_arc.lock().unwrap(); Arc::make_mut(b).stage_refine(domain, &mut *container); - return Err(container) - }, + return Err(container); + } NodeOption::Uninitialised => { refiner.refine(&self.aggregator, &domain, &[], generator, step) - }, + } }; match res { @@ -399,17 +428,17 @@ // aggregator. let mut container = container_arc.lock().unwrap(); container.push(RefinementInfo { - cube : domain, - node : self, - refiner_info : Some((agg, val)), - sorting : PhantomData, + cube: domain, + node: self, + refiner_info: Some((agg, val)), + sorting: PhantomData, }); Err(container) - }, + } Certain(val) => { // The refiner gave a certain result so return it to allow early termination Ok(val) - }, + } NeedRefinement => { // This should only happen when we run into NodeOption::Uninitialised above. // There's really nothing to do. @@ -423,9 +452,10 @@ /// /// This can be removed and the methods implemented directly on [`BT`] once Rust's const generics /// are flexible enough to allow fixing `P=pow(2, N)`. -pub trait BTSearch : BTImpl -where F : Float { - +pub trait BTSearch: BTImpl +where + F: Float, +{ /// Perform a search on on `Self`, as determined by `refiner`. /// /// Nodes are inserted in a priority queue and processed in the order determined by the @@ -437,26 +467,28 @@ /// The `generator` converts [`BTImpl::Data`] stored in the bisection tree into a [`Support`]. fn search_and_refine<'b, R, G>( &'b mut self, - refiner : R, - generator : &Arc, + refiner: R, + generator: &Arc, ) -> Option - where R : Refiner + Sync + Send + 'static, - G : SupportGenerator + Sync + Send + 'static, - G::SupportType : LocalAnalysis; + where + R: Refiner + Sync + Send + 'static, + G: SupportGenerator + Sync + Send + 'static, + G::SupportType: LocalAnalysis; } -fn refinement_loop ( - wakeup : Option>, - refiner : &R, - generator_arc : &Arc, - container_arc : &Arc>>, -) where A : Aggregator, - R : Refiner, - G : SupportGenerator, - G::SupportType : LocalAnalysis, - Const

: BranchCount, - D : 'static + Copy + Sync + Send + std::fmt::Debug { - +fn refinement_loop( + wakeup: Option>, + refiner: &R, + generator_arc: &Arc, + container_arc: &Arc>>, +) where + A: Aggregator, + R: Refiner, + G: SupportGenerator, + G::SupportType: LocalAnalysis, + Const

: BranchCount, + D: 'static + Copy + Sync + Send + std::fmt::Debug, +{ let mut did_park = true; let mut container = container_arc.lock().unwrap(); @@ -471,7 +503,7 @@ // Some refinement task/thread has found a result, return if container.result.is_some() { container.n_processing -= 1; - break 'main + break 'main; } match container.heap.pop() { @@ -489,7 +521,7 @@ container = c.wait(container).unwrap(); continue 'get_next; } else { - break 'main + break 'main; } } }; @@ -502,7 +534,7 @@ // Terminate based on a “best possible” result. container.result = Some(result); container.n_processing -= 1; - break 'main + break 'main; } // Do priority queue maintenance @@ -513,10 +545,8 @@ container.glb = glb; // Prune container.heap.retain(|ri| ri.sort_upper() >= glb); - }, - None => { - container.glb = R::Sorting::bottom() } + None => container.glb = R::Sorting::bottom(), } container.insert_counter = 0; } @@ -525,8 +555,14 @@ drop(container); // … and process the node. We may get returned an already unlocked mutex. - match Node::search_and_refine(ri.node, ri.cube, refiner, &**generator_arc, - &container_arc, step) { + match Node::search_and_refine( + ri.node, + ri.cube, + refiner, + &**generator_arc, + &container_arc, + step, + ) { Ok(r) => { let mut container = container_arc.lock().unwrap(); // Terminate based on a certain result from the refiner @@ -534,8 +570,8 @@ Some(ref mut r_prev) => R::fuse_results(r_prev, r), None => container.result = Some(r), } - break 'main - }, + break 'main; + } Err(cnt) => { container = cnt; // Wake up another thread if one is sleeping; there should be now work in the @@ -545,7 +581,6 @@ } } } - } // Make sure no task is sleeping @@ -558,9 +593,9 @@ macro_rules! impl_btsearch { ($($n:literal)*) => { $( impl<'a, M, F, D, A> - BTSearch + BTSearch<$n, F> for BT - where //Self : BTImpl, // <== automatically deduced + where //Self : BTImpl<$n, F, Data=D,Agg=A, Depth=M>, // <== automatically deduced M : Depth, F : Float + Send, A : Aggregator, @@ -571,7 +606,7 @@ generator : &Arc, ) -> Option where R : Refiner, - G : SupportGenerator, + G : SupportGenerator< $n, F, Id=D>, G::SupportType : LocalAnalysis { let mut init_container = HeapContainer { heap : BinaryHeap::new(), @@ -620,4 +655,3 @@ } impl_btsearch!(1 2 3 4); - diff -r 1f19c6bbf07b -r 3868555d135c src/bisection_tree/support.rs --- a/src/bisection_tree/support.rs Sun Apr 27 20:29:43 2025 -0500 +++ b/src/bisection_tree/support.rs Fri May 15 14:46:30 2026 -0500 @@ -1,31 +1,29 @@ - /*! Traits for representing the support of a [`Mapping`], and analysing the mapping on a [`Cube`]. */ -use serde::Serialize; -use std::ops::{MulAssign,DivAssign,Neg}; -use crate::types::{Float, Num}; +use super::aggregator::Bounds; +pub use crate::bounds::{GlobalAnalysis, LocalAnalysis}; +use crate::loc::Loc; +use crate::mapping::{ClosedSpace, DifferentiableImpl, DifferentiableMapping, Instance, Mapping}; use crate::maputil::map2; -use crate::mapping::{ - Instance, Mapping, DifferentiableImpl, DifferentiableMapping, Space -}; +use crate::norms::{Linfinity, Norm, L1, L2}; +pub use crate::operator_arithmetic::{Constant, Weighted}; use crate::sets::Cube; -use crate::loc::Loc; -use super::aggregator::Bounds; -use crate::norms::{Norm, L1, L2, Linfinity}; -pub use crate::operator_arithmetic::{Weighted, Constant}; +use crate::types::{Float, Num}; +use serde::Serialize; +use std::ops::{DivAssign, MulAssign, Neg}; /// A trait for working with the supports of [`Mapping`]s. /// /// `Mapping` is not a super-trait to allow more general use. -pub trait Support : Sized + Sync + Send + 'static { +pub trait Support: Sized + Sync + Send + 'static { /// Return a cube containing the support of the function represented by `self`. /// /// The hint may be larger than the actual support, but must contain it. - fn support_hint(&self) -> Cube; + fn support_hint(&self) -> Cube; /// Indicate whether `x` is in the support of the function represented by `self`. - fn in_support(&self, x : &Loc) -> bool; + fn in_support(&self, x: &Loc) -> bool; // Indicate whether `cube` is fully in the support of the function represented by `self`. //fn fully_in_support(&self, cube : &Cube) -> bool; @@ -41,139 +39,99 @@ /// The default implementation returns `[None; N]`. #[inline] #[allow(unused_variables)] - fn bisection_hint(&self, cube : &Cube) -> [Option; N] { + fn bisection_hint(&self, cube: &Cube) -> [Option; N] { [None; N] } /// Translate `self` by `x`. #[inline] - fn shift(self, x : Loc) -> Shift { - Shift { shift : x, base_fn : self } + fn shift(self, x: Loc) -> Shift { + Shift { shift: x, base_fn: self } } } -/// Trait for globally analysing a property `A` of a [`Mapping`]. -/// -/// Typically `A` is an [`Aggregator`][super::aggregator::Aggregator] such as -/// [`Bounds`][super::aggregator::Bounds]. -pub trait GlobalAnalysis { - /// Perform global analysis of the property `A` of `Self`. - /// - /// As an example, in the case of `A` being [`Bounds`][super::aggregator::Bounds], - /// this function will return global upper and lower bounds for the mapping - /// represented by `self`. - fn global_analysis(&self) -> A; +/// Shift of [`Support`] and [`Mapping`]; output of [`Support::shift`]. +#[derive(Copy, Clone, Debug, Serialize)] // Serialize! but not implemented by Loc. +pub struct Shift { + shift: Loc, + base_fn: T, } -// default impl GlobalAnalysis for L -// where L : LocalAnalysis { -// #[inline] -// fn global_analysis(&self) -> Bounds { -// self.local_analysis(&self.support_hint()) -// } -// } - -/// Trait for locally analysing a property `A` of a [`Mapping`] (implementing [`Support`]) -/// within a [`Cube`]. -/// -/// Typically `A` is an [`Aggregator`][super::aggregator::Aggregator] such as -/// [`Bounds`][super::aggregator::Bounds]. -pub trait LocalAnalysis : GlobalAnalysis + Support { - /// Perform local analysis of the property `A` of `Self`. - /// - /// As an example, in the case of `A` being [`Bounds`][super::aggregator::Bounds], - /// this function will return upper and lower bounds within `cube` for the mapping - /// represented by `self`. - fn local_analysis(&self, cube : &Cube) -> A; -} - -/// Trait for determining the upper and lower bounds of an float-valued [`Mapping`]. -/// -/// This is a blanket-implemented alias for [`GlobalAnalysis`]`>` -/// [`Mapping`] is not a supertrait to allow flexibility in the implementation of either -/// reference or non-reference arguments. -pub trait Bounded : GlobalAnalysis> { - /// Return lower and upper bounds for the values of of `self`. - #[inline] - fn bounds(&self) -> Bounds { - self.global_analysis() - } -} - -impl>> Bounded for T { } - -/// Shift of [`Support`] and [`Mapping`]; output of [`Support::shift`]. -#[derive(Copy,Clone,Debug,Serialize)] // Serialize! but not implemented by Loc. -pub struct Shift { - shift : Loc, - base_fn : T, -} - -impl<'a, T, V : Space, F : Float, const N : usize> Mapping> for Shift -where T : Mapping, Codomain=V> { +impl<'a, T, V: ClosedSpace, F: Float, const N: usize> Mapping> for Shift +where + T: Mapping, Codomain = V>, +{ type Codomain = V; #[inline] - fn apply>>(&self, x : I) -> Self::Codomain { + fn apply>>(&self, x: I) -> Self::Codomain { self.base_fn.apply(x.own() - &self.shift) } } -impl<'a, T, V : Space, F : Float, const N : usize> DifferentiableImpl> for Shift -where T : DifferentiableMapping, DerivativeDomain=V> { +impl<'a, T, V: ClosedSpace, F: Float, const N: usize> DifferentiableImpl> + for Shift +where + T: DifferentiableMapping, DerivativeDomain = V>, +{ type Derivative = V; #[inline] - fn differential_impl>>(&self, x : I) -> Self::Derivative { + fn differential_impl>>(&self, x: I) -> Self::Derivative { self.base_fn.differential(x.own() - &self.shift) } } -impl<'a, T, F : Float, const N : usize> Support for Shift -where T : Support { +impl<'a, T, F: Float, const N: usize> Support for Shift +where + T: Support, +{ #[inline] - fn support_hint(&self) -> Cube { + fn support_hint(&self) -> Cube { self.base_fn.support_hint().shift(&self.shift) } #[inline] - fn in_support(&self, x : &Loc) -> bool { + fn in_support(&self, x: &Loc) -> bool { self.base_fn.in_support(&(x - &self.shift)) } - + // fn fully_in_support(&self, _cube : &Cube) -> bool { // //self.base_fn.fully_in_support(cube.shift(&vectorneg(self.shift))) // todo!("Not implemented, but not used at the moment") // } #[inline] - fn bisection_hint(&self, cube : &Cube) -> [Option; N] { + fn bisection_hint(&self, cube: &Cube) -> [Option; N] { let base_hint = self.base_fn.bisection_hint(cube); map2(base_hint, &self.shift, |h, s| h.map(|z| z + *s)) } - } -impl<'a, T, F : Float, const N : usize> GlobalAnalysis> for Shift -where T : LocalAnalysis, N> { +impl<'a, T, F: Float, const N: usize> GlobalAnalysis> for Shift +where + T: LocalAnalysis, N>, +{ #[inline] fn global_analysis(&self) -> Bounds { self.base_fn.global_analysis() } } -impl<'a, T, F : Float, const N : usize> LocalAnalysis, N> for Shift -where T : LocalAnalysis, N> { +impl<'a, T, F: Float, const N: usize> LocalAnalysis, N> for Shift +where + T: LocalAnalysis, N>, +{ #[inline] - fn local_analysis(&self, cube : &Cube) -> Bounds { + fn local_analysis(&self, cube: &Cube) -> Bounds { self.base_fn.local_analysis(&cube.shift(&(-self.shift))) } } macro_rules! impl_shift_norm { ($($norm:ident)*) => { $( - impl<'a, T, F : Float, const N : usize> Norm for Shift - where T : Norm { + impl<'a, T, F : Float, const N : usize> Norm<$norm, F> for Shift + where T : Norm<$norm, F> { #[inline] fn norm(&self, n : $norm) -> F { self.base_fn.norm(n) @@ -184,33 +142,36 @@ impl_shift_norm!(L1 L2 Linfinity); -impl<'a, T, F : Float, C, const N : usize> Support for Weighted -where T : Support, - C : Constant { - +impl<'a, T, F: Float, C, const N: usize> Support for Weighted +where + T: Support, + C: Constant, +{ #[inline] - fn support_hint(&self) -> Cube { + fn support_hint(&self) -> Cube { self.base_fn.support_hint() } #[inline] - fn in_support(&self, x : &Loc) -> bool { + fn in_support(&self, x: &Loc) -> bool { self.base_fn.in_support(x) } - + // fn fully_in_support(&self, cube : &Cube) -> bool { // self.base_fn.fully_in_support(cube) // } #[inline] - fn bisection_hint(&self, cube : &Cube) -> [Option; N] { + fn bisection_hint(&self, cube: &Cube) -> [Option; N] { self.base_fn.bisection_hint(cube) } } -impl<'a, T, F : Float, C> GlobalAnalysis> for Weighted -where T : GlobalAnalysis>, - C : Constant { +impl<'a, T, F: Float, C> GlobalAnalysis> for Weighted +where + T: GlobalAnalysis>, + C: Constant, +{ #[inline] fn global_analysis(&self) -> Bounds { let Bounds(lower, upper) = self.base_fn.global_analysis(); @@ -222,11 +183,13 @@ } } -impl<'a, T, F : Float, C, const N : usize> LocalAnalysis, N> for Weighted -where T : LocalAnalysis, N>, - C : Constant { +impl<'a, T, F: Float, C, const N: usize> LocalAnalysis, N> for Weighted +where + T: LocalAnalysis, N>, + C: Constant, +{ #[inline] - fn local_analysis(&self, cube : &Cube) -> Bounds { + fn local_analysis(&self, cube: &Cube) -> Bounds { let Bounds(lower, upper) = self.base_fn.local_analysis(cube); debug_assert!(lower <= upper); match self.weight.value() { @@ -238,31 +201,33 @@ macro_rules! make_weighted_scalarop_rhs { ($trait:ident, $fn:ident, $trait_assign:ident, $fn_assign:ident) => { - impl std::ops::$trait_assign for Weighted { + impl std::ops::$trait_assign for Weighted { #[inline] - fn $fn_assign(&mut self, t : F) { + fn $fn_assign(&mut self, t: F) { self.weight.$fn_assign(t); } } - impl<'a, F : Float, T> std::ops::$trait for Weighted { + impl<'a, F: Float, T> std::ops::$trait for Weighted { type Output = Self; #[inline] - fn $fn(mut self, t : F) -> Self { + fn $fn(mut self, t: F) -> Self { self.weight.$fn_assign(t); self } } - impl<'a, F : Float, T> std::ops::$trait for &'a Weighted - where T : Clone { + impl<'a, F: Float, T> std::ops::$trait for &'a Weighted + where + T: Clone, + { type Output = Weighted; #[inline] - fn $fn(self, t : F) -> Self::Output { - Weighted { weight : self.weight.$fn(t), base_fn : self.base_fn.clone() } + fn $fn(self, t: F) -> Self::Output { + Weighted { weight: self.weight.$fn(t), base_fn: self.base_fn.clone() } } } - } + }; } make_weighted_scalarop_rhs!(Mul, mul, MulAssign, mul_assign); @@ -270,8 +235,8 @@ macro_rules! impl_weighted_norm { ($($norm:ident)*) => { $( - impl<'a, T, F : Float> Norm for Weighted - where T : Norm { + impl<'a, T, F : Float> Norm<$norm, F> for Weighted + where T : Norm<$norm, F> { #[inline] fn norm(&self, n : $norm) -> F { self.base_fn.norm(n) * self.weight.abs() @@ -282,52 +247,60 @@ impl_weighted_norm!(L1 L2 Linfinity); - /// Normalisation of [`Support`] and [`Mapping`] to L¹ norm 1. /// /// Currently only scalar-valued functions are supported. #[derive(Copy, Clone, Debug, Serialize, PartialEq)] pub struct Normalised( /// The base [`Support`] or [`Mapping`]. - pub T + pub T, ); -impl<'a, T, F : Float, const N : usize> Mapping> for Normalised -where T : Norm + Mapping, Codomain=F> { +impl<'a, T, F: Float, const N: usize> Mapping> for Normalised +where + T: Norm + Mapping, Codomain = F>, +{ type Codomain = F; #[inline] - fn apply>>(&self, x : I) -> Self::Codomain { + fn apply>>(&self, x: I) -> Self::Codomain { let w = self.0.norm(L1); - if w == F::ZERO { F::ZERO } else { self.0.apply(x) / w } + if w == F::ZERO { + F::ZERO + } else { + self.0.apply(x) / w + } } } -impl<'a, T, F : Float, const N : usize> Support for Normalised -where T : Norm + Support { - +impl<'a, T, F: Float, const N: usize> Support for Normalised +where + T: Norm + Support, +{ #[inline] - fn support_hint(&self) -> Cube { + fn support_hint(&self) -> Cube { self.0.support_hint() } #[inline] - fn in_support(&self, x : &Loc) -> bool { + fn in_support(&self, x: &Loc) -> bool { self.0.in_support(x) } - + // fn fully_in_support(&self, cube : &Cube) -> bool { // self.0.fully_in_support(cube) // } #[inline] - fn bisection_hint(&self, cube : &Cube) -> [Option; N] { + fn bisection_hint(&self, cube: &Cube) -> [Option; N] { self.0.bisection_hint(cube) } } -impl<'a, T, F : Float> GlobalAnalysis> for Normalised -where T : Norm + GlobalAnalysis> { +impl<'a, T, F: Float> GlobalAnalysis> for Normalised +where + T: Norm + GlobalAnalysis>, +{ #[inline] fn global_analysis(&self) -> Bounds { let Bounds(lower, upper) = self.0.global_analysis(); @@ -338,10 +311,12 @@ } } -impl<'a, T, F : Float, const N : usize> LocalAnalysis, N> for Normalised -where T : Norm + LocalAnalysis, N> { +impl<'a, T, F: Float, const N: usize> LocalAnalysis, N> for Normalised +where + T: Norm + LocalAnalysis, N>, +{ #[inline] - fn local_analysis(&self, cube : &Cube) -> Bounds { + fn local_analysis(&self, cube: &Cube) -> Bounds { let Bounds(lower, upper) = self.0.local_analysis(cube); debug_assert!(lower <= upper); let w = self.0.norm(L1); @@ -350,19 +325,25 @@ } } -impl<'a, T, F : Float> Norm for Normalised -where T : Norm { +impl<'a, T, F: Float> Norm for Normalised +where + T: Norm, +{ #[inline] - fn norm(&self, _ : L1) -> F { + fn norm(&self, _: L1) -> F { let w = self.0.norm(L1); - if w == F::ZERO { F::ZERO } else { F::ONE } + if w == F::ZERO { + F::ZERO + } else { + F::ONE + } } } macro_rules! impl_normalised_norm { ($($norm:ident)*) => { $( - impl<'a, T, F : Float> Norm for Normalised - where T : Norm + Norm { + impl<'a, T, F : Float> Norm<$norm, F> for Normalised + where T : Norm<$norm, F> + Norm { #[inline] fn norm(&self, n : $norm) -> F { let w = self.0.norm(L1); @@ -375,37 +356,39 @@ impl_normalised_norm!(L2 Linfinity); /* -impl, const N : usize> LocalAnalysis for S { - fn local_analysis(&self, _cube : &Cube) -> NullAggregator { NullAggregator } +impl, const N : usize> LocalAnalysis for S { + fn local_analysis(&self, _cube : &Cube) -> NullAggregator { NullAggregator } } impl, const N : usize> LocalAnalysis, N> for S { #[inline] - fn local_analysis(&self, cube : &Cube) -> Bounds { + fn local_analysis(&self, cube : &Cube) -> Bounds { self.bounds(cube) } }*/ /// Generator of [`Support`]-implementing component functions based on low storage requirement /// [ids][`Self::Id`]. -pub trait SupportGenerator -: MulAssign + DivAssign + Neg + Clone + Sync + Send + 'static { +pub trait SupportGenerator: + MulAssign + DivAssign + Neg + Clone + Sync + Send + 'static +{ /// The identification type - type Id : 'static + Copy; + type Id: 'static + Copy; /// The type of the [`Support`] (often also a [`Mapping`]). - type SupportType : 'static + Support; + type SupportType: 'static + Support; /// An iterator over all the [`Support`]s of the generator. - type AllDataIter<'a> : Iterator where Self : 'a; + type AllDataIter<'a>: Iterator + where + Self: 'a; /// Returns the component identified by `id`. /// /// Panics if `id` is an invalid identifier. - fn support_for(&self, id : Self::Id) -> Self::SupportType; - + fn support_for(&self, id: Self::Id) -> Self::SupportType; + /// Returns the number of different components in this generator. fn support_count(&self) -> usize; /// Returns an iterator over all pairs of `(id, support)`. fn all_data(&self) -> Self::AllDataIter<'_>; } - diff -r 1f19c6bbf07b -r 3868555d135c src/bounds.rs --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/bounds.rs Fri May 15 14:46:30 2026 -0500 @@ -0,0 +1,252 @@ +/*! +Bounded and minimizable/maximizable mappings. +*/ + +use crate::instance::{Instance, Space}; +use crate::mapping::Mapping; +use crate::sets::{Cube, Set}; +use crate::types::{Float, Num}; + +/// Trait for globally analysing a property `A` of a [`Mapping`]. +/// +/// Typically `A` is an [`Aggregator`][super::bisection_tree::Aggregator] +/// such as [`Bounds`]. +pub trait GlobalAnalysis { + /// Perform global analysis of the property `A` of `Self`. + /// + /// As an example, in the case of `A` being [`Bounds`], + /// this function will return global upper and lower bounds for the mapping + /// represented by `self`. + fn global_analysis(&self) -> A; +} + +// default impl GlobalAnalysis for L +// where L : LocalAnalysis { +// #[inline] +// fn global_analysis(&self) -> Bounds { +// self.local_analysis(&self.support_hint()) +// } +// } + +/// Trait for locally analysing a property `A` of a [`Mapping`] (implementing [`super::bisection_tree::Support`]) +/// within a [`Cube`]. +/// +/// Typically `A` is an [`Aggregator`][super::bisection_tree::Aggregator] such as [`Bounds`]. +pub trait LocalAnalysis: GlobalAnalysis { + /// Perform local analysis of the property `A` of `Self`. + /// + /// As an example, in the case of `A` being [`Bounds`], + /// this function will return upper and lower bounds within `cube` for the mapping + /// represented by `self`. + fn local_analysis(&self, cube: &Cube) -> A; +} + +/// Trait for determining the upper and lower bounds of an float-valued [`Mapping`]. +/// +/// This is a blanket-implemented alias for [`GlobalAnalysis`]`>` +/// [`Mapping`] is not a supertrait to allow flexibility in the implementation of either +/// reference or non-reference arguments. +pub trait Bounded: GlobalAnalysis> { + /// Return lower and upper bounds for the values of of `self`. + #[inline] + fn bounds(&self) -> Bounds { + self.global_analysis() + } +} + +impl>> Bounded for T {} + +/// A real-valued [`Mapping`] that provides rough bounds as well as minimisation and maximisation. +pub trait MinMaxMapping: + Mapping + Bounded +{ + /// Maximise the mapping within stated value `tolerance`. + /// + /// At most `max_steps` refinement steps are taken. + /// Returns the approximate maximiser and the corresponding function value. + fn maximise(&mut self, tolerance: F, max_steps: usize) -> (Domain, F); + + /// Maximise the mapping within stated value `tolerance` subject to a lower bound. + /// + /// At most `max_steps` refinement steps are taken. + /// Returns the approximate maximiser and the corresponding function value when one is found + /// above the `bound` threshold, otherwise `None`. + fn maximise_above(&mut self, bound: F, tolerance: F, max_steps: usize) -> Option<(Domain, F)> { + let res @ (_, v) = self.maximise(tolerance, max_steps); + (v > bound).then_some(res) + } + + /// Minimise the mapping within stated value `tolerance`. + /// + /// At most `max_steps` refinement steps are taken. + /// Returns the approximate minimiser and the corresponding function value. + fn minimise(&mut self, tolerance: F, max_steps: usize) -> (Domain, F); + + /// Minimise the mapping within stated value `tolerance` subject to a lower bound. + /// + /// At most `max_steps` refinement steps are taken. + /// Returns the approximate minimiser and the corresponding function value when one is found + /// below the `bound` threshold, otherwise `None`. + fn minimise_below(&mut self, bound: F, tolerance: F, max_steps: usize) -> Option<(Domain, F)> { + let res @ (_, v) = self.minimise(tolerance, max_steps); + (v < bound).then_some(res) + } + + /// Verify that the mapping has a given upper `bound` within indicated `tolerance`. + /// + /// At most `max_steps` refinement steps are taken. + fn has_upper_bound(&mut self, bound: F, tolerance: F, max_steps: usize) -> bool { + match self.maximise_above(bound, tolerance, max_steps) { + None => true, + Some((_, v)) => v <= bound, + } + } + + /// Verify that the mapping has a given lower `bound` within indicated `tolerance`. + /// + /// At most `max_steps` refinement steps are taken. + fn has_lower_bound(&mut self, bound: F, tolerance: F, max_steps: usize) -> bool { + match self.minimise_below(bound, tolerance, max_steps) { + None => true, + Some((_, v)) => v >= bound, + } + } +} + +/// Upper and lower bounds on an `F`-valued function. +#[derive(Copy, Clone, Debug)] +pub struct Bounds( + /// Lower bound + pub F, + /// Upper bound + pub F, +); + +impl Bounds { + /// Returns the lower bound + #[inline] + pub fn lower(&self) -> F { + self.0 + } + + /// Returns the upper bound + #[inline] + pub fn upper(&self) -> F { + self.1 + } +} + +impl Bounds { + /// Returns a uniform bound. + /// + /// This is maximum over the absolute values of the upper and lower bound. + #[inline] + pub fn uniform(&self) -> F { + let &Bounds(lower, upper) = self; + lower.abs().max(upper.abs()) + } + + /// Construct a bounds, making sure `lower` bound is less than `upper` + #[inline] + pub fn corrected(lower: F, upper: F) -> Self { + if lower <= upper { + Bounds(lower, upper) + } else { + Bounds(upper, lower) + } + } + + /// Refine the lower bound + #[inline] + pub fn refine_lower(&self, lower: F) -> Self { + let &Bounds(l, u) = self; + debug_assert!(l <= u); + Bounds(l.max(lower), u.max(lower)) + } + + /// Refine the lower bound + #[inline] + pub fn refine_upper(&self, upper: F) -> Self { + let &Bounds(l, u) = self; + debug_assert!(l <= u); + Bounds(l.min(upper), u.min(upper)) + } +} + +impl<'a, F: Float> std::ops::Add for Bounds { + type Output = Self; + #[inline] + fn add(self, Bounds(l2, u2): Self) -> Self::Output { + let Bounds(l1, u1) = self; + debug_assert!(l1 <= u1 && l2 <= u2); + Bounds(l1 + l2, u1 + u2) + } +} + +impl<'a, F: Float> std::ops::Mul for Bounds { + type Output = Self; + #[inline] + fn mul(self, Bounds(l2, u2): Self) -> Self::Output { + let Bounds(l1, u1) = self; + debug_assert!(l1 <= u1 && l2 <= u2); + let a = l1 * l2; + let b = u1 * u2; + // The order may flip when negative numbers are involved, so need min/max + Bounds(a.min(b), a.max(b)) + } +} + +impl std::iter::Product for Bounds { + #[inline] + fn product(mut iter: I) -> Self + where + I: Iterator, + { + match iter.next() { + None => Bounds(F::ZERO, F::ZERO), + Some(init) => iter.fold(init, |a, b| a * b), + } + } +} + +impl Set for Bounds { + fn contains>(&self, item: I) -> bool { + let v = item.own(); + let &Bounds(l, u) = self; + debug_assert!(l <= u); + l <= v && v <= u + } +} + +impl Bounds { + /// Calculate a common bound (glb, lub) for two bounds. + #[inline] + pub fn common(&self, &Bounds(l2, u2): &Self) -> Self { + let &Bounds(l1, u1) = self; + debug_assert!(l1 <= u1 && l2 <= u2); + Bounds(l1.min(l2), u1.max(u2)) + } + + /// Indicates whether `Self` is a superset of the argument bound. + #[inline] + pub fn superset(&self, &Bounds(l2, u2): &Self) -> bool { + let &Bounds(l1, u1) = self; + debug_assert!(l1 <= u1 && l2 <= u2); + l1 <= l2 && u2 <= u1 + } + + /// Returns the greatest bound contained by both argument bounds, if one exists. + #[inline] + pub fn glb(&self, &Bounds(l2, u2): &Self) -> Option { + let &Bounds(l1, u1) = self; + debug_assert!(l1 <= u1 && l2 <= u2); + let l = l1.max(l2); + let u = u1.min(u2); + debug_assert!(l <= u); + if l < u { + Some(Bounds(l, u)) + } else { + None + } + } +} diff -r 1f19c6bbf07b -r 3868555d135c src/collection.rs --- a/src/collection.rs Sun Apr 27 20:29:43 2025 -0500 +++ b/src/collection.rs Fri May 15 14:46:30 2026 -0500 @@ -5,20 +5,24 @@ use crate::loc::Loc; /// An abstract collection of elements. -pub trait Collection : IntoIterator { +pub trait Collection: IntoIterator { /// Type of elements of the collection type Element; /// Iterator over references to elements of the collection - type RefsIter<'a> : Iterator where Self : 'a; + type RefsIter<'a>: Iterator + where + Self: 'a; /// Returns an iterator over references to elements of the collection. fn iter_refs(&self) -> Self::RefsIter<'_>; } /// An abstract collection of mutable elements. -pub trait CollectionMut : Collection { +pub trait CollectionMut: Collection { /// Iterator over references to elements of the collection - type RefsIterMut<'a> : Iterator where Self : 'a; + type RefsIterMut<'a>: Iterator + where + Self: 'a; /// Returns an iterator over references to elements of the collection. fn iter_refs_mut(&mut self) -> Self::RefsIterMut<'_>; @@ -51,4 +55,4 @@ slice_like_collection!(Vec where E); slice_like_collection!([E; N] where E, const N : usize); -slice_like_collection!(Loc where E, const N : usize); +slice_like_collection!(Loc where E, const N : usize); diff -r 1f19c6bbf07b -r 3868555d135c src/convex.rs --- a/src/convex.rs Sun Apr 27 20:29:43 2025 -0500 +++ b/src/convex.rs Fri May 15 14:46:30 2026 -0500 @@ -2,27 +2,36 @@ Some convex analysis basics */ -use std::marker::PhantomData; +use crate::error::DynResult; +use crate::euclidean::Euclidean; +use crate::instance::{ClosedSpace, DecompositionMut, Instance}; +use crate::linops::{IdOp, Scaled, SimpleZeroOp, AXPY}; +use crate::mapping::{DifferentiableImpl, LipschitzDifferentiableImpl, Mapping, Space}; +use crate::norms::*; +use crate::operator_arithmetic::{Constant, Weighted}; use crate::types::*; -use crate::mapping::{Mapping, Space}; -use crate::linops::IdOp; -use crate::instance::{Instance, InstanceMut, DecompositionMut}; -use crate::operator_arithmetic::{Constant, Weighted}; -use crate::norms::*; +use serde::{Deserialize, Serialize}; +use std::marker::PhantomData; /// Trait for convex mappings. Has no features, just serves as a constraint /// /// TODO: should constrain `Mapping::Codomain` to implement a partial order, /// but this makes everything complicated with little benefit. -pub trait ConvexMapping : Mapping -{} +pub trait ConvexMapping: Mapping { + /// Returns (a lower estimate of) the factor of strong convexity in the norm of `Domain`. + fn factor_of_strong_convexity(&self) -> F { + F::ZERO + } +} /// Trait for mappings with a Fenchel conjugate /// /// The conjugate type has to implement [`ConvexMapping`], but a `Conjugable` mapping need /// not be convex. -pub trait Conjugable, F : Num = f64> : Mapping { - type Conjugate<'a> : ConvexMapping where Self : 'a; +pub trait Conjugable, F: Num = f64>: Mapping { + type Conjugate<'a>: ConvexMapping + where + Self: 'a; fn conjugate(&self) -> Self::Conjugate<'_>; } @@ -31,12 +40,14 @@ /// /// In contrast to [`Conjugable`], the preconjugate need not implement [`ConvexMapping`], /// but a `Preconjugable` mapping has to be convex. -pub trait Preconjugable : ConvexMapping +pub trait Preconjugable: ConvexMapping where - Domain : Space, - Predual : HasDual + Domain: Space, + Predual: HasDual, { - type Preconjugate<'a> : Mapping where Self : 'a; + type Preconjugate<'a>: Mapping + where + Self: 'a; fn preconjugate(&self) -> Self::Preconjugate<'_>; } @@ -45,53 +56,55 @@ /// /// The conjugate type has to implement [`ConvexMapping`], but a `Conjugable` mapping need /// not be convex. -pub trait Prox : Mapping { - type Prox<'a> : Mapping where Self : 'a; +pub trait Prox: Mapping { + type Prox<'a>: Mapping + where + Self: 'a; /// Returns a proximal mapping with weight τ - fn prox_mapping(&self, τ : Self::Codomain) -> Self::Prox<'_>; + fn prox_mapping(&self, τ: Self::Codomain) -> Self::Prox<'_>; /// Calculate the proximal mapping with weight τ - fn prox>(&self, τ : Self::Codomain, z : I) -> Domain { + fn prox>(&self, τ: Self::Codomain, z: I) -> Domain::Principal { self.prox_mapping(τ).apply(z) } /// Calculate the proximal mapping with weight τ in-place - fn prox_mut<'b>(&self, τ : Self::Codomain, y : &'b mut Domain) + fn prox_mut<'b>(&self, τ: Self::Codomain, y: &'b mut Domain::Principal) where - &'b mut Domain : InstanceMut, - Domain:: Decomp : DecompositionMut, - for<'a> &'a Domain : Instance, + Domain::Decomp: DecompositionMut, + for<'a> &'a Domain::Principal: Instance, { *y = self.prox(τ, &*y); } } - /// Constraint to the unit ball of the norm described by `E`. -pub struct NormConstraint { - radius : F, - norm : NormMapping, +#[derive(Copy, Clone, Debug, Serialize, Deserialize)] +pub struct NormConstraint { + radius: F, + norm: NormMapping, } impl ConvexMapping for NormMapping where - Domain : Space, - E : NormExponent, - F : Float, - Self : Mapping -{} - + Domain: Space, + E: NormExponent, + F: Float, + Self: Mapping, +{ +} impl Mapping for NormConstraint where - Domain : Space + Norm, - F : Float, - E : NormExponent, + Domain: Space, + Domain::Principal: Norm, + F: Float, + E: NormExponent, { type Codomain = F; - fn apply>(&self, d : I) -> F { + fn apply>(&self, d: I) -> F { if d.eval(|x| x.norm(self.norm.exponent)) <= self.radius { F::ZERO } else { @@ -102,68 +115,78 @@ impl ConvexMapping for NormConstraint where - Domain : Space, - E : NormExponent, - F : Float, - Self : Mapping -{} - + Domain: Space, + E: NormExponent, + F: Float, + Self: Mapping, +{ +} impl Conjugable for NormMapping where - E : HasDualExponent, - F : Float, - Domain : HasDual + Norm + Space, - >::DualSpace : Norm + E: HasDualExponent, + F: Float, + Domain: HasDual, + Domain::Principal: Norm, + >::DualSpace: Norm, { - type Conjugate<'a> = NormConstraint where Self : 'a; + type Conjugate<'a> + = NormConstraint + where + Self: 'a; fn conjugate(&self) -> Self::Conjugate<'_> { - NormConstraint { - radius : F::ONE, - norm : self.exponent.dual_exponent().as_mapping() - } + NormConstraint { radius: F::ONE, norm: self.exponent.dual_exponent().as_mapping() } } } impl Conjugable for Weighted, C> where - C : Constant, - E : HasDualExponent, - F : Float, - Domain : HasDual + Norm + Space, - >::DualSpace : Norm + C: Constant, + E: HasDualExponent, + F: Float, + Domain: HasDual, + Domain::Principal: Norm, + >::DualSpace: Norm, { - type Conjugate<'a> = NormConstraint where Self : 'a; + type Conjugate<'a> + = NormConstraint + where + Self: 'a; fn conjugate(&self) -> Self::Conjugate<'_> { NormConstraint { - radius : self.weight.value(), - norm : self.base_fn.exponent.dual_exponent().as_mapping() + radius: self.weight.value(), + norm: self.base_fn.exponent.dual_exponent().as_mapping(), } } } impl Prox for NormConstraint where - Domain : Space + Norm, - E : NormExponent, - F : Float, - NormProjection : Mapping, + Domain: Space, + Domain::Principal: Norm, + E: NormExponent, + F: Float, + NormProjection: Mapping, { - type Prox<'a> = NormProjection where Self : 'a; + type Prox<'a> + = NormProjection + where + Self: 'a; #[inline] - fn prox_mapping(&self, _τ : Self::Codomain) -> Self::Prox<'_> { + fn prox_mapping(&self, _τ: Self::Codomain) -> Self::Prox<'_> { assert!(self.radius >= F::ZERO); - NormProjection{ radius : self.radius, exponent : self.norm.exponent } + NormProjection { radius: self.radius, exponent: self.norm.exponent } } } /// Projection to the unit ball of the norm described by `E`. -pub struct NormProjection { - radius : F, - exponent : E, +#[derive(Copy, Clone, Debug, Serialize, Deserialize)] +pub struct NormProjection { + radius: F, + exponent: E, } /* @@ -182,41 +205,44 @@ impl Mapping for NormProjection where - Domain : Space + Projection, - F : Float, - E : NormExponent, + Domain: Space, + Domain::Principal: ClosedSpace + Projection, + F: Float, + E: NormExponent, { - type Codomain = Domain; + type Codomain = Domain::Principal; - fn apply>(&self, d : I) -> Domain { + fn apply>(&self, d: I) -> Self::Codomain { d.own().proj_ball(self.radius, self.exponent) } } +/// The zero mapping +#[derive(Copy, Clone, Debug, Serialize, Deserialize)] +pub struct Zero(PhantomData<(Domain, F)>); -/// The zero mapping -pub struct Zero(PhantomData<(Domain, F)>); - -impl Zero { +impl Zero { pub fn new() -> Self { Zero(PhantomData) } } -impl Mapping for Zero { +impl Mapping for Zero { type Codomain = F; /// Compute the value of `self` at `x`. - fn apply>(&self, _x : I) -> Self::Codomain { + fn apply>(&self, _x: I) -> Self::Codomain { F::ZERO } } -impl ConvexMapping for Zero { } - +impl ConvexMapping for Zero {} -impl, F : Float> Conjugable for Zero { - type Conjugate<'a> = ZeroIndicator where Self : 'a; +impl, F: Float> Conjugable for Zero { + type Conjugate<'a> + = ZeroIndicator + where + Self: 'a; #[inline] fn conjugate(&self) -> Self::Conjugate<'_> { @@ -224,12 +250,15 @@ } } -impl Preconjugable for Zero +impl Preconjugable for Zero where - Domain : Space, - Predual : HasDual + Domain: Normed, + Predual: HasDual, { - type Preconjugate<'a> = ZeroIndicator where Self : 'a; + type Preconjugate<'a> + = ZeroIndicator + where + Self: 'a; #[inline] fn preconjugate(&self) -> Self::Preconjugate<'_> { @@ -237,38 +266,61 @@ } } -impl Prox for Zero { - type Prox<'a> = IdOp where Self : 'a; +impl Prox for Zero { + type Prox<'a> + = IdOp + where + Self: 'a; #[inline] - fn prox_mapping(&self, _τ : Self::Codomain) -> Self::Prox<'_> { + fn prox_mapping(&self, _τ: Self::Codomain) -> Self::Prox<'_> { IdOp::new() } } +/// The zero indicator +#[derive(Copy, Clone, Debug, Serialize, Deserialize)] +pub struct ZeroIndicator(PhantomData<(Domain, F)>); -/// The zero indicator -pub struct ZeroIndicator(PhantomData<(Domain, F)>); - -impl ZeroIndicator { +impl ZeroIndicator { pub fn new() -> Self { ZeroIndicator(PhantomData) } } -impl, F : Float> Mapping for ZeroIndicator { +impl Mapping for ZeroIndicator +where + F: Float, + Domain: Space, + Domain::Principal: Normed, +{ type Codomain = F; /// Compute the value of `self` at `x`. - fn apply>(&self, x : I) -> Self::Codomain { + fn apply>(&self, x: I) -> Self::Codomain { x.eval(|x̃| if x̃.is_zero() { F::ZERO } else { F::INFINITY }) } } -impl, F : Float> ConvexMapping for ZeroIndicator { } +impl ConvexMapping for ZeroIndicator +where + Domain: Space, + Domain::Principal: Normed, +{ + fn factor_of_strong_convexity(&self) -> F { + F::INFINITY + } +} -impl, F : Float> Conjugable for ZeroIndicator { - type Conjugate<'a> = Zero where Self : 'a; +impl Conjugable for ZeroIndicator +where + Domain: HasDual, + Domain::PrincipalV: Normed, +{ + type Conjugate<'a> + = Zero + where + Self: 'a; #[inline] fn conjugate(&self) -> Self::Conjugate<'_> { @@ -276,15 +328,123 @@ } } -impl Preconjugable for ZeroIndicator +impl Preconjugable for ZeroIndicator where - Domain : Normed, - Predual : HasDual + Domain: Space, + Domain::Principal: Normed, + Predual: HasDual, { - type Preconjugate<'a> = Zero where Self : 'a; + type Preconjugate<'a> + = Zero + where + Self: 'a; #[inline] fn preconjugate(&self) -> Self::Preconjugate<'_> { Zero::new() } } + +impl Prox for ZeroIndicator +where + Domain: AXPY + Normed, + F: Float, +{ + type Prox<'a> + = SimpleZeroOp + where + Self: 'a; + + /// Returns a proximal mapping with weight τ + fn prox_mapping(&self, _τ: F) -> Self::Prox<'_> { + return SimpleZeroOp; + } +} + +/// The squared Euclidean norm divided by two +#[derive(Copy, Clone, Serialize, Deserialize)] +pub struct Norm222(PhantomData); + +impl,*/ F: Float> Norm222 { + pub fn new() -> Self { + Norm222(PhantomData) + } +} + +impl, F: Float> Mapping for Norm222 { + type Codomain = F; + + /// Compute the value of `self` at `x`. + fn apply>(&self, x: I) -> Self::Codomain { + x.eval(|z| z.norm2_squared() / F::TWO) + } +} + +impl, F: Float> ConvexMapping for Norm222 { + fn factor_of_strong_convexity(&self) -> F { + F::ONE + } +} + +impl, F: Float> Conjugable for Norm222 { + type Conjugate<'a> + = Self + where + Self: 'a; + + #[inline] + fn conjugate(&self) -> Self::Conjugate<'_> { + Self::new() + } +} + +impl, F: Float> Preconjugable for Norm222 { + type Preconjugate<'a> + = Self + where + Self: 'a; + + #[inline] + fn preconjugate(&self) -> Self::Preconjugate<'_> { + Self::new() + } +} + +impl Prox for Norm222 +where + F: Float, + X: Euclidean, +{ + type Prox<'a> + = Scaled + where + Self: 'a; + + fn prox_mapping(&self, τ: F) -> Self::Prox<'_> { + Scaled(F::ONE / (F::ONE + τ)) + } +} + +impl DifferentiableImpl for Norm222 +where + F: Float, + X: Euclidean, +{ + type Derivative = X::PrincipalV; + + fn differential_impl>(&self, x: I) -> Self::Derivative { + x.own() + } +} + +impl LipschitzDifferentiableImpl for Norm222 +where + F: Float, + X: Euclidean, +{ + type FloatType = F; + + fn diff_lipschitz_factor(&self, _: L2) -> DynResult { + Ok(F::ONE) + } +} diff -r 1f19c6bbf07b -r 3868555d135c src/direct_product.rs --- a/src/direct_product.rs Sun Apr 27 20:29:43 2025 -0500 +++ b/src/direct_product.rs Fri May 15 14:46:30 2026 -0500 @@ -6,11 +6,11 @@ */ use crate::euclidean::Euclidean; -use crate::instance::{Decomposition, DecompositionMut, Instance, InstanceMut, MyCow}; -use crate::linops::AXPY; +use crate::instance::{Decomposition, DecompositionMut, Instance, InstanceMut, MyCow, Ownable}; +use crate::linops::{VectorSpace, AXPY}; use crate::loc::Loc; use crate::mapping::Space; -use crate::norms::{HasDual, Norm, NormExponent, Normed, PairNorm, L2}; +use crate::norms::{Dist, HasDual, Norm, NormExponent, Normed, PairNorm, L2}; use crate::types::{Float, Num}; use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}; use serde::{Deserialize, Serialize}; @@ -39,239 +39,290 @@ } } -macro_rules! impl_binop { - (($a : ty, $b : ty), $trait : ident, $fn : ident, $refl:ident, $refr:ident) => { - impl_binop!(@doit: $a, $b, $trait, $fn; - maybe_lifetime!($refl, &'l Pair<$a,$b>), - (maybe_lifetime!($refl, &'l $a), - maybe_lifetime!($refl, &'l $b)); - maybe_lifetime!($refr, &'r Pair), - (maybe_lifetime!($refr, &'r Ai), - maybe_lifetime!($refr, &'r Bi)); - $refl, $refr); - }; - - (@doit: $a:ty, $b:ty, - $trait:ident, $fn:ident; - $self:ty, ($aself:ty, $bself:ty); - $in:ty, ($ain:ty, $bin:ty); - $refl:ident, $refr:ident) => { - impl<'l, 'r, Ai, Bi> $trait<$in> - for $self - where $aself: $trait<$ain>, - $bself: $trait<$bin> { - type Output = Pair<<$aself as $trait<$ain>>::Output, - <$bself as $trait<$bin>>::Output>; - - #[inline] - fn $fn(self, y : $in) -> Self::Output { - Pair(maybe_ref!($refl, self.0).$fn(maybe_ref!($refr, y.0)), - maybe_ref!($refl, self.1).$fn(maybe_ref!($refr, y.1))) +macro_rules! impl_unary { + ($trait:ident, $fn:ident) => { + impl $trait for Pair + where + A: $trait, + B: $trait, + { + type Output = Pair; + fn $fn(self) -> Self::Output { + let Pair(a, b) = self; + Pair(a.$fn(), b.$fn()) } } + + // Compiler overflow + // impl<'a, A, B> $trait for &'a Pair + // where + // &'a A: $trait, + // &'a B: $trait, + // { + // type Output = Pair<<&'a A as $trait>::Output, <&'a B as $trait>::Output>; + // fn $fn(self) -> Self::Output { + // let Pair(ref a, ref b) = self; + // Pair(a.$fn(), b.$fn()) + // } + // } }; } -macro_rules! impl_assignop { - (($a : ty, $b : ty), $trait : ident, $fn : ident, $refr:ident) => { - impl_assignop!(@doit: $a, $b, - $trait, $fn; - maybe_lifetime!($refr, &'r Pair), - (maybe_lifetime!($refr, &'r Ai), - maybe_lifetime!($refr, &'r Bi)); - $refr); - }; - (@doit: $a : ty, $b : ty, - $trait:ident, $fn:ident; - $in:ty, ($ain:ty, $bin:ty); - $refr:ident) => { - impl<'r, Ai, Bi> $trait<$in> - for Pair<$a,$b> - where $a: $trait<$ain>, - $b: $trait<$bin> { - #[inline] - fn $fn(&mut self, y : $in) -> () { - self.0.$fn(maybe_ref!($refr, y.0)); - self.1.$fn(maybe_ref!($refr, y.1)); +impl_unary!(Neg, neg); + +macro_rules! impl_binary { + ($trait:ident, $fn:ident) => { + impl $trait> for Pair + where + A: $trait, + B: $trait, + { + type Output = Pair; + fn $fn(self, Pair(c, d): Pair) -> Self::Output { + let Pair(a, b) = self; + Pair(a.$fn(c), b.$fn(d)) } } - } -} -macro_rules! impl_scalarop { - (($a : ty, $b : ty), $field : ty, $trait : ident, $fn : ident, $refl:ident) => { - impl_scalarop!(@doit: $field, - $trait, $fn; - maybe_lifetime!($refl, &'l Pair<$a,$b>), - (maybe_lifetime!($refl, &'l $a), - maybe_lifetime!($refl, &'l $b)); - $refl); - }; - (@doit: $field : ty, - $trait:ident, $fn:ident; - $self:ty, ($aself:ty, $bself:ty); - $refl:ident) => { - // Scalar as Rhs - impl<'l> $trait<$field> - for $self - where $aself: $trait<$field>, - $bself: $trait<$field> { - type Output = Pair<<$aself as $trait<$field>>::Output, - <$bself as $trait<$field>>::Output>; - #[inline] - fn $fn(self, a : $field) -> Self::Output { - Pair(maybe_ref!($refl, self.0).$fn(a), - maybe_ref!($refl, self.1).$fn(a)) + impl<'a, A, B, C, D> $trait> for &'a Pair + where + &'a A: $trait, + &'a B: $trait, + { + type Output = Pair<<&'a A as $trait>::Output, <&'a B as $trait>::Output>; + fn $fn(self, Pair(c, d): Pair) -> Self::Output { + let Pair(ref a, ref b) = self; + Pair(a.$fn(c), b.$fn(d)) } } - } -} -// Not used due to compiler overflow -#[allow(unused_macros)] -macro_rules! impl_scalarlhs_op { - (($a : ty, $b : ty), $field : ty, $trait:ident, $fn:ident, $refr:ident) => { - impl_scalarlhs_op!(@doit: $trait, $fn, - maybe_lifetime!($refr, &'r Pair<$a,$b>), - (maybe_lifetime!($refr, &'r $a), - maybe_lifetime!($refr, &'r $b)); - $refr, $field); - }; - (@doit: $trait:ident, $fn:ident, - $in:ty, ($ain:ty, $bin:ty); - $refr:ident, $field:ty) => { - impl<'r> $trait<$in> - for $field - where $field : $trait<$ain> - + $trait<$bin> { - type Output = Pair<<$field as $trait<$ain>>::Output, - <$field as $trait<$bin>>::Output>; - #[inline] - fn $fn(self, x : $in) -> Self::Output { - Pair(self.$fn(maybe_ref!($refr, x.0)), - self.$fn(maybe_ref!($refr, x.1))) + impl<'a, 'b, A, B, C, D> $trait<&'b Pair> for &'a Pair + where + &'a A: $trait<&'b C>, + &'a B: $trait<&'b D>, + { + type Output = Pair<<&'a A as $trait<&'b C>>::Output, <&'a B as $trait<&'b D>>::Output>; + fn $fn(self, Pair(ref c, ref d): &'b Pair) -> Self::Output { + let Pair(ref a, ref b) = self; + Pair(a.$fn(c), b.$fn(d)) + } + } + + impl<'b, A, B, C, D> $trait<&'b Pair> for Pair + where + A: $trait<&'b C>, + B: $trait<&'b D>, + { + type Output = Pair<>::Output, >::Output>; + fn $fn(self, Pair(ref c, ref d): &'b Pair) -> Self::Output { + let Pair(a, b) = self; + Pair(a.$fn(c), b.$fn(d)) } } }; } -macro_rules! impl_scalar_assignop { - (($a : ty, $b : ty), $field : ty, $trait : ident, $fn : ident) => { - impl<'r> $trait<$field> for Pair<$a, $b> +impl_binary!(Add, add); +impl_binary!(Sub, sub); + +macro_rules! impl_scalar { + ($trait:ident, $fn:ident) => { + impl $trait for Pair + where + A: $trait, + B: $trait, + { + type Output = Pair; + fn $fn(self, t: F) -> Self::Output { + let Pair(a, b) = self; + Pair(a.$fn(t), b.$fn(t)) + } + } + + impl<'a, A, B, F: Num> $trait for &'a Pair where - $a: $trait<$field>, - $b: $trait<$field>, + &'a A: $trait, + &'a B: $trait, + { + type Output = Pair<<&'a A as $trait>::Output, <&'a B as $trait>::Output>; + fn $fn(self, t: F) -> Self::Output { + let Pair(ref a, ref b) = self; + Pair(a.$fn(t), b.$fn(t)) + } + } + + // impl<'a, 'b, A, B> $trait<&'b $F> for &'a Pair + // where + // &'a A: $trait<&'b $F>, + // &'a B: $trait<&'b $F>, + // { + // type Output = + // Pair<<&'a A as $trait<&'b $F>>::Output, <&'a B as $trait<&'b $F>>::Output>; + // fn $fn(self, t: &'b $F) -> Self::Output { + // let Pair(ref a, ref b) = self; + // Pair(a.$fn(t), b.$fn(t)) + // } + // } + + // impl<'b, A, B> $trait<&'b $F> for Pair + // where + // A: $trait<&'b $F>, + // B: $trait<&'b $F>, + // { + // type Output = Pair<>::Output, >::Output>; + // fn $fn(self, t: &'b $F) -> Self::Output { + // let Pair(a, b) = self; + // Pair(a.$fn(t), b.$fn(t)) + // } + // } + }; +} + +impl_scalar!(Mul, mul); +impl_scalar!(Div, div); + +macro_rules! impl_scalar_lhs { + ($trait:ident, $fn:ident, $F:ty) => { + impl $trait> for $F + where + $F: $trait + $trait, { - #[inline] - fn $fn(&mut self, a: $field) -> () { - self.0.$fn(a); - self.1.$fn(a); + type Output = Pair<<$F as $trait>::Output, <$F as $trait>::Output>; + fn $fn(self, Pair(a, b): Pair) -> Self::Output { + Pair(self.$fn(a), self.$fn(b)) + } + } + + // Compiler overflow: + // + // impl<'a, A, B> $trait<&'a Pair> for $F + // where + // $F: $trait<&'a A> + $trait<&'a B>, + // { + // type Output = Pair<<$F as $trait<&'a A>>::Output, <$F as $trait<&'a B>>::Output>; + // fn $fn(self, Pair(a, b): &'a Pair) -> Self::Output { + // Pair(self.$fn(a), self.$fn(b)) + // } + // } + }; +} + +impl_scalar_lhs!(Mul, mul, f32); +impl_scalar_lhs!(Mul, mul, f64); +impl_scalar_lhs!(Div, div, f32); +impl_scalar_lhs!(Div, div, f64); + +macro_rules! impl_binary_mut { + ($trait:ident, $fn:ident) => { + impl<'a, A, B, C, D> $trait> for Pair + where + A: $trait, + B: $trait, + { + fn $fn(&mut self, Pair(c, d): Pair) { + let Pair(ref mut a, ref mut b) = self; + a.$fn(c); + b.$fn(d); + } + } + + impl<'a, 'b, A, B, C, D> $trait<&'b Pair> for Pair + where + A: $trait<&'b C>, + B: $trait<&'b D>, + { + fn $fn(&mut self, Pair(ref c, ref d): &'b Pair) { + let Pair(ref mut a, ref mut b) = self; + a.$fn(c); + b.$fn(d); } } }; } -macro_rules! impl_unaryop { - (($a : ty, $b : ty), $trait:ident, $fn:ident, $refl:ident) => { - impl_unaryop!(@doit: $trait, $fn; - maybe_lifetime!($refl, &'l Pair<$a,$b>), - (maybe_lifetime!($refl, &'l $a), - maybe_lifetime!($refl, &'l $b)); - $refl); - }; - (@doit: $trait:ident, $fn:ident; - $self:ty, ($aself:ty, $bself:ty); - $refl : ident) => { - impl<'l> $trait - for $self - where $aself: $trait, - $bself: $trait { - type Output = Pair<<$aself as $trait>::Output, - <$bself as $trait>::Output>; - #[inline] - fn $fn(self) -> Self::Output { - Pair(maybe_ref!($refl, self.0).$fn(), - maybe_ref!($refl, self.1).$fn()) +impl_binary_mut!(AddAssign, add_assign); +impl_binary_mut!(SubAssign, sub_assign); + +macro_rules! impl_scalar_mut { + ($trait:ident, $fn:ident) => { + impl<'a, A, B, F: Num> $trait for Pair + where + A: $trait, + B: $trait, + { + fn $fn(&mut self, t: F) { + let Pair(ref mut a, ref mut b) = self; + a.$fn(t); + b.$fn(t); } } + }; +} + +impl_scalar_mut!(MulAssign, mul_assign); +impl_scalar_mut!(DivAssign, div_assign); + +/// Trait for ownable-by-consumption objects +impl Ownable for Pair +where + A: Ownable, + B: Ownable, +{ + type OwnedVariant = Pair; + + #[inline] + fn into_owned(self) -> Self::OwnedVariant { + Pair(self.0.into_owned(), self.1.into_owned()) + } + + #[inline] + fn clone_owned(&self) -> Self::OwnedVariant { + Pair(self.0.clone_owned(), self.1.clone_owned()) + } + + #[inline] + fn cow_owned<'b>(self) -> MyCow<'b, Self::OwnedVariant> + where + Self: 'b, + { + MyCow::Owned(self.into_owned()) + } + + #[inline] + fn ref_cow_owned<'b>(&'b self) -> MyCow<'b, Self::OwnedVariant> + where + Self: 'b, + { + MyCow::Owned(self.clone_owned()) } } -#[macro_export] -macro_rules! impl_pair_vectorspace_ops { - (($a:ty, $b:ty), $field:ty) => { - impl_pair_vectorspace_ops!(@binary, ($a, $b), Add, add); - impl_pair_vectorspace_ops!(@binary, ($a, $b), Sub, sub); - impl_pair_vectorspace_ops!(@assign, ($a, $b), AddAssign, add_assign); - impl_pair_vectorspace_ops!(@assign, ($a, $b), SubAssign, sub_assign); - impl_pair_vectorspace_ops!(@scalar, ($a, $b), $field, Mul, mul); - impl_pair_vectorspace_ops!(@scalar, ($a, $b), $field, Div, div); - // Compiler overflow - // $( - // impl_pair_vectorspace_ops!(@scalar_lhs, ($a, $b), $field, $impl_scalarlhs_op, Mul, mul); - // )* - impl_pair_vectorspace_ops!(@scalar_assign, ($a, $b), $field, MulAssign, mul_assign); - impl_pair_vectorspace_ops!(@scalar_assign, ($a, $b), $field, DivAssign, div_assign); - impl_pair_vectorspace_ops!(@unary, ($a, $b), Neg, neg); - }; - (@binary, ($a : ty, $b : ty), $trait : ident, $fn : ident) => { - impl_binop!(($a, $b), $trait, $fn, ref, ref); - impl_binop!(($a, $b), $trait, $fn, ref, noref); - impl_binop!(($a, $b), $trait, $fn, noref, ref); - impl_binop!(($a, $b), $trait, $fn, noref, noref); - }; - (@assign, ($a : ty, $b : ty), $trait : ident, $fn :ident) => { - impl_assignop!(($a, $b), $trait, $fn, ref); - impl_assignop!(($a, $b), $trait, $fn, noref); - }; - (@scalar, ($a : ty, $b : ty), $field : ty, $trait : ident, $fn :ident) => { - impl_scalarop!(($a, $b), $field, $trait, $fn, ref); - impl_scalarop!(($a, $b), $field, $trait, $fn, noref); - }; - (@scalar_lhs, ($a : ty, $b : ty), $field : ty, $trait : ident, $fn : ident) => { - impl_scalarlhs_op!(($a, $b), $field, $trait, $fn, ref); - impl_scalarlhs_op!(($a, $b), $field, $trait, $fn, noref); - }; - (@scalar_assign, ($a : ty, $b : ty), $field : ty, $trait : ident, $fn : ident) => { - impl_scalar_assignop!(($a, $b), $field, $trait, $fn); - }; - (@unary, ($a : ty, $b : ty), $trait : ident, $fn : ident) => { - impl_unaryop!(($a, $b), $trait, $fn, ref); - impl_unaryop!(($a, $b), $trait, $fn, noref); - }; -} - -impl_pair_vectorspace_ops!((f32, f32), f32); -impl_pair_vectorspace_ops!((f64, f64), f64); - -type PairOutput = Pair<>::Output, >::Output>; - -impl Euclidean for Pair +/// We only support 'closed' `Euclidean` `Pair`s, as more general ones cause +/// compiler overflows. +impl Euclidean for Pair where A: Euclidean, B: Euclidean, - F: Float, - PairOutput: Euclidean, - Self: Sized - + Mul> - + MulAssign - + Div> - + DivAssign - + Add> - + Sub> - + for<'b> Add<&'b Self, Output = PairOutput> - + for<'b> Sub<&'b Self, Output = PairOutput> - + AddAssign - + for<'b> AddAssign<&'b Self> - + SubAssign - + for<'b> SubAssign<&'b Self> - + Neg>, + // //Pair: Euclidean, + // Self: Sized + // + Mul + // + MulAssign + // + Div + // + DivAssign + // + Add + // + Sub + // + for<'b> Add<&'b Self, Output = Self::OwnedEuclidean> + // + for<'b> Sub<&'b Self, Output = Self::OwnedEuclidean> + // + AddAssign + // + for<'b> AddAssign<&'b Self> + // + SubAssign + // + for<'b> SubAssign<&'b Self> + // + Neg, { - type Output = PairOutput; + type PrincipalE = Pair; fn dot>(&self, other: I) -> F { - let Pair(u, v) = other.decompose(); - self.0.dot(u) + self.1.dot(v) + other.eval_decompose(|Pair(u, v)| self.0.dot(u) + self.1.dot(v)) } fn norm2_squared(&self) -> F { @@ -279,45 +330,59 @@ } fn dist2_squared>(&self, other: I) -> F { - let Pair(u, v) = other.decompose(); - self.0.dist2_squared(u) + self.1.dist2_squared(v) + other.eval_decompose(|Pair(u, v)| self.0.dist2_squared(u) + self.1.dist2_squared(v)) } } -impl AXPY> for Pair +impl VectorSpace for Pair +where + A: VectorSpace, + B: VectorSpace, + F: Num, +{ + type Field = F; + type PrincipalV = Pair; + + /// Return a similar zero as `self`. + fn similar_origin(&self) -> Self::PrincipalV { + Pair(self.0.similar_origin(), self.1.similar_origin()) + } + + // #[inline] + // fn into_owned(self) -> Self::Owned { + // Pair(self.0.into_owned(), self.1.into_owned()) + // } +} + +impl AXPY> for Pair where U: Space, V: Space, - A: AXPY, - B: AXPY, + A: AXPY, + B: AXPY, F: Num, - Self: MulAssign, - Pair: MulAssign, - Pair: AXPY>, + // Self: MulAssign + DivAssign, + // Pair: MulAssign + DivAssign, { - type Owned = Pair; - fn axpy>>(&mut self, α: F, x: I, β: F) { - let Pair(u, v) = x.decompose(); - self.0.axpy(α, u, β); - self.1.axpy(α, v, β); + x.eval_decompose(|Pair(u, v)| { + self.0.axpy(α, u, β); + self.1.axpy(α, v, β); + }) } fn copy_from>>(&mut self, x: I) { - let Pair(u, v) = x.decompose(); - self.0.copy_from(u); - self.1.copy_from(v); + x.eval_decompose(|Pair(u, v)| { + self.0.copy_from(u); + self.1.copy_from(v); + }) } fn scale_from>>(&mut self, α: F, x: I) { - let Pair(u, v) = x.decompose(); - self.0.scale_from(α, u); - self.1.scale_from(α, v); - } - - /// Return a similar zero as `self`. - fn similar_origin(&self) -> Self::Owned { - Pair(self.0.similar_origin(), self.1.similar_origin()) + x.eval_decompose(|Pair(u, v)| { + self.0.scale_from(α, u); + self.1.scale_from(α, v); + }) } /// Set self to zero. @@ -332,6 +397,7 @@ pub struct PairDecomposition(D, Q); impl Space for Pair { + type Principal = Pair; type Decomp = PairDecomposition; } @@ -367,25 +433,16 @@ V: Instance, { #[inline] - fn decompose<'b>( - self, - ) -> as Decomposition>>::Decomposition<'b> + fn eval_ref<'b, R>(&'b self, f: impl FnOnce(Pair, Q::Reference<'b>>) -> R) -> R where + Pair: 'b, Self: 'b, - Pair: 'b, { - Pair(self.0.decompose(), self.1.decompose()) + self.0.eval_ref(|a| self.1.eval_ref(|b| f(Pair(a, b)))) } #[inline] - fn ref_instance( - &self, - ) -> as Decomposition>>::Reference<'_> { - Pair(self.0.ref_instance(), self.1.ref_instance()) - } - - #[inline] - fn cow<'b>(self) -> MyCow<'b, Pair> + fn cow<'b>(self) -> MyCow<'b, Pair> where Self: 'b, { @@ -393,9 +450,17 @@ } #[inline] - fn own(self) -> Pair { + fn own(self) -> Pair { Pair(self.0.own(), self.1.own()) } + + #[inline] + fn decompose<'b>(self) -> Pair, Q::Decomposition<'b>> + where + Self: 'b, + { + Pair(self.0.decompose(), self.1.decompose()) + } } impl<'a, A, B, U, V, D, Q> Instance, PairDecomposition> for &'a Pair @@ -409,29 +474,16 @@ &'a U: Instance, &'a V: Instance, { - #[inline] - fn decompose<'b>( - self, - ) -> as Decomposition>>::Decomposition<'b> + fn eval_ref<'b, R>(&'b self, f: impl FnOnce(Pair, Q::Reference<'b>>) -> R) -> R where + Pair: 'b, Self: 'b, - Pair: 'b, { - Pair( - D::lift(self.0.ref_instance()), - Q::lift(self.1.ref_instance()), - ) + self.0.eval_ref(|a| self.1.eval_ref(|b| f(Pair(a, b)))) } #[inline] - fn ref_instance( - &self, - ) -> as Decomposition>>::Reference<'_> { - Pair(self.0.ref_instance(), self.1.ref_instance()) - } - - #[inline] - fn cow<'b>(self) -> MyCow<'b, Pair> + fn cow<'b>(self) -> MyCow<'b, Pair> where Self: 'b, { @@ -439,10 +491,19 @@ } #[inline] - fn own(self) -> Pair { + fn own(self) -> Pair { let Pair(ref u, ref v) = self; Pair(u.own(), v.own()) } + + #[inline] + fn decompose<'b>(self) -> Pair, Q::Decomposition<'b>> + where + Self: 'b, + { + let Pair(u, v) = self; + Pair(u.decompose(), v.decompose()) + } } impl DecompositionMut> for PairDecomposition @@ -492,18 +553,63 @@ } } -impl Norm> for Pair +impl Norm, F> for Pair +where + F: Num, + ExpA: NormExponent, + ExpB: NormExponent, + ExpJ: NormExponent, + A: Norm, + B: Norm, + Loc<2, F>: Norm, +{ + fn norm(&self, PairNorm(expa, expb, expj): PairNorm) -> F { + Loc([self.0.norm(expa), self.1.norm(expb)]).norm(expj) + } +} + +impl Dist, F> for Pair where F: Num, ExpA: NormExponent, ExpB: NormExponent, ExpJ: NormExponent, - A: Norm, - B: Norm, - Loc: Norm, + A: Dist, + B: Dist, + Loc<2, F>: Norm, { - fn norm(&self, PairNorm(expa, expb, expj): PairNorm) -> F { - Loc([self.0.norm(expa), self.1.norm(expb)]).norm(expj) + fn dist>( + &self, + x: I, + PairNorm(expa, expb, expj): PairNorm, + ) -> F { + x.eval_decompose(|Pair(x1, x2)| { + Loc([self.0.dist(x1, expa), self.1.dist(x2, expb)]).norm(expj) + }) + } +} + +impl Norm for Pair +where + F: Num, + A: Norm, + B: Norm, + Loc<2, F>: Norm, +{ + fn norm(&self, _: L2) -> F { + Loc([self.0.norm(L2), self.1.norm(L2)]).norm(L2) + } +} + +impl Dist for Pair +where + F: Num, + A: Dist, + B: Dist, + Loc<2, F>: Norm, +{ + fn dist>(&self, x: I, _: L2) -> F { + x.eval_decompose(|Pair(x1, x2)| Loc([self.0.dist(x1, L2), self.1.dist(x2, L2)]).norm(L2)) } } @@ -531,4 +637,72 @@ B: HasDual, { type DualSpace = Pair; + + fn dual_origin(&self) -> ::PrincipalV { + Pair(self.0.dual_origin(), self.1.dual_origin()) + } } + +#[cfg(feature = "pyo3")] +mod python { + use super::Pair; + use pyo3::conversion::FromPyObject; + use pyo3::types::{PyAny, PyTuple}; + use pyo3::{Borrowed, Bound, IntoPyObject, PyErr, Python}; + + impl<'py, A, B> IntoPyObject<'py> for Pair + where + A: IntoPyObject<'py>, + B: IntoPyObject<'py>, + { + type Target = PyTuple; + type Error = PyErr; + type Output = Bound<'py, Self::Target>; + + fn into_pyobject(self, py: Python<'py>) -> Result { + (self.0, self.1).into_pyobject(py) + } + } + + /* + impl<'a, 'py, A, B> IntoPyObject<'py> for &'a mut Pair + where + &'a mut A: IntoPyObject<'py>, + &'a mut B: IntoPyObject<'py>, + { + type Target = PyTuple; + type Error = PyErr; + type Output = Bound<'py, Self::Target>; + + fn into_pyobject(self, py: Python<'py>) -> Result { + (&mut self.0, &mut self.1).into_pyobject(py) + } + } + + impl<'a, 'py, A, B> IntoPyObject<'py> for &'a Pair + where + &'a A: IntoPyObject<'py>, + &'a B: IntoPyObject<'py>, + { + type Target = PyTuple; + type Error = PyErr; + type Output = Bound<'py, Self::Target>; + + fn into_pyobject(self, py: Python<'py>) -> Result { + (&self.0, &self.1).into_pyobject(py) + } + } + */ + + impl<'a, 'py, A, B> FromPyObject<'a, 'py> for Pair + where + A: Clone + FromPyObject<'a, 'py>, + B: Clone + FromPyObject<'a, 'py>, + { + type Error = PyErr; + + fn extract(ob: Borrowed<'a, 'py, PyAny>) -> Result { + FromPyObject::extract(ob).map(|(a, b)| Pair(a, b)) + } + } +} diff -r 1f19c6bbf07b -r 3868555d135c src/discrete_gradient.rs --- a/src/discrete_gradient.rs Sun Apr 27 20:29:43 2025 -0500 +++ b/src/discrete_gradient.rs Fri May 15 14:46:30 2026 -0500 @@ -1,14 +1,13 @@ /*! Simple disrete gradient operators */ +use crate::error::DynResult; +use crate::instance::Instance; +use crate::linops::{Adjointable, BoundedLinear, Linear, Mapping, SimplyAdjointable, GEMV}; +use crate::norms::{Norm, L2}; +use crate::types::Float; +use nalgebra::{DVector, Dyn, Matrix, Storage, StorageMut, U1}; use numeric_literals::replace_float_literals; -use nalgebra::{ - DVector, Matrix, U1, Storage, StorageMut, Dyn -}; -use crate::types::Float; -use crate::instance::Instance; -use crate::linops::{Mapping, Linear, BoundedLinear, Adjointable, GEMV}; -use crate::norms::{Norm, L2}; #[derive(Copy, Clone, Debug)] /// Forward differences with Neumann boundary conditions @@ -27,26 +26,17 @@ pub struct BackwardNeumann; /// Finite differences gradient -pub struct Grad< - F : Float + nalgebra::RealField, - B : Discretisation, - const N : usize -> { - dims : [usize; N], - h : F, // may be negative to implement adjoints! - discretisation : B, +pub struct Grad, const N: usize> { + dims: [usize; N], + h: F, // may be negative to implement adjoints! + discretisation: B, } - /// Finite differences divergence -pub struct Div< - F : Float + nalgebra::RealField, - B : Discretisation, - const N : usize -> { - dims : [usize; N], - h : F, // may be negative to implement adjoints! - discretisation : B, +pub struct Div, const N: usize> { + dims: [usize; N], + h: F, // may be negative to implement adjoints! + discretisation: B, } /// Internal: classification of a point in a 1D discretisation @@ -62,53 +52,53 @@ use DiscretisationOrInterior::*; /// Trait for different discretisations -pub trait Discretisation : Copy { +pub trait Discretisation: Copy { /// Opposite discretisation, appropriate for adjoints with negated cell width. - type Opposite : Discretisation; + type Opposite: Discretisation; /// Add to appropiate index of `v` (as determined by `b`) the appropriate difference /// of `x` with cell width `h`. fn add_diff_mut( &self, - v : &mut Matrix, - x : &Matrix, - α : F, - b : DiscretisationOrInterior, + v: &mut Matrix, + x: &Matrix, + α: F, + b: DiscretisationOrInterior, ) where - SMut : StorageMut, - S : Storage; + SMut: StorageMut, + S: Storage; /// Give the opposite discretisation, appropriate for adjoints with negated `h`. fn opposite(&self) -> Self::Opposite; /// Bound for the corresponding operator norm. #[replace_float_literals(F::cast_from(literal))] - fn opnorm_bound(&self, h : F) -> F { + fn opnorm_bound(&self, h: F) -> DynResult { // See: Chambolle, “An Algorithm for Total Variation Minimization and Applications”. // Ok for forward and backward differences. // // Fuck nalgebra for polluting everything with its own shit. - num_traits::Float::sqrt(8.0) / h + Ok(num_traits::Float::sqrt(8.0) / h) } } -impl Discretisation for ForwardNeumann { +impl Discretisation for ForwardNeumann { type Opposite = BackwardDirichlet; #[inline] fn add_diff_mut( &self, - v : &mut Matrix, - x : &Matrix, - α : F, - b : DiscretisationOrInterior, + v: &mut Matrix, + x: &Matrix, + α: F, + b: DiscretisationOrInterior, ) where - SMut : StorageMut, - S : Storage + SMut: StorageMut, + S: Storage, { match b { - Interior(c, (_, f)) | LeftBoundary(c, f) => { v[c] += (x[f] - x[c]) * α }, - RightBoundary(_c, _b) => { }, + Interior(c, (_, f)) | LeftBoundary(c, f) => v[c] += (x[f] - x[c]) * α, + RightBoundary(_c, _b) => {} } } @@ -118,23 +108,23 @@ } } -impl Discretisation for BackwardNeumann { +impl Discretisation for BackwardNeumann { type Opposite = ForwardDirichlet; #[inline] fn add_diff_mut( &self, - v : &mut Matrix, - x : &Matrix, - α : F, - b : DiscretisationOrInterior, + v: &mut Matrix, + x: &Matrix, + α: F, + b: DiscretisationOrInterior, ) where - SMut : StorageMut, - S : Storage + SMut: StorageMut, + S: Storage, { match b { - Interior(c, (b, _)) | RightBoundary(c, b) => { v[c] += (x[c] - x[b]) * α }, - LeftBoundary(_c, _f) => { }, + Interior(c, (b, _)) | RightBoundary(c, b) => v[c] += (x[c] - x[b]) * α, + LeftBoundary(_c, _f) => {} } } @@ -144,24 +134,24 @@ } } -impl Discretisation for BackwardDirichlet { +impl Discretisation for BackwardDirichlet { type Opposite = ForwardNeumann; #[inline] fn add_diff_mut( &self, - v : &mut Matrix, - x : &Matrix, - α : F, - b : DiscretisationOrInterior, + v: &mut Matrix, + x: &Matrix, + α: F, + b: DiscretisationOrInterior, ) where - SMut : StorageMut, - S : Storage + SMut: StorageMut, + S: Storage, { match b { - Interior(c, (b, _f)) => { v[c] += (x[c] - x[b]) * α }, - LeftBoundary(c, _f) => { v[c] += x[c] * α }, - RightBoundary(c, b) => { v[c] -= x[b] * α }, + Interior(c, (b, _f)) => v[c] += (x[c] - x[b]) * α, + LeftBoundary(c, _f) => v[c] += x[c] * α, + RightBoundary(c, b) => v[c] -= x[b] * α, } } @@ -171,24 +161,24 @@ } } -impl Discretisation for ForwardDirichlet { +impl Discretisation for ForwardDirichlet { type Opposite = BackwardNeumann; #[inline] fn add_diff_mut( &self, - v : &mut Matrix, - x : &Matrix, - α : F, - b : DiscretisationOrInterior, + v: &mut Matrix, + x: &Matrix, + α: F, + b: DiscretisationOrInterior, ) where - SMut : StorageMut, - S : Storage + SMut: StorageMut, + S: Storage, { match b { - Interior(c, (_b, f)) => { v[c] += (x[f] - x[c]) * α }, - LeftBoundary(c, f) => { v[c] += x[f] * α }, - RightBoundary(c, _b) => { v[c] -= x[c] * α }, + Interior(c, (_b, f)) => v[c] += (x[f] - x[c]) * α, + LeftBoundary(c, f) => v[c] += x[f] * α, + RightBoundary(c, _b) => v[c] -= x[c] * α, } } @@ -198,30 +188,30 @@ } } -struct Iter<'a, const N : usize> { +struct Iter<'a, const N: usize> { /// Dimensions - dims : &'a [usize; N], + dims: &'a [usize; N], /// Dimension along which to calculate differences - d : usize, + d: usize, /// Stride along coordinate d - d_stride : usize, + d_stride: usize, /// Cartesian indices - i : [usize; N], + i: [usize; N], /// Linear index - k : usize, + k: usize, /// Maximal linear index - len : usize + len: usize, } -impl<'a, const N : usize> Iter<'a, N> { - fn new(dims : &'a [usize; N], d : usize) -> Self { +impl<'a, const N: usize> Iter<'a, N> { + fn new(dims: &'a [usize; N], d: usize) -> Self { let d_stride = dims[0..d].iter().product::(); let len = dims.iter().product::(); - Iter{ dims, d, d_stride, i : [0; N], k : 0, len } + Iter { dims, d, d_stride, i: [0; N], k: 0, len } } } -impl<'a, const N : usize> Iterator for Iter<'a, N> { +impl<'a, const N: usize> Iterator for Iter<'a, N> { type Item = DiscretisationOrInterior; fn next(&mut self) -> Option { let res = if self.k >= self.len { @@ -243,7 +233,7 @@ for j in 0..N { if self.i[j] + 1 < self.dims[j] { self.i[j] += 1; - break + break; } self.i[j] = 0 } @@ -251,14 +241,13 @@ } } -impl Mapping> -for Grad +impl Mapping> for Grad where - B : Discretisation, - F : Float + nalgebra::RealField, + B: Discretisation, + F: Float + nalgebra::RealField, { type Codomain = DVector; - fn apply>>(&self, i : I) -> DVector { + fn apply>>(&self, i: I) -> DVector { let mut y = DVector::zeros(N * self.len()); self.apply_add(&mut y, i); y @@ -266,15 +255,12 @@ } #[replace_float_literals(F::cast_from(literal))] -impl GEMV> -for Grad +impl GEMV> for Grad where - B : Discretisation, - F : Float + nalgebra::RealField, + B: Discretisation, + F: Float + nalgebra::RealField, { - fn gemv>>( - &self, y : &mut DVector, α : F, i : I, β : F - ) { + fn gemv>>(&self, y: &mut DVector, α: F, i: I, β: F) { if β == 0.0 { y.as_mut_slice().iter_mut().for_each(|x| *x = 0.0); } else if β != 1.0 { @@ -286,23 +272,22 @@ i.eval(|x| { assert_eq!(x.len(), m); for d in 0..N { - let mut v = y.generic_view_mut((d*m, 0), (Dyn(m), U1)); + let mut v = y.generic_view_mut((d * m, 0), (Dyn(m), U1)); for b in Iter::new(&self.dims, d) { - self.discretisation.add_diff_mut(&mut v, x, α/h, b) + self.discretisation.add_diff_mut(&mut v, x, α / h, b) } } }) } } -impl Mapping> -for Div +impl Mapping> for Div where - B : Discretisation, - F : Float + nalgebra::RealField, + B: Discretisation, + F: Float + nalgebra::RealField, { type Codomain = DVector; - fn apply>>(&self, i : I) -> DVector { + fn apply>>(&self, i: I) -> DVector { let mut y = DVector::zeros(self.len()); self.apply_add(&mut y, i); y @@ -310,15 +295,12 @@ } #[replace_float_literals(F::cast_from(literal))] -impl GEMV> -for Div +impl GEMV> for Div where - B : Discretisation, - F : Float + nalgebra::RealField, + B: Discretisation, + F: Float + nalgebra::RealField, { - fn gemv>>( - &self, y : &mut DVector, α : F, i : I, β : F - ) { + fn gemv>>(&self, y: &mut DVector, α: F, i: I, β: F) { if β == 0.0 { y.as_mut_slice().iter_mut().for_each(|x| *x = 0.0); } else if β != 1.0 { @@ -327,29 +309,46 @@ } let h = self.h; let m = self.len(); - i.eval(|x| { + i.eval_decompose(|x| { assert_eq!(x.len(), N * m); for d in 0..N { - let v = x.generic_view((d*m, 0), (Dyn(m), U1)); + let v = x.generic_view((d * m, 0), (Dyn(m), U1)); for b in Iter::new(&self.dims, d) { - self.discretisation.add_diff_mut(y, &v, α/h, b) + self.discretisation.add_diff_mut(y, &v, α / h, b) } } }) } } -impl Grad +impl Grad where - B : Discretisation, - F : Float + nalgebra::RealField + B: Discretisation, + F: Float + nalgebra::RealField, { /// Creates a new discrete gradient operator for the vector `u`, verifying dimensions. - pub fn new_for(u : &DVector, h : F, dims : [usize; N], discretisation : B) - -> Option - { + pub fn new_for(u: &DVector, h: F, dims: [usize; N], discretisation: B) -> Option { if u.len() == dims.iter().product::() { - Some(Grad { dims, h, discretisation } ) + Some(Grad { dims, h, discretisation }) + } else { + None + } + } + + fn len(&self) -> usize { + self.dims.iter().product::() + } +} + +impl Div +where + B: Discretisation, + F: Float + nalgebra::RealField, +{ + /// Creates a new discrete gradient operator for the vector `u`, verifying dimensions. + pub fn new_for(u: &DVector, h: F, dims: [usize; N], discretisation: B) -> Option { + if u.len() == dims.iter().product::() * N { + Some(Div { dims, h, discretisation }) } else { None } @@ -360,109 +359,101 @@ } } - -impl Div +impl Linear> for Grad where - B : Discretisation, - F : Float + nalgebra::RealField + B: Discretisation, + F: Float + nalgebra::RealField, +{ +} + +impl Linear> for Div +where + B: Discretisation, + F: Float + nalgebra::RealField, { - /// Creates a new discrete gradient operator for the vector `u`, verifying dimensions. - pub fn new_for(u : &DVector, h : F, dims : [usize; N], discretisation : B) - -> Option - { - if u.len() == dims.iter().product::() * N { - Some(Div { dims, h, discretisation } ) - } else { - None - } - } - - fn len(&self) -> usize { - self.dims.iter().product::() +} + +impl BoundedLinear, L2, L2, F> for Grad +where + B: Discretisation, + F: Float + nalgebra::RealField, + DVector: Norm, +{ + fn opnorm_bound(&self, _: L2, _: L2) -> DynResult { + // Fuck nalgebra. + self.discretisation + .opnorm_bound(num_traits::Float::abs(self.h)) } } -impl Linear> -for Grad -where - B : Discretisation, - F : Float + nalgebra::RealField, -{ -} - -impl Linear> -for Div +impl BoundedLinear, L2, L2, F> for Div where - B : Discretisation, - F : Float + nalgebra::RealField, + B: Discretisation, + F: Float + nalgebra::RealField, + DVector: Norm, { -} - -impl BoundedLinear, L2, L2, F> -for Grad -where - B : Discretisation, - F : Float + nalgebra::RealField, - DVector : Norm, -{ - fn opnorm_bound(&self, _ : L2, _ : L2) -> F { + fn opnorm_bound(&self, _: L2, _: L2) -> DynResult { // Fuck nalgebra. - self.discretisation.opnorm_bound(num_traits::Float::abs(self.h)) + self.discretisation + .opnorm_bound(num_traits::Float::abs(self.h)) } } - -impl BoundedLinear, L2, L2, F> -for Div +impl Adjointable, DVector> for Grad where - B : Discretisation, - F : Float + nalgebra::RealField, - DVector : Norm, + B: Discretisation, + F: Float + nalgebra::RealField, { - fn opnorm_bound(&self, _ : L2, _ : L2) -> F { - // Fuck nalgebra. - self.discretisation.opnorm_bound(num_traits::Float::abs(self.h)) + type AdjointCodomain = DVector; + type Adjoint<'a> + = Div + where + Self: 'a; + + fn adjoint(&self) -> Self::Adjoint<'_> { + Div { dims: self.dims, h: -self.h, discretisation: self.discretisation.opposite() } } } -impl -Adjointable, DVector> -for Grad +impl SimplyAdjointable, DVector> for Grad where - B : Discretisation, - F : Float + nalgebra::RealField, + B: Discretisation, + F: Float + nalgebra::RealField, { type AdjointCodomain = DVector; - type Adjoint<'a> = Div where Self : 'a; + type SimpleAdjoint = Div; - /// Form the adjoint operator of `self`. - fn adjoint(&self) -> Self::Adjoint<'_> { - Div { - dims : self.dims, - h : -self.h, - discretisation : self.discretisation.opposite(), - } + fn adjoint(&self) -> Self::SimpleAdjoint { + Div { dims: self.dims, h: -self.h, discretisation: self.discretisation.opposite() } } } - -impl -Adjointable, DVector> -for Div +impl Adjointable, DVector> for Div where - B : Discretisation, - F : Float + nalgebra::RealField, + B: Discretisation, + F: Float + nalgebra::RealField, { type AdjointCodomain = DVector; - type Adjoint<'a> = Grad where Self : 'a; + type Adjoint<'a> + = Grad + where + Self: 'a; - /// Form the adjoint operator of `self`. fn adjoint(&self) -> Self::Adjoint<'_> { - Grad { - dims : self.dims, - h : -self.h, - discretisation : self.discretisation.opposite(), - } + Grad { dims: self.dims, h: -self.h, discretisation: self.discretisation.opposite() } + } +} + +impl SimplyAdjointable, DVector> for Div +where + B: Discretisation, + F: Float + nalgebra::RealField, +{ + type AdjointCodomain = DVector; + type SimpleAdjoint = Grad; + + fn adjoint(&self) -> Self::SimpleAdjoint { + Grad { dims: self.dims, h: -self.h, discretisation: self.discretisation.opposite() } } } @@ -472,8 +463,8 @@ #[test] fn grad_adjoint() { - let im = DVector::from( (0..9).map(|t| t as f64).collect::>()); - let v = DVector::from( (0..18).map(|t| t as f64).collect::>()); + let im = DVector::from((0..9).map(|t| t as f64).collect::>()); + let v = DVector::from((0..18).map(|t| t as f64).collect::>()); let grad = Grad::new_for(&im, 1.0, [3, 3], ForwardNeumann).unwrap(); let a = grad.apply(&im).dot(&v); @@ -484,6 +475,5 @@ let a = grad.apply(&im).dot(&v); let b = grad.adjoint().apply(&v).dot(&im); assert_eq!(a, b); - } } diff -r 1f19c6bbf07b -r 3868555d135c src/euclidean.rs --- a/src/euclidean.rs Sun Apr 27 20:29:43 2025 -0500 +++ b/src/euclidean.rs Fri May 15 14:46:30 2026 -0500 @@ -2,31 +2,30 @@ Euclidean spaces. */ -use std::ops::{Mul, MulAssign, Div, DivAssign, Add, Sub, AddAssign, SubAssign, Neg}; +use crate::instance::Instance; +use crate::linops::{VectorSpace, AXPY}; +use crate::norms::{HasDual, Norm, Normed, Reflexive, L2}; use crate::types::*; -use crate::instance::Instance; -use crate::norms::{HasDual, Reflexive}; + +pub mod wrap; +// TODO: Euclidean & EuclideanMut +// /// Space (type) with Euclidean and vector space structure /// /// The type should implement vector space operations (addition, subtraction, scalar /// multiplication and scalar division) along with their assignment versions, as well /// as an inner product. -pub trait Euclidean : HasDual + Reflexive - + Mul>::Output> + MulAssign - + Div>::Output> + DivAssign - + Add>::Output> - + Sub>::Output> - + for<'b> Add<&'b Self, Output=>::Output> - + for<'b> Sub<&'b Self, Output=>::Output> - + AddAssign + for<'b> AddAssign<&'b Self> - + SubAssign + for<'b> SubAssign<&'b Self> - + Neg>::Output> +// TODO: remove F parameter, use VectorSpace::Field +pub trait Euclidean: + VectorSpace + Reflexive { - type Output : Euclidean; + /// Principal form of the space; always equal to [`crate::linops::Space::Principal`] and + /// [`VectorSpace::PrincipalV`], but with more traits guaranteed. + type PrincipalE: ClosedEuclidean; // Inner product - fn dot>(&self, other : I) -> F; + fn dot>(&self, other: I) -> F; /// Calculate the square of the 2-norm, $\frac{1}{2}\\|x\\|_2^2$, where `self` is $x$. /// @@ -38,7 +37,7 @@ /// where `self` is $x$. #[inline] fn norm2_squared_div2(&self) -> F { - self.norm2_squared()/F::TWO + self.norm2_squared() / F::TWO } /// Calculate the 2-norm $‖x‖_2$, where `self` is $x$. @@ -48,33 +47,119 @@ } /// Calculate the 2-distance squared $\\|x-y\\|_2^2$, where `self` is $x$. - fn dist2_squared>(&self, y : I) -> F; + fn dist2_squared>(&self, y: I) -> F; /// Calculate the 2-distance $\\|x-y\\|_2$, where `self` is $x$. #[inline] - fn dist2>(&self, y : I) -> F { + fn dist2>(&self, y: I) -> F { self.dist2_squared(y).sqrt() } /// Projection to the 2-ball. #[inline] - fn proj_ball2(mut self, ρ : F) -> Self { - self.proj_ball2_mut(ρ); - self + fn proj_ball2(self, ρ: F) -> Self::PrincipalV { + let r = self.norm2(); + if r > ρ { + self * (ρ / r) + } else { + self.into_owned() + } } +} +pub trait ClosedEuclidean: + Instance + Euclidean +{ +} +impl + Euclidean> ClosedEuclidean for X {} + +// TODO: remove F parameter, use AXPY::Field +pub trait EuclideanMut: Euclidean + AXPY { /// In-place projection to the 2-ball. #[inline] - fn proj_ball2_mut(&mut self, ρ : F) { + fn proj_ball2_mut(&mut self, ρ: F) { let r = self.norm2(); - if r>ρ { - *self *= ρ/r + if r > ρ { + *self *= ρ / r } } } +impl EuclideanMut for X where X: Euclidean + AXPY {} + /// Trait for [`Euclidean`] spaces with dimensions known at compile time. -pub trait StaticEuclidean : Euclidean { +pub trait StaticEuclidean: Euclidean { /// Returns the origin - fn origin() -> >::Output; + fn origin() -> >::PrincipalE; } + +macro_rules! scalar_euclidean { + ($f:ident) => { + impl VectorSpace for $f { + type Field = $f; + type PrincipalV = $f; + + #[inline] + fn similar_origin(&self) -> Self::PrincipalV { + 0.0 + } + } + impl AXPY for $f { + #[inline] + fn axpy>(&mut self, α: $f, x: I, β: $f) { + *self = β * *self + α * x.own() + } + + #[inline] + fn set_zero(&mut self) { + *self = 0.0 + } + } + + impl Norm for $f { + fn norm(&self, _p: L2) -> $f { + self.abs() + } + } + + impl Normed<$f> for $f { + type NormExp = L2; + + fn norm_exponent(&self) -> Self::NormExp { + L2 + } + } + + impl HasDual<$f> for $f { + type DualSpace = $f; + + #[inline] + fn dual_origin(&self) -> $f { + 0.0 + } + } + + impl Euclidean<$f> for $f { + type PrincipalE = $f; + + #[inline] + fn dot>(&self, other: I) -> $f { + *self * other.own() + } + + #[inline] + fn norm2_squared(&self) -> $f { + *self * *self + } + + #[inline] + fn dist2_squared>(&self, y: I) -> $f { + let d = *self - y.own(); + d * d + } + } + }; +} + +scalar_euclidean!(f64); +scalar_euclidean!(f32); diff -r 1f19c6bbf07b -r 3868555d135c src/euclidean/wrap.rs --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/euclidean/wrap.rs Fri May 15 14:46:30 2026 -0500 @@ -0,0 +1,313 @@ +/*! +Wrappers for implemention [`Euclidean`] operations. +*/ + +use crate::euclidean::Euclidean; +use crate::instance::Space; +use crate::types::Float; + +pub trait WrapGuard<'a, F: Float> { + type View<'b>: Euclidean + where + Self: 'b; + fn get_view(&self) -> Self::View<'_>; +} + +pub trait WrapGuardMut<'a, F: Float> { + type ViewMut<'b>: Euclidean + where + Self: 'b; + fn get_view_mut(&mut self) -> Self::ViewMut<'_>; +} + +pub trait Wrapped: Space { + type WrappedField: Float; + type Guard<'a>: WrapGuard<'a, Self::WrappedField> + where + Self: 'a; + type GuardMut<'a>: WrapGuardMut<'a, Self::WrappedField> + where + Self: 'a; + type UnwrappedOutput; + type WrappedOutput; + fn get_guard(&self) -> Self::Guard<'_>; + fn get_guard_mut(&mut self) -> Self::GuardMut<'_>; + fn wrap(output: Self::UnwrappedOutput) -> Self::WrappedOutput; +} + +#[macro_export] +macro_rules! wrap { + // Rust macros are totally fucked up. $trait:path does not work, have to + // manually code paths through $($trait:ident)::+. + (impl_unary $type:ty, $($trait:ident)::+, $fn:ident where $($qual:tt)*) => { + impl<$($qual)*> $($trait)::+ for $type { + type Output = ::WrappedOutput; + fn $fn(self) -> Self::Output { + let a = self.get_guard(); + Self::wrap(a.get_view().$fn()) + } + } + }; + (impl_binary $type:ty, $($trait:ident)::+, $fn:ident where $($qual:tt)*) => { + impl<$($qual)*> $($trait)::+<$type> for $type { + type Output = ::WrappedOutput; + fn $fn(self, other: $type) -> Self::Output { + let a = self.get_guard(); + let b = other.get_guard(); + Self::wrap(a.get_view().$fn(b.get_view())) + } + } + + impl<'a, $($qual)*> $($trait)::+<$type> for &'a $type { + type Output = <$type as $crate::euclidean::wrap::Wrapped>::WrappedOutput; + fn $fn(self, other: $type) -> Self::Output { + let a = self.get_guard(); + let b = other.get_guard(); + <$type>::wrap(a.get_view().$fn(b.get_view())) + } + } + + impl<'a, 'b, $($qual)*> $($trait)::+<&'b $type> for &'a $type { + type Output = <$type as $crate::euclidean::wrap::Wrapped>::WrappedOutput; + fn $fn(self, other: &'b $type) -> Self::Output { + let a = self.get_guard(); + let b = other.get_guard(); + <$type>::wrap(a.get_view().$fn(b.get_view())) + } + } + + impl<'b, $($qual)*> $($trait)::+<&'b $type> for $type { + type Output = ::WrappedOutput; + fn $fn(self, other: &'b $type) -> Self::Output { + let a = self.get_guard(); + let b = other.get_guard(); + Self::wrap(a.get_view().$fn(b.get_view())) + } + } + }; + (impl_scalar $F:ty, $type:ty, $($trait:ident)::+, $fn:ident where $($qual:tt)*) => { + impl<$($qual)*> $($trait)::+<$F> for $type + // where + // $type: $crate::euclidean::wrap::Wrapped, + // //$type::Unwrapped: $($trait)::+, + { + type Output = ::WrappedOutput; + fn $fn(self, t: $F) -> Self::Output { + let a = self.get_guard(); + Self::wrap(a.get_view().$fn(t)) + } + } + + impl<'a, $($qual)*> $($trait)::+<$F> for &'a $type + // where + // $type: $crate::euclidean::wrap::Wrapped, + // //$type::Unwrapped: $($trait)::+, + { + type Output = <$type as $crate::euclidean::wrap::Wrapped>::WrappedOutput; + fn $fn(self, t: $F) -> Self::Output { + let a = self.get_guard(); + <$type>::wrap(a.get_view().$fn(t)) + } + } + + }; + (impl_scalar_lhs $F:ty, $type:ty, $($trait:ident)::+, $fn:ident where $($qual:tt)*) => { + impl<$($qual)*> $($trait)::+<$type> for $F + // where + // $type: $crate::euclidean::wrap::Wrapped, + // // where + // // $F: $($trait)::+<$type::Unwrapped>, + { + type Output = <$type as $crate::euclidean::wrap::Wrapped>::WrappedOutput; + fn $fn(self, rhs: $type) -> Self::Output { + let b = rhs.get_guard(); + <$type>::wrap(self.$fn(b.get_view())) + } + } + }; + (impl_binary_mut $type:ty, $($trait:ident)::+, $fn:ident where $($qual:tt)*) => { + impl<$($qual)*> $($trait)::+<$type> for $type { + fn $fn(&mut self, rhs: $type) { + let mut a = self.get_guard_mut(); + let b = rhs.get_guard(); + a.get_view_mut().$fn(b.get_view()) + } + } + + impl<'b, $($qual)*> $($trait)::+<&'b $type> for $type { + fn $fn(&mut self, rhs: &'b $type) { + let mut a = self.get_guard_mut(); + let b = rhs.get_guard(); + a.get_view_mut().$fn(b.get_view()) + } + } + }; + (impl_scalar_mut $F:ty, $type:ty, $($trait:ident)::+, $fn:ident where $($qual:tt)*) => { + impl<$($qual)*> $($trait)::+<$F> for $type + // where + // $type: $crate::euclidean::wrap::Wrapped, + // // where + // // $type::UnwrappedMut: $($trait)::+<$($trait)::+>, + { + fn $fn(&mut self, t: $F) { + let mut a = self.get_guard_mut(); + a.get_view_mut().$fn(t) + } + } + }; + // ($type:ty) => { + // $crate::wrap!(imp<> do $type); + // }; + ($F:ty; $type:ty where $($qual:tt)*) => { + + $crate::wrap!(impl_unary $type, std::ops::Neg, neg where $($qual)*); + $crate::wrap!(impl_binary $type, std::ops::Add, add where $($qual)*); + $crate::wrap!(impl_binary $type, std::ops::Sub, sub where $($qual)*); + $crate::wrap!(impl_scalar $F, $type, std::ops::Mul, mul where $($qual)*); + $crate::wrap!(impl_scalar $F, $type, std::ops::Div, div where $($qual)*); + $crate::wrap!(impl_scalar_lhs $F, $type, std::ops::Mul, mul where $($qual)*); + $crate::wrap!(impl_binary_mut $type, std::ops::AddAssign, add_assign where $($qual)*); + $crate::wrap!(impl_binary_mut $type, std::ops::SubAssign, sub_assign where $($qual)*); + $crate::wrap!(impl_scalar_mut $F, $type, std::ops::MulAssign, mul_assign where $($qual)*); + $crate::wrap!(impl_scalar_mut $F, $type, std::ops::DivAssign, div_assign where $($qual)*); + + $crate::self_ownable!($type where $($qual)*); + + impl<$($qual)*> $crate::norms::Norm<$crate::norms::L2, $F> for $type + { + fn norm(&self, p : $crate::norms::L2) -> $F { + let a = self.get_guard(); + $crate::norms::Norm::norm(&a.get_view(), p) + } + } + + impl<$($qual)*> $crate::norms::Dist<$crate::norms::L2, $F> for $type + { + fn dist>(&self, other : I, p : $crate::norms::L2) -> $F { + other.eval_ref(|other| { + let a = self.get_guard(); + let b = other.get_guard(); + a.get_view().dist(b.get_view(), p) + }) + } + } + + impl<$($qual)*> $crate::norms::Normed<$F> for $type { + type NormExp = $crate::norms::L2; + + fn norm_exponent(&self) -> Self::NormExp { + $crate::norms::L2 + } + } + + impl<$($qual)*> $crate::norms::HasDual<$F> for $type { + type DualSpace = Self; + + fn dual_origin(&self) -> Self { + $crate::linops::VectorSpace::similar_origin(self) + } + } + + impl<$($qual)*> $crate::euclidean::Euclidean<$F> for $type + // where + // Self: $crate::euclidean::wrap::Wrapped + // + Sized + // + std::ops::Mul::Owned> + // + std::ops::MulAssign + // + std::ops::Div::Owned> + // + std::ops::DivAssign + // + std::ops::Add::Owned> + // + std::ops::Sub::Owned> + // + for<'b> std::ops::Add<&'b Self, Output = ::Owned> + // + for<'b> std::ops::Sub<&'b Self, Output = ::Owned> + // + std::ops::AddAssign + // + for<'b> std::ops::AddAssign<&'b Self> + // + std::ops::SubAssign + // + for<'b> std::ops::SubAssign<&'b Self> + // + std::ops::Neg::Owned>, + { + type PrincipalE = Self; + + fn dot>(&self, other: I) -> $F { + other.eval_decompose(|other| { + let a = self.get_guard(); + let b = other.get_guard(); + a.get_view().dot(&b.get_view()) + }) + } + + fn norm2_squared(&self) -> $F { + let a = self.get_guard(); + a.get_view().norm2_squared() + } + + fn dist2_squared>(&self, other: I) -> $F { + other.eval_decompose(|other| { + let a = self.get_guard(); + let b = other.get_guard(); + a.get_view().dist2_squared(b.get_view()) + }) + } + } + + impl<$($qual)*> $crate::linops::VectorSpace for $type + // where + // Self : $crate::euclidean::wrap::Wrapped, + // Self::Unwrapped : $crate::linops::AXPY, + // Self: std::ops::MulAssign + std::ops::DivAssign, + // Self::Unwrapped: std::ops::MulAssign + std::ops::DivAssign, + { + type Field = $F; + type PrincipalV = Self; + + /// Return a similar zero as `self`. + fn similar_origin(&self) -> Self::PrincipalV { + let a = self.get_guard(); + Self::wrap(a.get_view().similar_origin()) + } + } + + impl<$($qual)*> $crate::linops::AXPY for $type + // where + // Self : $crate::euclidean::wrap::Wrapped, + // Self::Unwrapped : $crate::linops::AXPY, + // Self: std::ops::MulAssign + std::ops::DivAssign, + // Self::Unwrapped: std::ops::MulAssign + std::ops::DivAssign, + { + fn axpy>(&mut self, α: $F, x: I, β: $F) { + x.eval_decompose(|other| { + let mut a = self.get_guard_mut(); + let b = other.get_guard(); + $crate::linops::AXPY::axpy(&mut a.get_view_mut(), α, b.get_view(), β) + }) + } + + fn copy_from>(&mut self, x: I) { + x.eval_decompose(|other| { + let mut a = self.get_guard_mut(); + let b = other.get_guard(); + $crate::linops::AXPY::copy_from(&mut a.get_view_mut(), b.get_view()) + }) + } + + fn scale_from>(&mut self, α: $F, x: I) { + x.eval_decompose(|other| { + let mut a = self.get_guard_mut(); + let b = other.get_guard(); + $crate::linops::AXPY::scale_from(&mut a.get_view_mut(), α, b.get_view()) + }) + } + + /// Set self to zero. + fn set_zero(&mut self) { + let mut a = self.get_guard_mut(); + a.get_view_mut().set_zero() + } + } + + impl<$($qual)*> $crate::instance::Space for $type { + type Decomp = $crate::instance::BasicDecomposition; + type Principal = Self; + } + }; +} diff -r 1f19c6bbf07b -r 3868555d135c src/fe_model/p2_local_model.rs --- a/src/fe_model/p2_local_model.rs Sun Apr 27 20:29:43 2025 -0500 +++ b/src/fe_model/p2_local_model.rs Fri May 15 14:46:30 2026 -0500 @@ -2,21 +2,21 @@ Second order polynomical (P2) models on real intervals and planar 2D simplices. */ -use crate::types::*; -use crate::loc::Loc; -use crate::sets::{Set,NPolygon,SpannedHalfspace}; -use crate::linsolve::*; +use super::base::{LocalModel, RealLocalModel}; use crate::euclidean::Euclidean; use crate::instance::Instance; -use super::base::{LocalModel,RealLocalModel}; +use crate::linsolve::*; +use crate::loc::Loc; use crate::sets::Cube; +use crate::sets::{NPolygon, Set, SpannedHalfspace}; +use crate::types::*; use numeric_literals::replace_float_literals; /// Type for simplices of arbitrary dimension `N`. /// /// The type parameter `D` indicates the number of nodes. (Rust's const generics do not currently /// allow its automatic calculation from `N`.) -pub struct Simplex(pub [Loc; D]); +pub struct Simplex(pub [Loc; D]); /// A two-dimensional planar simplex pub type PlanarSimplex = Simplex; /// A real interval @@ -25,26 +25,30 @@ /// Calculates (a+b)/2 #[inline] #[replace_float_literals(F::cast_from(literal))] -pub(crate) fn midpoint(a : &Loc, b : &Loc) -> Loc { - (a+b)/2.0 +pub(crate) fn midpoint(a: &Loc, b: &Loc) -> Loc { + (a + b) / 2.0 } -impl<'a, F : Float> Set> for RealInterval { +impl<'a, F: Float> Set> for RealInterval { #[inline] - fn contains>>(&self, z : I) -> bool { - let &Loc([x]) = z.ref_instance(); - let &[Loc([x0]), Loc([x1])] = &self.0; - (x0 < x && x < x1) || (x1 < x && x < x0) + fn contains>>(&self, z: I) -> bool { + z.eval_ref(|&Loc([x])| { + let &[Loc([x0]), Loc([x1])] = &self.0; + (x0 < x && x < x1) || (x1 < x && x < x0) + }) } } -impl<'a, F : Float> Set> for PlanarSimplex { +impl<'a, F: Float> Set> for PlanarSimplex { #[inline] - fn contains>>(&self, z : I) -> bool { + fn contains>>(&self, z: I) -> bool { let &[x0, x1, x2] = &self.0; - NPolygon([[x0, x1].spanned_halfspace(), - [x1, x2].spanned_halfspace(), - [x2, x0].spanned_halfspace()]).contains(z) + NPolygon([ + [x0, x1].spanned_halfspace(), + [x1, x2].spanned_halfspace(), + [x2, x0].spanned_halfspace(), + ]) + .contains(z) } } @@ -58,121 +62,118 @@ } #[replace_float_literals(F::cast_from(literal))] -impl P2Powers for Loc { - type Output = Loc; - type Full = Loc; - type Diff = Loc, 1>; +impl P2Powers for Loc<1, F> { + type Output = Loc<1, F>; + type Full = Loc<3, F>; + type Diff = Loc<1, Loc<1, F>>; #[inline] fn p2powers(&self) -> Self::Output { let &Loc([x0]) = self; - [x0*x0].into() + [x0 * x0].into() } #[inline] fn p2powers_full(&self) -> Self::Full { let &Loc([x0]) = self; - [1.0, x0, x0*x0].into() + [1.0, x0, x0 * x0].into() } #[inline] fn p2powers_diff(&self) -> Self::Diff { let &Loc([x0]) = self; - [[x0+x0].into()].into() + [[x0 + x0].into()].into() } } #[replace_float_literals(F::cast_from(literal))] -impl P2Powers for Loc { - type Output = Loc; - type Full = Loc; - type Diff = Loc, 2>; +impl P2Powers for Loc<2, F> { + type Output = Loc<3, F>; + type Full = Loc<6, F>; + type Diff = Loc<2, Loc<3, F>>; #[inline] fn p2powers(&self) -> Self::Output { let &Loc([x0, x1]) = self; - [x0*x0, x0*x1, x1*x1].into() + [x0 * x0, x0 * x1, x1 * x1].into() } #[inline] fn p2powers_full(&self) -> Self::Full { let &Loc([x0, x1]) = self; - [1.0, x0, x1, x0*x0, x0*x1, x1*x1].into() + [1.0, x0, x1, x0 * x0, x0 * x1, x1 * x1].into() } #[inline] fn p2powers_diff(&self) -> Self::Diff { let &Loc([x0, x1]) = self; - [[x0+x0, x1, 0.0].into(), [0.0, x0, x1+x1].into()].into() + [[x0 + x0, x1, 0.0].into(), [0.0, x0, x1 + x1].into()].into() } } /// A trait for generating second order polynomial model of dimension `N` on `Self´. /// /// `Self` should present a subset aset of elements of the type [`Loc`]``. -pub trait P2Model { +pub trait P2Model { /// Implementation type of the second order polynomical model. /// Typically a [`P2LocalModel`]. - type Model : LocalModel,F>; + type Model: LocalModel, F>; /// Generates a second order polynomial model of the function `g` on `Self`. - fn p2_model) -> F>(&self, g : G) -> Self::Model; + fn p2_model) -> F>(&self, g: G) -> Self::Model; } /// A local second order polynomical model of dimension `N` with `E` edges -pub struct P2LocalModel { - a0 : F, - a1 : Loc, - a2 : Loc, - //node_values : Loc, - //edge_values : Loc, +pub struct P2LocalModel< + F: Num, + const N: usize, + const E: usize, /*, const V : usize, const Q : usize*/ +> { + a0: F, + a1: Loc, + a2: Loc, + //node_values : Loc, + //edge_values : Loc, } // // 1D planar model construction // -impl RealInterval { +impl RealInterval { #[inline] - fn midpoints(&self) -> [Loc; 1] { + pub fn midpoints(&self) -> [Loc<1, F>; 1] { let [ref n0, ref n1] = &self.0; let n01 = midpoint(n0, n1); [n01] } } -impl P2LocalModel { +impl P2LocalModel { /// Creates a new 1D second order polynomical model based on three nodal coordinates and /// corresponding function values. #[inline] - pub fn new( - &[n0, n1, n01] : &[Loc; 3], - &[v0, v1, v01] : &[F; 3], - ) -> Self { - let p = move |x : &Loc, v : F| { + pub fn new(&[n0, n1, n01]: &[Loc<1, F>; 3], &[v0, v1, v01]: &[F; 3]) -> Self { + let p = move |x: &Loc<1, F>, v: F| { let Loc([c, d, e]) = x.p2powers_full(); [c, d, e, v] }; - let [a0, a1, a11] = linsolve([ - p(&n0, v0), - p(&n1, v1), - p(&n01, v01) - ]); + let [a0, a1, a11] = linsolve([p(&n0, v0), p(&n1, v1), p(&n01, v01)]); P2LocalModel { - a0 : a0, - a1 : [a1].into(), - a2 : [a11].into(), + a0: a0, + a1: [a1].into(), + a2: [a11].into(), //node_values : [v0, v1].into(), //edge_values: [].into(), } } } -impl P2Model for RealInterval { - type Model = P2LocalModel; +impl P2Model for RealInterval { + type Model = P2LocalModel; #[inline] - fn p2_model) -> F>(&self, g : G) -> Self::Model { - let [n01] = self.midpoints(); + fn p2_model) -> F>(&self, g: G) -> Self::Model { + let [n01] = self.midpoints(); let [n0, n1] = self.0; let vals = [g(&n0), g(&n1), g(&n01)]; let nodes = [n0, n1, n01]; @@ -184,10 +185,10 @@ // 2D planar model construction // -impl PlanarSimplex { +impl PlanarSimplex { #[inline] /// Returns the midpoints of all the edges of the simplex - fn midpoints(&self) -> [Loc; 3] { + pub fn midpoints(&self) -> [Loc<2, F>; 3] { let [ref n0, ref n1, ref n2] = &self.0; let n01 = midpoint(n0, n1); let n12 = midpoint(n1, n2); @@ -196,15 +197,15 @@ } } -impl P2LocalModel { +impl P2LocalModel { /// Creates a new 2D second order polynomical model based on six nodal coordinates and /// corresponding function values. #[inline] pub fn new( - &[n0, n1, n2, n01, n12, n20] : &[Loc; 6], - &[v0, v1, v2, v01, v12, v20] : &[F; 6], + &[n0, n1, n2, n01, n12, n20]: &[Loc<2, F>; 6], + &[v0, v1, v2, v01, v12, v20]: &[F; 6], ) -> Self { - let p = move |x : &Loc, v :F| { + let p = move |x: &Loc<2, F>, v: F| { let Loc([c, d, e, f, g, h]) = x.p2powers_full(); [c, d, e, f, g, h, v] }; @@ -217,20 +218,20 @@ p(&n20, v20), ]); P2LocalModel { - a0 : a0, - a1 : [a1, a2].into(), - a2 : [a11, a12, a22].into(), + a0: a0, + a1: [a1, a2].into(), + a2: [a11, a12, a22].into(), //node_values : [v0, v1, v2].into(), //edge_values: [v01, v12, v20].into(), } } } -impl P2Model for PlanarSimplex { - type Model = P2LocalModel; +impl P2Model for PlanarSimplex { + type Model = P2LocalModel; #[inline] - fn p2_model) -> F>(&self, g : G) -> Self::Model { + fn p2_model) -> F>(&self, g: G) -> Self::Model { let midpoints = self.midpoints(); let [ref n0, ref n1, ref n2] = self.0; let [ref n01, ref n12, ref n20] = midpoints; @@ -242,125 +243,132 @@ macro_rules! impl_local_model { ($n:literal, $e:literal, $v:literal, $q:literal) => { - impl LocalModel, F> for P2LocalModel { + impl LocalModel, F> for P2LocalModel { #[inline] - fn value(&self, x : &Loc) -> F { + fn value(&self, x: &Loc<$n, F>) -> F { self.a0 + x.dot(&self.a1) + x.p2powers().dot(&self.a2) } #[inline] - fn differential(&self, x : &Loc) -> Loc { + fn differential(&self, x: &Loc<$n, F>) -> Loc<$n, F> { self.a1 + x.p2powers_diff().map(|di| di.dot(&self.a2)) } } - } + }; } impl_local_model!(1, 1, 2, 0); impl_local_model!(2, 3, 3, 3); - // // Minimisation // #[replace_float_literals(F::cast_from(literal))] -impl P2LocalModel { +impl P2LocalModel { /// Minimises the model along the edge `[x0, x1]`. #[inline] - fn minimise_edge(&self, x0 : Loc, x1 : Loc) -> (Loc, F) { - let &P2LocalModel{ - a1 : Loc([a1]), - a2 : Loc([a11]), + fn minimise_edge(&self, x0: Loc<1, F>, x1: Loc<1, F>) -> (Loc<1, F>, F) { + let &P2LocalModel { + a1: Loc([a1]), + a2: Loc([a11]), //node_values : Loc([v0, v1]), .. - } = self; + } = self; // We do this in cases, first trying for an interior solution, then edges. // For interior solution, first check determinant; no point trying if non-positive if a11 > 0.0 { // An interior solution x[1] has to satisfy // 2a₁₁*x[1] + a₁ =0 // This gives - let t = -a1/(2.0*a11); + let t = -a1 / (2.0 * a11); let (Loc([t0]), Loc([t1])) = (x0, x1); if (t0 <= t && t <= t1) || (t1 <= t && t <= t0) { let x = [t].into(); let v = self.value(&x); - return (x, v) + return (x, v); } } let v0 = self.value(&x0); let v1 = self.value(&x1); - if v0 < v1 { (x0, v0) } else { (x1, v1) } + if v0 < v1 { + (x0, v0) + } else { + (x1, v1) + } } } -impl<'a, F : Float> RealLocalModel,Loc,F> -for P2LocalModel { +impl<'a, F: Float> RealLocalModel, Loc<1, F>, F> + for P2LocalModel +{ #[inline] - fn minimise(&self, &Simplex([x0, x1]) : &RealInterval) -> (Loc, F) { + fn minimise(&self, &Simplex([x0, x1]): &RealInterval) -> (Loc<1, F>, F) { self.minimise_edge(x0, x1) } } #[replace_float_literals(F::cast_from(literal))] -impl P2LocalModel { +impl P2LocalModel { /// Minimise the 2D model along the edge `[x0, x1] = {x0 + t(x1 - x0) | t ∈ [0, 1] }`. #[inline] - fn minimise_edge(&self, x0 : &Loc, x1 : &Loc/*, v0 : F, v1 : F*/) -> (Loc, F) { - let &P2LocalModel { - a0, - a1 : Loc([a1, a2]), - a2 : Loc([a11, a12, a22]), - .. - } = self; + fn minimise_edge( + &self, + x0: &Loc<2, F>, + x1: &Loc<2, F>, /*, v0 : F, v1 : F*/ + ) -> (Loc<2, F>, F) { + let &P2LocalModel { a0, a1: Loc([a1, a2]), a2: Loc([a11, a12, a22]), .. } = self; let &Loc([x00, x01]) = x0; - let d@Loc([d0, d1]) = x1 - x0; - let b0 = a0 + a1*x00 + a2*x01 + a11*x00*x00 + a12*x00*x01 + a22*x01*x01; - let b1 = a1*d0 + a2*d1 + 2.0*a11*d0*x00 + a12*(d0*x01 + d1*x00) + 2.0*a22*d1*x01; - let b11 = a11*d0*d0 + a12*d0*d1 + a22*d1*d1; + let d @ Loc([d0, d1]) = x1 - x0; + let b0 = a0 + a1 * x00 + a2 * x01 + a11 * x00 * x00 + a12 * x00 * x01 + a22 * x01 * x01; + let b1 = a1 * d0 + + a2 * d1 + + 2.0 * a11 * d0 * x00 + + a12 * (d0 * x01 + d1 * x00) + + 2.0 * a22 * d1 * x01; + let b11 = a11 * d0 * d0 + a12 * d0 * d1 + a22 * d1 * d1; let edge_1d_model = P2LocalModel { - a0 : b0, - a1 : Loc([b1]), - a2 : Loc([b11]), + a0: b0, + a1: Loc([b1]), + a2: Loc([b11]), //node_values : Loc([v0, v1]), }; let (Loc([t]), v) = edge_1d_model.minimise_edge(0.0.into(), 1.0.into()); - (x0 + d*t, v) + (x0 + d * t, v) } } #[replace_float_literals(F::cast_from(literal))] -impl<'a, F : Float> RealLocalModel,Loc,F> -for P2LocalModel { +impl<'a, F: Float> RealLocalModel, Loc<2, F>, F> + for P2LocalModel +{ #[inline] - fn minimise(&self, el : &PlanarSimplex) -> (Loc, F) { + fn minimise(&self, el: &PlanarSimplex) -> (Loc<2, F>, F) { let &P2LocalModel { - a1 : Loc([a1, a2]), - a2 : Loc([a11, a12, a22]), + a1: Loc([a1, a2]), + a2: Loc([a11, a12, a22]), //node_values : Loc([v0, v1, v2]), .. - } = self; + } = self; // We do this in cases, first trying for an interior solution, then edges. // For interior solution, first check determinant; no point trying if non-positive - let r = 2.0*(a11*a22-a12*a12); + let r = 2.0 * (a11 * a22 - a12 * a12); if r > 0.0 { // An interior solution (x[1], x[2]) has to satisfy // 2a₁₁*x[1] + 2a₁₂*x[2]+a₁ =0 and 2a₂₂*x[1] + 2a₁₂*x[1]+a₂=0 // This gives - let x = [(a22*a1-a12*a2)/r, (a12*a1-a11*a2)/r].into(); + let x = [(a22 * a1 - a12 * a2) / r, (a12 * a1 - a11 * a2) / r].into(); if el.contains(&x) { - return (x, self.value(&x)) + return (x, self.value(&x)); } } let &[ref x0, ref x1, ref x2] = &el.0; let mut min_edge = self.minimise_edge(x0, x1); - let more_edge = [self.minimise_edge(x1, x2), - self.minimise_edge(x2, x0)]; - + let more_edge = [self.minimise_edge(x1, x2), self.minimise_edge(x2, x0)]; + for edge in more_edge { if edge.1 < min_edge.1 { min_edge = edge; @@ -372,35 +380,38 @@ } #[replace_float_literals(F::cast_from(literal))] -impl<'a, F : Float> RealLocalModel,Loc,F> -for P2LocalModel { +impl<'a, F: Float> RealLocalModel, Loc<2, F>, F> + for P2LocalModel +{ #[inline] - fn minimise(&self, el : &Cube) -> (Loc, F) { + fn minimise(&self, el: &Cube<2, F>) -> (Loc<2, F>, F) { let &P2LocalModel { - a1 : Loc([a1, a2]), - a2 : Loc([a11, a12, a22]), + a1: Loc([a1, a2]), + a2: Loc([a11, a12, a22]), //node_values : Loc([v0, v1, v2]), .. - } = self; + } = self; // We do this in cases, first trying for an interior solution, then edges. // For interior solution, first check determinant; no point trying if non-positive - let r = 2.0*(a11*a22-a12*a12); + let r = 2.0 * (a11 * a22 - a12 * a12); if r > 0.0 { // An interior solution (x[1], x[2]) has to satisfy // 2a₁₁*x[1] + 2a₁₂*x[2]+a₁ =0 and 2a₂₂*x[1] + 2a₁₂*x[1]+a₂=0 // This gives - let x = [(a22*a1-a12*a2)/r, (a12*a1-a11*a2)/r].into(); + let x = [(a22 * a1 - a12 * a2) / r, (a12 * a1 - a11 * a2) / r].into(); if el.contains(&x) { - return (x, self.value(&x)) + return (x, self.value(&x)); } } let [x0, x1, x2, x3] = el.corners(); let mut min_edge = self.minimise_edge(&x0, &x1); - let more_edge = [self.minimise_edge(&x1, &x2), - self.minimise_edge(&x2, &x3), - self.minimise_edge(&x3, &x0)]; + let more_edge = [ + self.minimise_edge(&x1, &x2), + self.minimise_edge(&x2, &x3), + self.minimise_edge(&x3, &x0), + ]; for edge in more_edge { if edge.1 < min_edge.1 { @@ -422,7 +433,7 @@ let domain = Simplex(vertices); // A simple quadratic function for which the approximation is exact on reals, // and appears exact on f64 as well. - let f = |&Loc([x]) : &Loc| x*x + x + 1.0; + let f = |&Loc([x]): &Loc<1, f64>| x * x + x + 1.0; let model = domain.p2_model(f); let xs = [Loc([0.5]), Loc([0.25])]; @@ -439,7 +450,7 @@ let domain = Simplex(vertices); // A simple quadratic function for which the approximation is exact on reals, // and appears exact on f64 as well. - let f = |&Loc([x, y]) : &Loc| - (x*x + x*y + x - 2.0 * y + 1.0); + let f = |&Loc([x, y]): &Loc<2, f64>| -(x * x + x * y + x - 2.0 * y + 1.0); let model = domain.p2_model(f); let xs = [Loc([0.5, 0.5]), Loc([0.25, 0.25])]; diff -r 1f19c6bbf07b -r 3868555d135c src/instance.rs --- a/src/instance.rs Sun Apr 27 20:29:43 2025 -0500 +++ b/src/instance.rs Fri May 15 14:46:30 2026 -0500 @@ -23,32 +23,252 @@ } } -impl<'b, X> MyCow<'b, X> { +/// Trait for ownable-by-consumption objects +pub trait Ownable { + type OwnedVariant: Clone; + + /// Returns an owned instance, possibly consuming the original, + /// avoiding cloning when possible. + fn into_owned(self) -> Self::OwnedVariant; + + /// Returns an owned instance of a reference. + fn clone_owned(&self) -> Self::OwnedVariant; + + /// Returns an owned instance or a reference to one. + fn cow_owned<'b>(self) -> MyCow<'b, Self::OwnedVariant> + where + Self: 'b; + + /// Returns an owned instance or a reference to one. + fn ref_cow_owned<'b>(&'b self) -> MyCow<'b, Self::OwnedVariant> + where + Self: 'b; +} + +impl<'a, X: Ownable> Ownable for &'a X { + type OwnedVariant = X::OwnedVariant; + + #[inline] + fn into_owned(self) -> Self::OwnedVariant { + X::clone_owned(self) + } + + #[inline] + fn clone_owned(&self) -> Self::OwnedVariant { + X::clone_owned(self) + } + #[inline] - pub fn into_owned(self) -> X where X : Clone { + fn cow_owned<'b>(self) -> MyCow<'b, Self::OwnedVariant> + where + Self: 'b, + { + X::ref_cow_owned(self) + } + + fn ref_cow_owned<'b>(&self) -> MyCow<'b, Self::OwnedVariant> + where + Self: 'b, + { + X::ref_cow_owned(self) + } +} + +impl<'a, X: Ownable> Ownable for &'a mut X { + type OwnedVariant = X::OwnedVariant; + + #[inline] + fn into_owned(self) -> Self::OwnedVariant { + X::clone_owned(self) + } + + #[inline] + fn clone_owned(&self) -> Self::OwnedVariant { + X::clone_owned(self) + } + + #[inline] + fn cow_owned<'b>(self) -> MyCow<'b, Self::OwnedVariant> + where + Self: 'b, + { + X::ref_cow_owned(self) + } + + fn ref_cow_owned<'b>(&'b self) -> MyCow<'b, Self::OwnedVariant> + where + Self: 'b, + { + X::ref_cow_owned(self) + } +} + +impl<'a, A, B> Ownable for EitherDecomp +where + A: Ownable, + B: Ownable, +{ + type OwnedVariant = A::OwnedVariant; + + #[inline] + fn into_owned(self) -> Self::OwnedVariant { match self { - EitherDecomp::Owned(x) => x, - EitherDecomp::Borrowed(x) => x.clone(), + EitherDecomp::Owned(a) => A::into_owned(a), + EitherDecomp::Borrowed(b) => B::into_owned(b), + } + } + + #[inline] + fn clone_owned(&self) -> Self::OwnedVariant { + match self { + EitherDecomp::Owned(a) => A::clone_owned(a), + EitherDecomp::Borrowed(b) => B::clone_owned(b), + } + } + + #[inline] + fn cow_owned<'b>(self) -> MyCow<'b, Self::OwnedVariant> + where + A: 'b, + B: 'b, + { + match self { + EitherDecomp::Owned(a) => A::cow_owned(a), + EitherDecomp::Borrowed(b) => B::cow_owned(b), + } + } + + #[inline] + fn ref_cow_owned<'b>(&'b self) -> MyCow<'b, Self::OwnedVariant> + where + Self: 'b, + { + match self { + EitherDecomp::Owned(a) => A::ref_cow_owned(a), + EitherDecomp::Borrowed(b) => B::ref_cow_owned(b), } } } +#[macro_export] +macro_rules! self_ownable { + ($type:ty where $($qual:tt)*) => { + impl<$($qual)*> $crate::instance::Ownable for $type { + type OwnedVariant = Self; + + #[inline] + fn into_owned(self) -> Self::OwnedVariant { + self + } + + fn clone_owned(&self) -> Self::OwnedVariant { + self.clone() + } + + fn cow_owned<'b>(self) -> $crate::instance::MyCow<'b, Self::OwnedVariant> + where + Self: 'b, + { + $crate::instance::MyCow::Owned(self) + } + + fn ref_cow_owned<'b>(&'b self) -> $crate::instance::MyCow<'b, Self::OwnedVariant> + where + Self: 'b, + { + $crate::instance::MyCow::Borrowed(self) + } + } + }; +} + +self_ownable!(Vec where T : Clone); + +impl<'a, T: Clone> Ownable for &'a [T] { + type OwnedVariant = Vec; + + #[inline] + fn into_owned(self) -> Self::OwnedVariant { + Vec::from(self) + } + + #[inline] + fn clone_owned(&self) -> Self::OwnedVariant { + Vec::from(*self) + } + + #[inline] + fn cow_owned<'b>(self) -> MyCow<'b, Self::OwnedVariant> + where + Self: 'b, + { + MyCow::Owned(Vec::from(self)) + } + + fn ref_cow_owned<'b>(&'b self) -> MyCow<'b, Self::OwnedVariant> + where + Self: 'b, + { + MyCow::Owned(Vec::from(*self)) + } +} + /// Trait for abitrary mathematical spaces. -pub trait Space : Instance { +pub trait Space: Ownable + Sized { + /// Principal, typically owned realisation of the space. + type Principal: ClosedSpace; + /// Default decomposition for the space - type Decomp : Decomposition; + type Decomp: Decomposition; +} + +mod private { + pub trait Sealed {} } +/// Helper trait for working with own types. +pub trait Owned: Ownable + private::Sealed {} +impl> private::Sealed for X {} +impl> Owned for X {} + +/// Helper trait for working with closed spaces, operations in which should +/// return members of the same space +pub trait ClosedSpace: Space + Owned + Instance {} +impl + Owned + Instance> ClosedSpace for X {} + #[macro_export] macro_rules! impl_basic_space { - ($($type:ty)*) => { $( - impl $crate::instance::Space for $type { + ($($type:ty)*) => { + $( $crate::impl_basic_space!($type where ); )* + }; + ($type:ty where $($where:tt)*) => { + impl<$($where)*> $crate::instance::Space for $type { + type Principal = Self; type Decomp = $crate::instance::BasicDecomposition; } - )* }; - ($type:ty where $($where:tt)*) => { - impl<$($where)*> $crate::instance::Space for $type { - type Decomp = $crate::instance::BasicDecomposition; + + impl<$($where)*> $crate::instance::Ownable for $type { + type OwnedVariant = Self; + + #[inline] + fn into_owned(self) -> Self::OwnedVariant { + self + } + + #[inline] + fn clone_owned(&self) -> Self::OwnedVariant { + *self + } + + #[inline] + fn cow_owned<'b>(self) -> MyCow<'b, Self::OwnedVariant> where Self : 'b { + MyCow::Owned(self) + } + + #[inline] + fn ref_cow_owned<'b>(&self) -> MyCow<'b, Self::OwnedVariant> where Self : 'b { + MyCow::Owned(*self) + } } }; } @@ -58,17 +278,21 @@ f32 f64); /// Marker type for decompositions to be used with [`Instance`]. -pub trait Decomposition : Sized { +pub trait Decomposition: Sized { /// Possibly owned form of the decomposition - type Decomposition<'b> : Instance where X : 'b; + type Decomposition<'b>: Instance + where + X: 'b; /// Unlikely owned form of the decomposition. /// Type for a lightweight intermediate conversion that does not own the original variable. /// Usually this is just a reference, but may also be a lightweight structure that /// contains references; see the implementation for [`crate::direct_product::Pair`]. - type Reference<'b> : Instance + Copy where X : 'b; + type Reference<'b>: Instance + Copy + where + X: 'b; - /// Left the lightweight reference type into a full decomposition type. - fn lift<'b>(r : Self::Reference<'b>) -> Self::Decomposition<'b>; + /// Lift the lightweight reference type into a full decomposition type. + fn lift<'b>(r: Self::Reference<'b>) -> Self::Decomposition<'b>; } /// Most common [`Decomposition`] (into `Either`) that allows working with owned @@ -76,12 +300,18 @@ #[derive(Copy, Clone, Debug)] pub struct BasicDecomposition; -impl Decomposition for BasicDecomposition { - type Decomposition<'b> = MyCow<'b, X> where X : 'b; - type Reference<'b> = &'b X where X : 'b; +impl Decomposition for BasicDecomposition { + type Decomposition<'b> + = MyCow<'b, X> + where + X: 'b; + type Reference<'b> + = &'b X + where + X: 'b; #[inline] - fn lift<'b>(r : Self::Reference<'b>) -> Self::Decomposition<'b> { + fn lift<'b>(r: Self::Reference<'b>) -> Self::Decomposition<'b> { MyCow::Borrowed(r) } } @@ -91,192 +321,225 @@ /// generalises [`std::borrow::ToOwned`], [`std::borrow::Borrow`], and [`std::borrow::Cow`]. /// /// This is used, for example, by [`crate::mapping::Mapping::apply`]. -pub trait Instance::Decomp> : Sized where D : Decomposition { - /// Decomposes self according to `decomposer`. - fn decompose<'b>(self) -> D::Decomposition<'b> - where Self : 'b, X : 'b; - - /// Returns a lightweight instance of `self`. - fn ref_instance(&self) -> D::Reference<'_>; +pub trait Instance::Decomp>: Sized +where + X: Space, + D: Decomposition, +{ + /// Decomposes self according to `decomposer`, and evaluate `f` on the result. + /// Consumes self. + #[inline] + fn eval_decompose<'b, R>(self, f: impl FnOnce(D::Decomposition<'b>) -> R) -> R + where + X: 'b, + Self: 'b, + { + f(self.decompose()) + } + + /// Does a light decomposition of self `decomposer`, and evaluates `f` on the result. + /// Does not consume self. + fn eval_ref<'b, R>(&'b self, f: impl FnOnce(D::Reference<'b>) -> R) -> R + where + X: 'b, + Self: 'b; /// Returns an owned instance of `X`, cloning or converting non-true instances when necessary. - fn own(self) -> X; + fn own(self) -> X::Principal; - // ************** automatically implemented methods below from here ************** + fn decompose<'b>(self) -> D::Decomposition<'b> + where + Self: 'b; /// Returns an owned instance or reference to `X`, converting non-true instances when necessary. /// /// Default implementation uses [`Self::own`]. Consumes the input. - fn cow<'b>(self) -> MyCow<'b, X> where Self : 'b { - MyCow::Owned(self.own()) - } - + fn cow<'b>(self) -> MyCow<'b, X::Principal> + where + Self: 'b; + #[inline] /// Evaluates `f` on a reference to self. /// /// Default implementation uses [`Self::cow`]. Consumes the input. - fn eval<'b, R>(self, f : impl FnOnce(&X) -> R) -> R - where X : 'b, Self : 'b + fn eval<'b, R>(self, f: impl FnOnce(&X::Principal) -> R) -> R + where + X: 'b, + Self: 'b, { f(&*self.cow()) } - - #[inline] - /// Evaluates `f` or `g` depending on whether a reference or owned value is available. - /// - /// Default implementation uses [`Self::cow`]. Consumes the input. - fn either<'b, R>( - self, - f : impl FnOnce(X) -> R, - g : impl FnOnce(&X) -> R - ) -> R - where Self : 'b - { - match self.cow() { - EitherDecomp::Owned(x) => f(x), - EitherDecomp::Borrowed(x) => g(x), - } - } } - -impl Instance for X { +impl Instance for X { #[inline] - fn decompose<'b>(self) -> >::Decomposition<'b> - where Self : 'b, X : 'b + fn eval_ref<'b, R>(&'b self, f: impl FnOnce(&'b X) -> R) -> R + where + X: 'b, + Self: 'b, { - MyCow::Owned(self) + f(self) } #[inline] - fn own(self) -> X { - self + fn own(self) -> X::Principal { + self.into_owned() + } + + #[inline] + fn cow<'b>(self) -> MyCow<'b, X::Principal> + where + Self: 'b, + { + self.cow_owned() } #[inline] - fn cow<'b>(self) -> MyCow<'b, X> where Self : 'b { + fn decompose<'b>(self) -> MyCow<'b, X> + where + Self: 'b, + { MyCow::Owned(self) } +} + +impl<'a, X: Space> Instance for &'a X { + #[inline] + fn eval_ref<'b, R>(&'b self, f: impl FnOnce(&'b X) -> R) -> R + where + X: 'b, + Self: 'b, + { + f(*self) + } #[inline] - fn ref_instance(&self) -> >::Reference<'_> { - self + fn own(self) -> X::Principal { + self.into_owned() } -} -impl<'a, X : Space + Clone> Instance for &'a X { #[inline] - fn decompose<'b>(self) -> >::Decomposition<'b> - where Self : 'b, X : 'b + fn cow<'b>(self) -> MyCow<'b, X::Principal> + where + Self: 'b, + { + self.cow_owned() + } + + #[inline] + fn decompose<'b>(self) -> MyCow<'b, X> + where + Self: 'b, { MyCow::Borrowed(self) } +} +impl<'a, X: Space> Instance for &'a mut X { #[inline] - fn own(self) -> X { - self.clone() - } - - #[inline] - fn cow<'b>(self) -> MyCow<'b, X> where Self : 'b { - MyCow::Borrowed(self) + fn eval_ref<'b, R>(&'b self, f: impl FnOnce(&'b X) -> R) -> R + where + X: 'b, + Self: 'b, + { + f(*self) } #[inline] - fn ref_instance(&self) -> >::Reference<'_> { - *self - } -} - -impl<'a, X : Space + Clone> Instance for &'a mut X { - #[inline] - fn decompose<'b>(self) -> >::Decomposition<'b> - where Self : 'b, X : 'b - { - EitherDecomp::Borrowed(self) + fn own(self) -> X::Principal { + self.into_owned() } #[inline] - fn own(self) -> X { - self.clone() + fn cow<'b>(self) -> MyCow<'b, X::Principal> + where + Self: 'b, + { + self.cow_owned() } #[inline] - fn cow<'b>(self) -> MyCow<'b, X> where Self : 'b, X : Clone { - EitherDecomp::Borrowed(self) - } - - #[inline] - fn ref_instance(&self) -> >::Reference<'_> { - *self + fn decompose<'b>(self) -> MyCow<'b, X> + where + Self: 'b, + { + MyCow::Borrowed(self) } } -impl<'a, X : Space + Clone> Instance for MyCow<'a, X> { - +impl<'a, X: Space> Instance for MyCow<'a, X> { #[inline] - fn decompose<'b>(self) -> >::Decomposition<'b> - where Self : 'b, X : 'b + fn eval_ref<'b, R>(&'b self, f: impl FnOnce(&'b X) -> R) -> R + where + X: 'b, + Self: 'b, { - self - } - - #[inline] - fn own(self) -> X { match self { - MyCow::Borrowed(a) => a.own(), - MyCow::Owned(b) => b.own() + MyCow::Borrowed(a) => f(a), + MyCow::Owned(b) => f(&b), } } #[inline] - fn cow<'b>(self) -> MyCow<'b, X> where Self : 'b { - match self { - MyCow::Borrowed(a) => a.cow(), - MyCow::Owned(b) => b.cow() - } + fn own(self) -> X::Principal { + self.into_owned() } #[inline] - fn ref_instance(&self) -> >::Reference<'_> { - match self { - MyCow::Borrowed(a) => a, - MyCow::Owned(b) => &b, - } + fn cow<'b>(self) -> MyCow<'b, X::Principal> + where + Self: 'b, + { + self.cow_owned() } -} - -/// Marker type for mutable decompositions to be used with [`InstanceMut`]. -pub trait DecompositionMut : Sized { - type ReferenceMut<'b> : InstanceMut where X : 'b; -} - -/// Helper trait for functions to work with mutable references. -pub trait InstanceMut::Decomp> : Sized where D : DecompositionMut { - /// Returns a mutable decomposition of self. - fn ref_instance_mut(&mut self) -> D::ReferenceMut<'_>; -} - -impl DecompositionMut for BasicDecomposition { - type ReferenceMut<'b> = &'b mut X where X : 'b; -} - -/// This impl may seem pointless, but allows throwaway mutable scratch variables -impl<'a, X : Space> InstanceMut for X { #[inline] - fn ref_instance_mut(&mut self) - -> >::ReferenceMut<'_> + fn decompose<'b>(self) -> MyCow<'b, X> + where + Self: 'b, { self } } -impl<'a, X : Space> InstanceMut for &'a mut X { +/// Marker type for mutable decompositions to be used with [`InstanceMut`]. +pub trait DecompositionMut: Sized { + type ReferenceMut<'b>: InstanceMut + where + X: 'b; +} + +/// Helper trait for functions to work with mutable references. +pub trait InstanceMut::Decomp>: Sized +where + D: DecompositionMut, +{ + /// Returns a mutable decomposition of self. + fn ref_instance_mut(&mut self) -> D::ReferenceMut<'_>; +} + +impl DecompositionMut for BasicDecomposition { + type ReferenceMut<'b> + = &'b mut X + where + X: 'b; +} + +/// This impl may seem pointless, but allows throwaway mutable scratch variables +impl<'a, X: Space> InstanceMut for X { #[inline] - fn ref_instance_mut(&mut self) - -> >::ReferenceMut<'_> - { + fn ref_instance_mut( + &mut self, + ) -> >::ReferenceMut<'_> { self } } + +impl<'a, X: Space> InstanceMut for &'a mut X { + #[inline] + fn ref_instance_mut( + &mut self, + ) -> >::ReferenceMut<'_> { + self + } +} diff -r 1f19c6bbf07b -r 3868555d135c src/iterate.rs --- a/src/iterate.rs Sun Apr 27 20:29:43 2025 -0500 +++ b/src/iterate.rs Fri May 15 14:46:30 2026 -0500 @@ -56,40 +56,54 @@ ``` */ -use colored::{Colorize, ColoredString}; +use crate::logger::*; +use crate::types::*; +use colored::{ColoredString, Colorize}; use core::fmt::Debug; -use serde::{Serialize, Deserialize}; use cpu_time::ProcessTime; +use serde::{Deserialize, Serialize}; +use std::cell::RefCell; +use std::error::Error; use std::marker::PhantomData; +use std::rc::Rc; use std::time::Duration; -use std::error::Error; -use std::cell::RefCell; -use std::rc::Rc; -use crate::types::*; -use crate::logger::*; /// Create the displayed presentation for log items. -pub trait LogRepr : Debug { - fn logrepr(&self) -> ColoredString { format!("« {self:?} »").as_str().into() } +pub trait LogRepr: Debug { + fn logrepr(&self) -> ColoredString { + format!("« {self:?} »").as_str().into() + } } impl LogRepr for str { - fn logrepr(&self) -> ColoredString { self.into() } + fn logrepr(&self) -> ColoredString { + self.into() + } } impl LogRepr for String { - fn logrepr(&self) -> ColoredString { self.as_str().into() } + fn logrepr(&self) -> ColoredString { + self.as_str().into() + } } -impl LogRepr for T where T : Num { - fn logrepr(&self) -> ColoredString { format!("J={self}").as_str().into() } +impl LogRepr for T +where + T: Num, +{ + fn logrepr(&self) -> ColoredString { + format!("J={self}").as_str().into() + } } -impl LogRepr for Option where V : LogRepr { +impl LogRepr for Option +where + V: LogRepr, +{ fn logrepr(&self) -> ColoredString { match self { - None => { "===missing value===".red() } - Some(v) => { v.logrepr() } + None => "===missing value===".red(), + Some(v) => v.logrepr(), } } } @@ -97,29 +111,33 @@ /// Helper struct for returning results annotated with an additional string to /// [`if_verbose`][AlgIteratorState::if_verbose]. The [`LogRepr`] implementation will /// display that string when so decided by the specific [`AlgIterator`] in use. -#[derive(Debug,Clone)] +#[derive(Debug, Clone)] pub struct Annotated(pub F, pub String); -impl LogRepr for Annotated where V : LogRepr { +impl LogRepr for Annotated +where + V: LogRepr, +{ fn logrepr(&self) -> ColoredString { - format!("{}\t| {}", self.0.logrepr(), self.1).as_str().into() + format!("{}\t| {}", self.0.logrepr(), self.1) + .as_str() + .into() } } - /// Basic log item. #[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)] pub struct LogItem { - pub iter : usize, + pub iter: usize, // This causes [`csv`] to crash. //#[serde(flatten)] - pub data : V + pub data: V, } impl LogItem { /// Creates a new log item - fn new(iter : usize, data : V) -> Self { - LogItem{ iter, data } + fn new(iter: usize, data: V) -> Self { + LogItem { iter, data } } } @@ -127,7 +145,7 @@ /// /// This is the parameter obtained by the closure passed to [`AlgIterator::iterate`] or /// [`AlgIteratorFactory::iterate`]. -pub trait AlgIteratorState : Sized { +pub trait AlgIteratorState: Sized { /// Call `call_objective` if this is a verbose iteration. /// /// Verbosity depends on the [`AlgIterator`] that produced this state. @@ -135,7 +153,7 @@ /// The closure `calc_objective` should return an arbitrary value of type `V`, to be inserted /// into the log, or whatever is deemed by the [`AlgIterator`]. For usage instructions see the /// [module documentation][self]. - fn if_verbose(self, calc_objective : impl FnOnce() -> V) -> Step; + fn if_verbose(self, calc_objective: impl FnOnce() -> V) -> Step; /// Returns the current iteration count. fn iteration(&self) -> usize; @@ -146,7 +164,7 @@ /// Result of a step of an [`AlgIterator`] #[derive(Debug, Serialize)] -pub enum Step { +pub enum Step { /// Iteration should be terminated Terminated, /// Iteration should be terminated due to failure @@ -157,19 +175,19 @@ Result(V, S), } -impl Step { +impl Step { /// Maps the value contained within the `Step`, if any, by the closure `f`. - pub fn map(self, mut f : impl FnMut(V) -> U) -> Step { + pub fn map(self, mut f: impl FnMut(V) -> U) -> Step { match self { - Step::Result(v, s) => Step::Result(f(v), s), + Step::Result(v, s) => Step::Result(f(v), s), Step::Failure(e) => Step::Failure(e), - Step::Quiet => Step::Quiet, + Step::Quiet => Step::Quiet, Step::Terminated => Step::Terminated, } } } -impl Default for Step { +impl Default for Step { fn default() -> Self { Step::Quiet } @@ -179,29 +197,34 @@ /// /// Typically not accessed directly, but transparently produced by an [`AlgIteratorFactory`]. /// Every [`AlgIteratorFactory`] has to implement a corresponding `AlgIterator`. -pub trait AlgIterator : Sized { +pub trait AlgIterator: Sized { /// The state type - type State : AlgIteratorState; + type State: AlgIteratorState; /// The output type for [`Self::poststep`] and [`Self::step`]. type Output; /// The input type for [`Self::poststep`]. type Input; /// Advance the iterator, performing `step_fn` with the state - fn step(&mut self, step_fn : &mut F) -> Step - where F : FnMut(Self::State) -> Step, - E : Error { - self.prestep().map_or(Step::Terminated, - |state| self.poststep(step_fn(state))) + fn step(&mut self, step_fn: &mut F) -> Step + where + F: FnMut(Self::State) -> Step, + E: Error, + { + self.prestep() + .map_or(Step::Terminated, |state| self.poststep(step_fn(state))) } /// Initial stage of advancing the iterator, before the actual step fn prestep(&mut self) -> Option; /// Handle step result - fn poststep(&mut self, result : Step) - -> Step - where E : Error; + fn poststep( + &mut self, + result: Step, + ) -> Step + where + E: Error; /// Return current iteration count. fn iteration(&self) -> usize { @@ -213,16 +236,18 @@ /// Iterate the `AlgIterator` until termination, erforming `step_fn` on each step. /// - /// Returns either `()` or an error if the step closure terminated in [`Step::Failure´]. + /// Returns either `()` or an error if the step closure terminated in [`Step::Failure`]. #[inline] - fn iterate(&mut self, mut step_fn : F) -> Result<(), E> - where F : FnMut(Self::State) -> Step, - E : Error { + fn iterate(&mut self, mut step_fn: F) -> Result<(), E> + where + F: FnMut(Self::State) -> Step, + E: Error, + { loop { match self.step(&mut step_fn) { Step::Terminated => return Ok(()), Step::Failure(e) => return Err(e), - _ => {}, + _ => {} } } } @@ -231,12 +256,12 @@ /// A factory for producing an [`AlgIterator`]. /// /// For usage instructions see the [module documentation][self]. -pub trait AlgIteratorFactory : Sized { - type Iter : AlgIterator; +pub trait AlgIteratorFactory: Sized { + type Iter: AlgIterator; /// The state type of the corresponding [`AlgIterator`]. /// A reference to this is passed to the closures passed to methods such as [`Self::iterate`]. - type State : AlgIteratorState; + type State: AlgIteratorState; /// The output type of the corresponding [`AlgIterator`]. /// This is the output of the closures passed to methods such as [`Self::iterate`] after /// mappings performed by each [`AlgIterator`] implementation. @@ -254,9 +279,11 @@ /// /// This method is equivalent to [`Self::prepare`] followed by [`AlgIterator::iterate`]. #[inline] - fn iterate_fallible(self, step : F) -> Result<(), E> - where F : FnMut(Self::State) -> Step, - E : Error { + fn iterate_fallible(self, step: F) -> Result<(), E> + where + F: FnMut(Self::State) -> Step, + E: Error, + { self.prepare().iterate(step) } @@ -271,8 +298,10 @@ /// This method is equivalent to [`Self::prepare`] followed by [`AlgIterator::iterate`] /// with the error type `E=`[`std::convert::Infallible`]. #[inline] - fn iterate(self, step : F) - where F : FnMut(Self::State) -> Step { + fn iterate(self, step: F) + where + F: FnMut(Self::State) -> Step, + { self.iterate_fallible(step).unwrap_or_default() } @@ -288,13 +317,16 @@ /// /// For usage instructions see the [module documentation][self]. #[inline] - fn iterate_data_fallible(self, mut datasource : I, mut step : F) - -> Result<(), E> - where F : FnMut(Self::State, D) -> Step, - I : Iterator, - E : Error { + fn iterate_data_fallible(self, mut datasource: I, mut step: F) -> Result<(), E> + where + F: FnMut(Self::State, D) -> Step, + I: Iterator, + E: Error, + { self.prepare().iterate(move |state| { - datasource.next().map_or(Step::Terminated, |d| step(state, d)) + datasource + .next() + .map_or(Step::Terminated, |d| step(state, d)) }) } @@ -309,10 +341,13 @@ /// /// For usage instructions see the [module documentation][self]. #[inline] - fn iterate_data(self, datasource : I, step : F) - where F : FnMut(Self::State, D) -> Step, - I : Iterator { - self.iterate_data_fallible(datasource, step).unwrap_or_default() + fn iterate_data(self, datasource: I, step: F) + where + F: FnMut(Self::State, D) -> Step, + I: Iterator, + { + self.iterate_data_fallible(datasource, step) + .unwrap_or_default() } // fn make_iterate<'own>(self) @@ -345,79 +380,87 @@ /// }) /// }) /// ``` - fn into_log<'log>(self, logger : &'log mut Logger) - -> LoggingIteratorFactory<'log, Self::Output, Self> - where Self : Sized { - LoggingIteratorFactory { - base_options : self, - logger, - } + fn into_log<'log>( + self, + logger: &'log mut Logger, + ) -> LoggingIteratorFactory<'log, Self::Output, Self> + where + Self: Sized, + { + LoggingIteratorFactory { base_options: self, logger } } /// Map the output of the iterator produced by the factory. /// /// Returns a new factory. - fn mapped(self, map : G) - -> MappingIteratorFactory - where Self : Sized, - G : Fn(usize, Self::Output) -> U { - MappingIteratorFactory { - base_options : self, - map - } + fn mapped(self, map: G) -> MappingIteratorFactory + where + Self: Sized, + G: Fn(usize, Self::Output) -> U, + { + MappingIteratorFactory { base_options: self, map } } /// Adds iteration number to the output. /// /// Returns a new factory. /// Typically followed by [`Self::into_log`]. - fn with_iteration_number(self) - -> MappingIteratorFactory LogItem, Self> - where Self : Sized { + fn with_iteration_number( + self, + ) -> MappingIteratorFactory LogItem, Self> + where + Self: Sized, + { self.mapped(LogItem::new) } /// Add timing to the iterator produced by the factory. fn timed(self) -> TimingIteratorFactory - where Self : Sized { + where + Self: Sized, + { TimingIteratorFactory(self) } /// Add value stopping threshold to the iterator produce by the factory - fn stop_target(self, target : Self::Output) -> ValueIteratorFactory - where Self : Sized, - Self::Output : Num { - ValueIteratorFactory { base_options : self, target : target } + fn stop_target(self, target: Self::Output) -> ValueIteratorFactory + where + Self: Sized, + Self::Output: Num, + { + ValueIteratorFactory { base_options: self, target: target } } /// Add stall stopping to the iterator produce by the factory - fn stop_stall(self, stall : Self::Output) -> StallIteratorFactory - where Self : Sized, - Self::Output : Num { - StallIteratorFactory { base_options : self, stall : stall } + fn stop_stall(self, stall: Self::Output) -> StallIteratorFactory + where + Self: Sized, + Self::Output: Num, + { + StallIteratorFactory { base_options: self, stall: stall } } /// Is the iterator quiet, i.e., on-verbose? - fn is_quiet(&self) -> bool { false } + fn is_quiet(&self) -> bool { + false + } /// Returns an an [`std::iter::Iterator`] that can be used in a `for`-loop. fn iter(self) -> AlgIteratorIterator { - AlgIteratorIterator { - algi : Rc::new(RefCell::new(self.prepare())), - } + AlgIteratorIterator { algi: Rc::new(RefCell::new(self.prepare())) } } /// Returns an an [`std::iter::Iterator`] that can be used in a `for`-loop, /// also inputting an initial iteration status calculated by `f` if needed. - fn iter_init(self, f : impl FnOnce() -> ::Input) - -> AlgIteratorIterator { + fn iter_init( + self, + f: impl FnOnce() -> ::Input, + ) -> AlgIteratorIterator { let mut i = self.prepare(); let st = i.state(); - let step : Step<::Input, Self::State> = st.if_verbose(f); + let step: Step<::Input, Self::State> = st.if_verbose(f); i.poststep(step); - AlgIteratorIterator { - algi : Rc::new(RefCell::new(i)), - } + AlgIteratorIterator { algi: Rc::new(RefCell::new(i)) } } } @@ -431,12 +474,12 @@ #[derive(Clone, Copy, Debug, Serialize, Deserialize, Eq, PartialEq)] pub struct AlgIteratorOptions { /// Maximum number of iterations - pub max_iter : usize, + pub max_iter: usize, /// Number of iterations between verbose iterations that display state. - pub verbose_iter : Verbose, + pub verbose_iter: Verbose, /// Whether verbose iterations are displayed, or just passed onwards to a containing /// `AlgIterator`. - pub quiet : bool, + pub quiet: bool, } #[derive(Clone, Copy, Debug, Serialize, Deserialize, Eq, PartialEq)] @@ -444,7 +487,7 @@ /// Be verbose every $n$ iterations. Every(usize), /// Be verbose every $n$ iterations and initial $m$ iterations. - EveryAndInitial{ every : usize, initial : usize }, + EveryAndInitial { every: usize, initial: usize }, /// Be verbose if iteration number $n$ divides by $b^{\text{floor}(\log_b(n))}$, where /// $b$ is indicated logarithmic base. So, with $b=10$, /// * every iteration for first 10 iterations, @@ -455,24 +498,22 @@ /// is the given `cap`. For example, with `base=10` and `cap=2`, the first ten iterations /// will be output, then every tenth iteration, and after 100 iterations, every 100th iteration, /// without further logarithmic progression. - LogarithmicCap{ base : usize, cap : u32 }, + LogarithmicCap { base: usize, cap: u32 }, } impl Verbose { /// Indicates whether given iteration number is verbose - pub fn is_verbose(&self, iter : usize) -> bool { + pub fn is_verbose(&self, iter: usize) -> bool { match self { - &Verbose::Every(every) => { - every != 0 && iter % every == 0 - }, - &Verbose::EveryAndInitial{ every, initial } => { + &Verbose::Every(every) => every != 0 && iter % every == 0, + &Verbose::EveryAndInitial { every, initial } => { iter <= initial || (every != 0 && iter % every == 0) - }, + } &Verbose::Logarithmic(base) => { let every = base.pow((iter as float).log(base as float).floor() as u32); iter % every == 0 } - &Verbose::LogarithmicCap{base, cap} => { + &Verbose::LogarithmicCap { base, cap } => { let every = base.pow(((iter as float).log(base as float).floor() as u32).min(cap)); iter % every == 0 } @@ -482,41 +523,41 @@ impl Default for AlgIteratorOptions { fn default() -> AlgIteratorOptions { - AlgIteratorOptions{ - max_iter : 1000, - verbose_iter : Verbose::EveryAndInitial { every : 100, initial : 10 }, - quiet : false + AlgIteratorOptions { + max_iter: 1000, + verbose_iter: Verbose::EveryAndInitial { every: 100, initial: 10 }, + quiet: false, } } } /// State of a `BasicAlgIterator` -#[derive(Clone,Copy,Debug,Serialize,Eq,PartialEq)] +#[derive(Clone, Copy, Debug, Serialize, Eq, PartialEq)] pub struct BasicState { /// Current iteration - iter : usize, + iter: usize, /// Whether the iteration is verbose, i.e., results should be displayed. /// Requires `calc` to be `true`. - verbose : bool, + verbose: bool, /// Whether results should be calculated. - calc : bool, + calc: bool, /// Indicates whether the iteration is quiet - quiet : bool, + quiet: bool, } /// [`AlgIteratorFactory`] for [`BasicAlgIterator`] -#[derive(Clone,Debug)] +#[derive(Clone, Debug)] pub struct BasicAlgIteratorFactory { - options : AlgIteratorOptions, - _phantoms : PhantomData, + options: AlgIteratorOptions, + _phantoms: PhantomData, } /// The simplest [`AlgIterator`], created by [`BasicAlgIteratorFactory`] -#[derive(Clone,Debug)] +#[derive(Clone, Debug)] pub struct BasicAlgIterator { - options : AlgIteratorOptions, - iter : usize, - _phantoms : PhantomData, + options: AlgIteratorOptions, + iter: usize, + _phantoms: PhantomData, } impl AlgIteratorOptions { @@ -524,25 +565,20 @@ /// however, due to type inference issues, it may become convenient to instantiate /// it to a specific return type for the inner step function. This method does that. pub fn instantiate(&self) -> BasicAlgIteratorFactory { - BasicAlgIteratorFactory { - options : self.clone(), - _phantoms : PhantomData - } + BasicAlgIteratorFactory { options: self.clone(), _phantoms: PhantomData } } } impl AlgIteratorFactory for AlgIteratorOptions -where V : LogRepr { +where + V: LogRepr, +{ type State = BasicState; type Iter = BasicAlgIterator; type Output = V; fn prepare(self) -> Self::Iter { - BasicAlgIterator{ - options : self, - iter : 0, - _phantoms : PhantomData, - } + BasicAlgIterator { options: self, iter: 0, _phantoms: PhantomData } } #[inline] @@ -552,17 +588,15 @@ } impl AlgIteratorFactory for BasicAlgIteratorFactory -where V : LogRepr { +where + V: LogRepr, +{ type State = BasicState; type Iter = BasicAlgIterator; type Output = V; fn prepare(self) -> Self::Iter { - BasicAlgIterator { - options : self.options, - iter : 0, - _phantoms : PhantomData - } + BasicAlgIterator { options: self.options, iter: 0, _phantoms: PhantomData } } #[inline] @@ -572,7 +606,9 @@ } impl AlgIterator for BasicAlgIterator -where V : LogRepr { +where + V: LogRepr, +{ type State = BasicState; type Output = V; type Input = V; @@ -587,14 +623,17 @@ } } - fn poststep(&mut self, res : Step) -> Step { + fn poststep(&mut self, res: Step) -> Step { if let Step::Result(ref val, ref state) = res { if state.verbose && !self.options.quiet { - println!("{}{}/{} {}{}", "".dimmed(), - state.iter, - self.options.max_iter, - val.logrepr(), - "".clear()); + println!( + "{}{}/{} {}{}", + "".dimmed(), + state.iter, + self.options.max_iter, + val.logrepr(), + "".clear() + ); } } res @@ -609,18 +648,13 @@ fn state(&self) -> BasicState { let iter = self.iter; let verbose = self.options.verbose_iter.is_verbose(iter); - BasicState { - iter : iter, - verbose : verbose, - calc : verbose, - quiet : self.options.quiet - } + BasicState { iter: iter, verbose: verbose, calc: verbose, quiet: self.options.quiet } } } impl AlgIteratorState for BasicState { #[inline] - fn if_verbose(self, calc_objective : impl FnOnce() -> V) -> Step { + fn if_verbose(self, calc_objective: impl FnOnce() -> V) -> Step { if self.calc { Step::Result(calc_objective(), self) } else { @@ -647,34 +681,35 @@ /// /// We define stall as $(v_{k+n}-v_k)/v_k ≤ θ$, where $n$ the distance between /// [`Step::Result`] iterations, and $θ$ is the provided `stall` parameter. -#[derive(Clone,Copy,Debug,Serialize,Eq,PartialEq)] -pub struct StallIteratorFactory { +#[derive(Clone, Copy, Debug, Serialize, Eq, PartialEq)] +pub struct StallIteratorFactory { /// An [`AlgIteratorFactory`] on which to build on - pub base_options : BaseFactory, + pub base_options: BaseFactory, /// Stalling threshold $θ$. - pub stall : U, + pub stall: U, } /// Iterator produced by [`StallIteratorFactory`]. -pub struct StallIterator { - base_iterator : BaseIterator, - stall : U, - previous_value : Option, +pub struct StallIterator { + base_iterator: BaseIterator, + stall: U, + previous_value: Option, } -impl AlgIteratorFactory -for StallIteratorFactory -where BaseFactory : AlgIteratorFactory, - U : SignedNum + PartialOrd { +impl AlgIteratorFactory for StallIteratorFactory +where + BaseFactory: AlgIteratorFactory, + U: SignedNum + PartialOrd, +{ type Iter = StallIterator; type State = BaseFactory::State; type Output = BaseFactory::Output; fn prepare(self) -> Self::Iter { StallIterator { - base_iterator : self.base_options.prepare(), - stall : self.stall, - previous_value : None, + base_iterator: self.base_options.prepare(), + stall: self.stall, + previous_value: None, } } @@ -683,10 +718,11 @@ } } -impl AlgIterator -for StallIterator -where BaseIterator : AlgIterator, - U : SignedNum + PartialOrd { +impl AlgIterator for StallIterator +where + BaseIterator: AlgIterator, + U: SignedNum + PartialOrd, +{ type State = BaseIterator::State; type Output = U; type Input = BaseIterator::Input; @@ -697,8 +733,10 @@ } #[inline] - fn poststep(&mut self, res : Step) -> Step - where E : Error { + fn poststep(&mut self, res: Step) -> Step + where + E: Error, + { match self.base_iterator.poststep(res) { Step::Result(nv, state) => { let previous_v = self.previous_value; @@ -707,7 +745,7 @@ Some(pv) if (nv - pv).abs() <= self.stall * pv.abs() => Step::Terminated, _ => Step::Result(nv, state), } - }, + } val => val, } } @@ -725,33 +763,31 @@ /// An [`AlgIteratorFactory`] for an [`AlgIterator`] that detect whether step function /// return value is less than `target`, and terminates if it is. -#[derive(Clone,Copy,Debug,Serialize,Eq,PartialEq)] -pub struct ValueIteratorFactory { +#[derive(Clone, Copy, Debug, Serialize, Eq, PartialEq)] +pub struct ValueIteratorFactory { /// An [`AlgIteratorFactory`] on which to build on - pub base_options : BaseFactory, + pub base_options: BaseFactory, /// Target value - pub target : U, + pub target: U, } /// Iterator produced by [`ValueIteratorFactory`]. -pub struct ValueIterator { - base_iterator : BaseIterator, - target : U, +pub struct ValueIterator { + base_iterator: BaseIterator, + target: U, } -impl AlgIteratorFactory -for ValueIteratorFactory -where BaseFactory : AlgIteratorFactory, - U : SignedNum + PartialOrd { +impl AlgIteratorFactory for ValueIteratorFactory +where + BaseFactory: AlgIteratorFactory, + U: SignedNum + PartialOrd, +{ type Iter = ValueIterator; type State = BaseFactory::State; type Output = BaseFactory::Output; fn prepare(self) -> Self::Iter { - ValueIterator { - base_iterator : self.base_options.prepare(), - target : self.target - } + ValueIterator { base_iterator: self.base_options.prepare(), target: self.target } } fn is_quiet(&self) -> bool { @@ -759,10 +795,11 @@ } } -impl AlgIterator -for ValueIterator -where BaseIterator : AlgIterator, - U : SignedNum + PartialOrd { +impl AlgIterator for ValueIterator +where + BaseIterator: AlgIterator, + U: SignedNum + PartialOrd, +{ type State = BaseIterator::State; type Output = U; type Input = BaseIterator::Input; @@ -773,15 +810,18 @@ } #[inline] - fn poststep(&mut self, res : Step) -> Step where E : Error{ + fn poststep(&mut self, res: Step) -> Step + where + E: Error, + { match self.base_iterator.poststep(res) { Step::Result(v, state) => { - if v <= self.target { + if v <= self.target { Step::Terminated - } else { + } else { Step::Result(v, state) - } - }, + } + } val => val, } } @@ -808,31 +848,29 @@ #[derive(Debug)] pub struct LoggingIteratorFactory<'log, U, BaseFactory> { /// Base [`AlgIteratorFactory`] on which to build - base_options : BaseFactory, + base_options: BaseFactory, /// The `Logger` to use. - logger : &'log mut Logger, + logger: &'log mut Logger, } /// Iterator produced by `LoggingIteratorFactory`. pub struct LoggingIterator<'log, U, BaseIterator> { - base_iterator : BaseIterator, - logger : &'log mut Logger, + base_iterator: BaseIterator, + logger: &'log mut Logger, } - impl<'log, V, BaseFactory> AlgIteratorFactory -for LoggingIteratorFactory<'log, BaseFactory::Output, BaseFactory> -where BaseFactory : AlgIteratorFactory, - BaseFactory::Output : 'log { + for LoggingIteratorFactory<'log, BaseFactory::Output, BaseFactory> +where + BaseFactory: AlgIteratorFactory, + BaseFactory::Output: 'log, +{ type State = BaseFactory::State; type Iter = LoggingIterator<'log, BaseFactory::Output, BaseFactory::Iter>; type Output = (); fn prepare(self) -> Self::Iter { - LoggingIterator { - base_iterator : self.base_options.prepare(), - logger : self.logger, - } + LoggingIterator { base_iterator: self.base_options.prepare(), logger: self.logger } } #[inline] @@ -841,10 +879,11 @@ } } -impl<'log, BaseIterator> AlgIterator -for LoggingIterator<'log, BaseIterator::Output, BaseIterator> -where BaseIterator : AlgIterator, - BaseIterator::Output : 'log { +impl<'log, BaseIterator> AlgIterator for LoggingIterator<'log, BaseIterator::Output, BaseIterator> +where + BaseIterator: AlgIterator, + BaseIterator::Output: 'log, +{ type State = BaseIterator::State; type Output = (); type Input = BaseIterator::Input; @@ -855,12 +894,15 @@ } #[inline] - fn poststep(&mut self, res : Step) -> Step<(), Self::State, E> where E : Error { + fn poststep(&mut self, res: Step) -> Step<(), Self::State, E> + where + E: Error, + { match self.base_iterator.poststep(res) { Step::Result(v, _) => { self.logger.log(v); Step::Quiet - }, + } Step::Quiet => Step::Quiet, Step::Terminated => Step::Terminated, Step::Failure(e) => Step::Failure(e), @@ -884,32 +926,29 @@ #[derive(Debug)] pub struct MappingIteratorFactory { /// Base [`AlgIteratorFactory`] on which to build - base_options : BaseFactory, + base_options: BaseFactory, /// A closure `G : Fn(usize, BaseFactory::Output) -> U` that gets the current iteration /// and the output of the base factory as input, and produces a new output. - map : G, + map: G, } /// [`AlgIterator`] produced by [`MappingIteratorFactory`]. pub struct MappingIterator { - base_iterator : BaseIterator, - map : G, + base_iterator: BaseIterator, + map: G, } - -impl AlgIteratorFactory -for MappingIteratorFactory -where BaseFactory : AlgIteratorFactory, - G : Fn(usize, BaseFactory::Output) -> U { +impl AlgIteratorFactory for MappingIteratorFactory +where + BaseFactory: AlgIteratorFactory, + G: Fn(usize, BaseFactory::Output) -> U, +{ type State = BaseFactory::State; type Iter = MappingIterator; type Output = U; fn prepare(self) -> Self::Iter { - MappingIterator { - base_iterator : self.base_options.prepare(), - map : self.map - } + MappingIterator { base_iterator: self.base_options.prepare(), map: self.map } } #[inline] @@ -918,10 +957,11 @@ } } -impl AlgIterator -for MappingIterator -where BaseIterator : AlgIterator, - G : Fn(usize, BaseIterator::Output) -> U { +impl AlgIterator for MappingIterator +where + BaseIterator: AlgIterator, + G: Fn(usize, BaseIterator::Output) -> U, +{ type State = BaseIterator::State; type Output = U; type Input = BaseIterator::Input; @@ -932,7 +972,13 @@ } #[inline] - fn poststep(&mut self, res : Step) -> Step where E : Error { + fn poststep( + &mut self, + res: Step, + ) -> Step + where + E: Error, + { match self.base_iterator.poststep(res) { Step::Result(v, state) => Step::Result((self.map)(self.iteration(), v), state), Step::Quiet => Step::Quiet, @@ -963,41 +1009,47 @@ /// Iterator produced by [`TimingIteratorFactory`] #[derive(Debug)] pub struct TimingIterator { - base_iterator : BaseIterator, - start_time : ProcessTime, + base_iterator: BaseIterator, + start_time: ProcessTime, } /// Data `U` with production time attached #[derive(Copy, Clone, Debug, Serialize)] pub struct Timed { /// CPU time taken - pub cpu_time : Duration, + pub cpu_time: Duration, /// Iteration number - pub iter : usize, + pub iter: usize, /// User data //#[serde(flatten)] - pub data : U + pub data: U, } -impl LogRepr for Timed where T : LogRepr { +impl LogRepr for Timed +where + T: LogRepr, +{ fn logrepr(&self) -> ColoredString { - format!("[{:.3}s] {}", self.cpu_time.as_secs_f64(), self.data.logrepr()).as_str().into() + format!( + "[{:.3}s] {}", + self.cpu_time.as_secs_f64(), + self.data.logrepr() + ) + .as_str() + .into() } } - -impl AlgIteratorFactory -for TimingIteratorFactory -where BaseFactory : AlgIteratorFactory { +impl AlgIteratorFactory for TimingIteratorFactory +where + BaseFactory: AlgIteratorFactory, +{ type State = BaseFactory::State; type Iter = TimingIterator; type Output = Timed; fn prepare(self) -> Self::Iter { - TimingIterator { - base_iterator : self.0.prepare(), - start_time : ProcessTime::now() - } + TimingIterator { base_iterator: self.0.prepare(), start_time: ProcessTime::now() } } #[inline] @@ -1006,9 +1058,10 @@ } } -impl AlgIterator -for TimingIterator -where BaseIterator : AlgIterator { +impl AlgIterator for TimingIterator +where + BaseIterator: AlgIterator, +{ type State = BaseIterator::State; type Output = Timed; type Input = BaseIterator::Input; @@ -1019,15 +1072,18 @@ } #[inline] - fn poststep(&mut self, res : Step) -> Step where E : Error { + fn poststep( + &mut self, + res: Step, + ) -> Step + where + E: Error, + { match self.base_iterator.poststep(res) { - Step::Result(data, state) => { - Step::Result(Timed{ - cpu_time : self.start_time.elapsed(), - iter : self.iteration(), - data - }, state) - }, + Step::Result(data, state) => Step::Result( + Timed { cpu_time: self.start_time.elapsed(), iter: self.iteration(), data }, + state, + ), Step::Quiet => Step::Quiet, Step::Terminated => Step::Terminated, Step::Failure(e) => Step::Failure(e), @@ -1049,35 +1105,34 @@ // New for-loop interface // -pub struct AlgIteratorIterator { - algi : Rc>, +pub struct AlgIteratorIterator { + algi: Rc>, } -pub struct AlgIteratorIteration { - state : I::State, - algi : Rc>, +pub struct AlgIteratorIteration { + state: I::State, + algi: Rc>, } -impl std::iter::Iterator for AlgIteratorIterator { +impl std::iter::Iterator for AlgIteratorIterator { type Item = AlgIteratorIteration; fn next(&mut self) -> Option { let algi = self.algi.clone(); - RefCell::borrow_mut(&self.algi).prestep().map(|state| AlgIteratorIteration { - state, - algi, - }) + RefCell::borrow_mut(&self.algi) + .prestep() + .map(|state| AlgIteratorIteration { state, algi }) } } /// Types of errors that may occur -#[derive(Debug,PartialEq,Eq)] +#[derive(Debug, PartialEq, Eq)] pub enum IterationError { /// [`AlgIteratorIteration::if_verbose_check`] is not called in iteration order. - ReportingOrderingError + ReportingOrderingError, } -impl AlgIteratorIteration { +impl AlgIteratorIteration { /// Call `call_objective` if this is a verbose iteration. /// /// Verbosity depends on the [`AlgIterator`] that produced this state. @@ -1089,22 +1144,24 @@ /// This function may panic if result reporting is not ordered correctly (an unlikely mistake /// if using this facility correctly). For a version that propagates errors, see /// [`Self::if_verbose_check`]. - pub fn if_verbose(self, calc_objective : impl FnOnce() -> I::Input) { + pub fn if_verbose(self, calc_objective: impl FnOnce() -> I::Input) { self.if_verbose_check(calc_objective).unwrap() } /// Version of [`Self::if_verbose`] that propagates errors instead of panicking. - pub fn if_verbose_check(self, calc_objective : impl FnOnce() -> I::Input) - -> Result<(), IterationError> { + pub fn if_verbose_check( + self, + calc_objective: impl FnOnce() -> I::Input, + ) -> Result<(), IterationError> { let mut algi = match RefCell::try_borrow_mut(&self.algi) { Err(_) => return Err(IterationError::ReportingOrderingError), - Ok(algi) => algi + Ok(algi) => algi, }; if self.state.iteration() != algi.iteration() { Err(IterationError::ReportingOrderingError) } else { - let res : Step - = self.state.if_verbose(calc_objective); + let res: Step = + self.state.if_verbose(calc_objective); algi.poststep(res); Ok(()) } @@ -1131,10 +1188,10 @@ use crate::logger::Logger; #[test] fn iteration() { - let options = AlgIteratorOptions{ - max_iter : 10, - verbose_iter : Verbose::Every(3), - .. Default::default() + let options = AlgIteratorOptions { + max_iter: 10, + verbose_iter: Verbose::Every(3), + ..Default::default() }; { @@ -1149,31 +1206,35 @@ { let mut start = 1 as int; let mut log = Logger::new(); - let factory = options.instantiate() - .with_iteration_number() - .into_log(&mut log); + let factory = options + .instantiate() + .with_iteration_number() + .into_log(&mut log); factory.iterate(|state| { start = start * 2; state.if_verbose(|| start) }); assert_eq!(start, (2 as int).pow(10)); - assert_eq!(log.data() - .iter() - .map(|LogItem{ data : v, iter : _ }| v.clone()) - .collect::>(), - (1..10).map(|i| (2 as int).pow(i)) - .skip(2) - .step_by(3) - .collect::>()) + assert_eq!( + log.data() + .iter() + .map(|LogItem { data: v, iter: _ }| v.clone()) + .collect::>(), + (1..10) + .map(|i| (2 as int).pow(i)) + .skip(2) + .step_by(3) + .collect::>() + ) } } #[test] fn iteration_for_loop() { - let options = AlgIteratorOptions{ - max_iter : 10, - verbose_iter : Verbose::Every(3), - .. Default::default() + let options = AlgIteratorOptions { + max_iter: 10, + verbose_iter: Verbose::Every(3), + ..Default::default() }; { @@ -1188,23 +1249,26 @@ { let mut start = 1 as int; let mut log = Logger::new(); - let factory = options.instantiate() - .with_iteration_number() - .into_log(&mut log); + let factory = options + .instantiate() + .with_iteration_number() + .into_log(&mut log); for state in factory.iter() { start = start * 2; state.if_verbose(|| start) } assert_eq!(start, (2 as int).pow(10)); - assert_eq!(log.data() - .iter() - .map(|LogItem{ data : v, iter : _ }| v.clone()) - .collect::>(), - (1..10).map(|i| (2 as int).pow(i)) - .skip(2) - .step_by(3) - .collect::>()) + assert_eq!( + log.data() + .iter() + .map(|LogItem { data: v, iter: _ }| v.clone()) + .collect::>(), + (1..10) + .map(|i| (2 as int).pow(i)) + .skip(2) + .step_by(3) + .collect::>() + ) } } - } diff -r 1f19c6bbf07b -r 3868555d135c src/lib.rs --- a/src/lib.rs Sun Apr 27 20:29:43 2025 -0500 +++ b/src/lib.rs Fri May 15 14:46:30 2026 -0500 @@ -12,15 +12,12 @@ nightly, feature( maybe_uninit_array_assume_init, - maybe_uninit_slice, float_minimum_maximum, get_mut_unchecked, - cow_is_borrowed ) )] #[macro_use] -pub(crate) mod metaprogramming; pub mod collection; pub mod error; pub mod euclidean; @@ -34,6 +31,7 @@ #[macro_use] pub mod loc; pub mod bisection_tree; +pub mod bounds; pub mod coefficients; pub mod convex; pub mod direct_product; diff -r 1f19c6bbf07b -r 3868555d135c src/lingrid.rs --- a/src/lingrid.rs Sun Apr 27 20:29:43 2025 -0500 +++ b/src/lingrid.rs Fri May 15 14:46:30 2026 -0500 @@ -15,46 +15,54 @@ iteration over the grid. Additional utility functions are in the [`Grid`] trait. */ -use crate::types::*; +use crate::iter::{RestartableIterator, StatefulIterator}; use crate::loc::Loc; +use crate::maputil::{map2, map4}; use crate::sets::Cube; -use crate::iter::{RestartableIterator, StatefulIterator}; -use crate::maputil::{map2, map4}; -use serde::{Serialize, Deserialize}; +use crate::types::*; +use serde::{Deserialize, Serialize}; // TODO: rewrite this using crate::sets::Cube. /// An abstraction of possibly multi-dimensional linear grids. /// /// `U` is typically a `F` for a `Float` `F` for one-dimensional grids created by `linspace`, -/// or [`Loc`]`` for multi-dimensional grids created by `lingrid`. +/// or [`Loc`]`` for multi-dimensional grids created by `lingrid`. /// In the first case `count` of nodes is `usize`, and in the second case `[usize; N]`. #[derive(Clone, Copy, Debug, Serialize, Deserialize, Eq, PartialEq)] pub struct LinSpace { - pub start : U, - pub end : U, - pub count : I, + pub start: U, + pub end: U, + pub count: I, } /// A `N`-dimensional interval divided into an indicated number of equally-spaced nodes along /// each dimension. #[allow(type_alias_bounds)] // Need it to access F::CompatibleSize. -pub type LinGrid = LinSpace, [usize; N]>; +pub type LinGrid = LinSpace, [usize; N]>; /// Creates a [`LinSpace`] on the real line. -pub fn linspace(start : F, end : F, count : usize) -> LinSpace { - LinSpace{ start : start, end : end, count : count } +pub fn linspace(start: F, end: F, count: usize) -> LinSpace { + LinSpace { + start: start, + end: end, + count: count, + } } /// Creates a multi-dimensional linear grid. /// /// The first and last point in each dimension are the boundaries of the corresponding /// dimensions of `cube`, and there are `count` nodes along each dimension. -pub fn lingrid( - cube : &Cube, - count : &[usize; N] -) -> LinSpace, [usize; N]> { - LinSpace{ start : cube.span_start(), end : cube.span_end(), count : *count } +pub fn lingrid( + cube: &Cube, + count: &[usize; N], +) -> LinSpace, [usize; N]> { + LinSpace { + start: cube.span_start(), + end: cube.span_end(), + count: *count, + } } /// Create a multi-dimensional linear grid with centered nodes. @@ -63,30 +71,33 @@ /// inside `cube`. Thus, if $w\_i$ is the width of the cube along dimension $i$, and $n_i$ the number /// of nodes, the width of the subcube along this dimension is $h_i = w\_i/(n\_i+1)$, and the first /// and last nodes are at a distance $h\_i/2$ from the closest boundary. -pub fn lingrid_centered( - cube : &Cube, - count : &[usize; N] -) -> LinSpace, [usize; N]> { +pub fn lingrid_centered( + cube: &Cube, + count: &[usize; N], +) -> LinSpace, [usize; N]> { let h_div_2 = map2(cube.width(), count, |w, &n| w / F::cast_from(2 * (n + 1))); let span_start = map2(cube.span_start(), &h_div_2, |a, &t| a + t).into(); - let span_end = map2(cube.span_end(), &h_div_2, |b, &t| b - t).into(); - LinSpace{ start : span_start, end : span_end, count : *count } + let span_end = map2(cube.span_end(), &h_div_2, |b, &t| b - t).into(); + LinSpace { + start: span_start, + end: span_end, + count: *count, + } } - /// Iterator over a `LinSpace`. #[derive(Clone, Debug)] pub struct LinSpaceIterator { - lingrid : LinSpace, - current : Option, + lingrid: LinSpace, + current: Option, } /// Abstraction of a linear grid over space `U` with multi-dimensional index set `I`. pub trait Grid { /// Converts a linear index `i` into a grid point. - fn entry_linear_unchecked(&self, i : usize) -> U; + fn entry_linear_unchecked(&self, i: usize) -> U; // Converts a multi-dimensional index `i` into a grid point. - fn entry_unchecked(&self, i : &I) -> U; + fn entry_unchecked(&self, i: &I) -> U; // fn entry(&self, i : I) -> Option } @@ -97,7 +108,7 @@ fn next_index(&mut self) -> Option; } -impl, I : Unsigned> Grid for LinSpace { +impl, I: Unsigned> Grid for LinSpace { /*fn entry(&self, i : I) -> Option { if i < self.count { Some(self.entry_unchecked(i)) @@ -107,37 +118,40 @@ }*/ #[inline] - fn entry_linear_unchecked(&self, i : usize) -> F { + fn entry_linear_unchecked(&self, i: usize) -> F { self.entry_unchecked(&I::cast_from(i)) } #[inline] - fn entry_unchecked(&self, i : &I) -> F { + fn entry_unchecked(&self, i: &I) -> F { let idx = F::cast_from(*i); - let scale = F::cast_from(self.count-I::ONE); - self.start + (self.end-self.start)*idx/scale + let scale = F::cast_from(self.count - I::ONE); + self.start + (self.end - self.start) * idx / scale } } -impl, I : Unsigned> GridIteration -for LinSpaceIterator { +impl, I: Unsigned> GridIteration for LinSpaceIterator { #[inline] fn next_index(&mut self) -> Option { match self.current { - None if I::ZERO < self.lingrid.count - => { self.current = Some(I::ZERO); self.current } - Some(v) if v+I::ONE < self.lingrid.count - => { self.current = Some(v+I::ONE); self.current } - _ - => { None } + None if I::ZERO < self.lingrid.count => { + self.current = Some(I::ZERO); + self.current + } + Some(v) if v + I::ONE < self.lingrid.count => { + self.current = Some(v + I::ONE); + self.current + } + _ => None, } } } -impl, I : Unsigned, const N : usize> Grid, [I; N]> -for LinSpace, [I; N]> { +impl, I: Unsigned, const N: usize> Grid, [I; N]> + for LinSpace, [I; N]> +{ #[inline] - fn entry_linear_unchecked(&self, i_ : usize) -> Loc { + fn entry_linear_unchecked(&self, i_: usize) -> Loc { let mut i = I::cast_from(i_); let mut tmp = [I::ZERO; N]; for k in 0..N { @@ -148,58 +162,66 @@ } #[inline] - fn entry_unchecked(&self, i : &[I; N]) -> Loc { - let LinSpace{ start, end, count } = self; + fn entry_unchecked(&self, i: &[I; N]) -> Loc { + let LinSpace { start, end, count } = self; map4(i, start, end, count, |&ik, &sk, &ek, &ck| { let idx = F::cast_from(ik); - let scale = F::cast_from(ck-I::ONE); + let scale = F::cast_from(ck - I::ONE); sk + (ek - sk) * idx / scale - }).into() + }) + .into() } } -impl, I : Unsigned, const N : usize> GridIteration, [I; N]> -for LinSpaceIterator, [I; N]> { - +impl, I: Unsigned, const N: usize> GridIteration, [I; N]> + for LinSpaceIterator, [I; N]> +{ #[inline] fn next_index(&mut self) -> Option<[I; N]> { match self.current { - None if self.lingrid.count.iter().all(|v| I::ZERO < *v) => { + None if self.lingrid.count.iter().all(|v| I::ZERO < *v) => { self.current = Some([I::ZERO; N]); self.current - }, + } Some(ref mut v) => { for k in 0..N { let a = v[k] + I::ONE; if a < self.lingrid.count[k] { v[k] = a; - return self.current + return self.current; } else { v[k] = I::ZERO; } } None - }, - _ => None + } + _ => None, } } } -impl IntoIterator for LinSpace -where LinSpace : Grid, - LinSpaceIterator : GridIteration { +impl IntoIterator for LinSpace +where + LinSpace: Grid, + LinSpaceIterator: GridIteration, +{ type Item = F; - type IntoIter = LinSpaceIterator; + type IntoIter = LinSpaceIterator; #[inline] fn into_iter(self) -> Self::IntoIter { - LinSpaceIterator { lingrid : self, current : None } + LinSpaceIterator { + lingrid: self, + current: None, + } } } -impl Iterator for LinSpaceIterator -where LinSpace : Grid, - LinSpaceIterator : GridIteration { +impl Iterator for LinSpaceIterator +where + LinSpace: Grid, + LinSpaceIterator: GridIteration, +{ type Item = F; #[inline] fn next(&mut self) -> Option { @@ -207,19 +229,24 @@ } } -impl StatefulIterator for LinSpaceIterator -where LinSpace : Grid, - LinSpaceIterator : GridIteration { +impl StatefulIterator for LinSpaceIterator +where + LinSpace: Grid, + LinSpaceIterator: GridIteration, +{ #[inline] fn current(&self) -> Option { - self.current.as_ref().map(|c| self.lingrid.entry_unchecked(c)) + self.current + .as_ref() + .map(|c| self.lingrid.entry_unchecked(c)) } } - -impl RestartableIterator for LinSpaceIterator -where LinSpace : Grid, - LinSpaceIterator : GridIteration { +impl RestartableIterator for LinSpaceIterator +where + LinSpace: Grid, + LinSpaceIterator: GridIteration, +{ #[inline] fn restart(&mut self) -> Option { self.current = None; diff -r 1f19c6bbf07b -r 3868555d135c src/linops.rs --- a/src/linops.rs Sun Apr 27 20:29:43 2025 -0500 +++ b/src/linops.rs Fri May 15 14:46:30 2026 -0500 @@ -2,81 +2,143 @@ Abstract linear operators. */ -use numeric_literals::replace_float_literals; -use std::marker::PhantomData; -use serde::Serialize; +use crate::direct_product::Pair; +use crate::error::DynResult; +use crate::euclidean::StaticEuclidean; +use crate::instance::Instance; +pub use crate::mapping::{ClosedSpace, Composition, DifferentiableImpl, Mapping, Space}; +use crate::norms::{HasDual, Linfinity, NormExponent, PairNorm, L1, L2}; use crate::types::*; -pub use crate::mapping::{Mapping, Space, Composition}; -use crate::direct_product::Pair; -use crate::instance::Instance; -use crate::norms::{NormExponent, PairNorm, L1, L2, Linfinity, Norm}; +use numeric_literals::replace_float_literals; +use serde::Serialize; +use std::marker::PhantomData; +use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}; /// Trait for linear operators on `X`. -pub trait Linear : Mapping -{ } +pub trait Linear: Mapping {} + +// impl> DifferentiableImpl for A { +// type Derivative = >::Codomain; + +// /// Compute the differential of `self` at `x`, consuming the input. +// fn differential_impl>(&self, x: I) -> Self::Derivative { +// self.apply(x) +// } +// } + +/// Vector spaces +#[replace_float_literals(Self::Field::cast_from(literal))] +pub trait VectorSpace: + Space + + Mul + + Div + + Add + + Add + + Sub + + Sub + + Neg + + for<'b> Add<&'b Self, Output = ::PrincipalV> + + for<'b> Sub<&'b Self, Output = ::PrincipalV> +{ + /// Underlying scalar field + type Field: Num; + + /// Principal form of the space; always equal to [`Space::Principal`], but with + /// more traits guaranteed. + /// + /// `PrincipalV` is only assumed to be `AXPY` for itself, as [`AXPY`] + /// uses [`Instance`] to apply all other variants and avoid problems + /// of choosing multiple implementations of the trait. + type PrincipalV: ClosedSpace + + AXPY< + Self::PrincipalV, + Field = Self::Field, + PrincipalV = Self::PrincipalV, + OwnedVariant = Self::PrincipalV, + Principal = Self::PrincipalV, + >; + + /// Return a similar zero as `self`. + fn similar_origin(&self) -> Self::PrincipalV; + // { + // self.make_origin_generator().make_origin() + // } + + /// Return a similar zero as `x`. + fn similar_origin_inst>(x: I) -> Self::PrincipalV { + x.eval(|xr| xr.similar_origin()) + } +} /// Efficient in-place summation. -#[replace_float_literals(F::cast_from(literal))] -pub trait AXPY : Space + std::ops::MulAssign +#[replace_float_literals(Self::Field::cast_from(literal))] +pub trait AXPY: + VectorSpace + + MulAssign + + DivAssign + + AddAssign + + AddAssign + + SubAssign + + SubAssign + + for<'b> AddAssign<&'b Self> + + for<'b> SubAssign<&'b Self> where - F : Num, - X : Space, + X: Space, { - type Owned : AXPY; - /// Computes `y = βy + αx`, where `y` is `Self`. - fn axpy>(&mut self, α : F, x : I, β : F); + fn axpy>(&mut self, α: Self::Field, x: I, β: Self::Field); /// Copies `x` to `self`. - fn copy_from>(&mut self, x : I) { + fn copy_from>(&mut self, x: I) { self.axpy(1.0, x, 0.0) } /// Computes `y = αx`, where `y` is `Self`. - fn scale_from>(&mut self, α : F, x : I) { + fn scale_from>(&mut self, α: Self::Field, x: I) { self.axpy(α, x, 0.0) } - /// Return a similar zero as `self`. - fn similar_origin(&self) -> Self::Owned; - /// Set self to zero. fn set_zero(&mut self); } +pub trait ClosedVectorSpace: Instance + VectorSpace {} +impl + VectorSpace> ClosedVectorSpace for X {} + /// Efficient in-place application for [`Linear`] operators. #[replace_float_literals(F::cast_from(literal))] -pub trait GEMV>::Codomain> : Linear { +pub trait GEMV>::Codomain>: Linear { /// Computes `y = αAx + βy`, where `A` is `Self`. - fn gemv>(&self, y : &mut Y, α : F, x : I, β : F); + fn gemv>(&self, y: &mut Y, α: F, x: I, β: F); #[inline] /// Computes `y = Ax`, where `A` is `Self` - fn apply_mut>(&self, y : &mut Y, x : I){ + fn apply_mut>(&self, y: &mut Y, x: I) { self.gemv(y, 1.0, x, 0.0) } #[inline] /// Computes `y += Ax`, where `A` is `Self` - fn apply_add>(&self, y : &mut Y, x : I){ + fn apply_add>(&self, y: &mut Y, x: I) { self.gemv(y, 1.0, x, 1.0) } } - /// Bounded linear operators -pub trait BoundedLinear : Linear +pub trait BoundedLinear: Linear where - F : Num, - X : Space + Norm, - XExp : NormExponent, - CodExp : NormExponent + F: Num, + X: Space, + XExp: NormExponent, + CodExp: NormExponent, { /// A bound on the operator norm $\|A\|$ for the linear operator $A$=`self`. /// This is not expected to be the norm, just any bound on it that can be /// reasonably implemented. The [`NormExponent`] `xexp` indicates the norm /// in `X`, and `codexp` in the codomain. - fn opnorm_bound(&self, xexp : XExp, codexp : CodExp) -> F; + /// + /// This may fail with an error if the bound is for some reason incalculable. + fn opnorm_bound(&self, xexp: XExp, codexp: CodExp) -> DynResult; } // Linear operator application into mutable target. The [`AsRef`] bound @@ -90,18 +152,45 @@ }*/ /// Trait for forming the adjoint operator of `Self`. -pub trait Adjointable : Linear +pub trait Adjointable: Linear where - X : Space, - Yʹ : Space, + X: Space, + Yʹ: Space, { - type AdjointCodomain : Space; - type Adjoint<'a> : Linear where Self : 'a; + /// Codomain of the adjoint operator. + type AdjointCodomain: ClosedSpace; + /// Type of the adjoint operator. + type Adjoint<'a>: Linear + where + Self: 'a; /// Form the adjoint operator of `self`. fn adjoint(&self) -> Self::Adjoint<'_>; } +/// Variant of [`Adjointable`] where the adjoint does not depend on a lifetime parameter. +/// This exists due to restrictions of Rust's type system: if `A :: Adjointable`, and we make +/// further restrictions on the adjoint operator, through, e.g. +/// ``` +/// for<'a> A::Adjoint<'a> : GEMV, +/// ``` +/// Then `'static` lifetime is forced on `X`. Having `A::SimpleAdjoint` not depend on `'a` +/// avoids this, but makes it impossible for the adjoint to be just a light wrapper around the +/// original operator. +pub trait SimplyAdjointable: Linear +where + X: Space, + Yʹ: Space, +{ + /// Codomain of the adjoint operator. + type AdjointCodomain: ClosedSpace; + /// Type of the adjoint operator. + type SimpleAdjoint: Linear; + + /// Form the adjoint operator of `self`. + fn adjoint(&self) -> Self::SimpleAdjoint; +} + /// Trait for forming a preadjoint of an operator. /// /// For an operator $A$ this is an operator $A\_\*$ @@ -112,404 +201,667 @@ /// We do not make additional restrictions on `Self::Preadjoint` (in particular, it /// does not have to be adjointable) to allow `X` to be a subspace yet the preadjoint /// have the full space as the codomain, etc. -pub trait Preadjointable : Linear { - type PreadjointCodomain : Space; - type Preadjoint<'a> : Linear< - Ypre, Codomain=Self::PreadjointCodomain - > where Self : 'a; +pub trait Preadjointable>::Codomain>: + Linear +{ + type PreadjointCodomain: ClosedSpace; + type Preadjoint<'a>: Linear + where + Self: 'a; /// Form the adjoint operator of `self`. fn preadjoint(&self) -> Self::Preadjoint<'_>; } -/// Adjointable operators $A: X → Y$ between reflexive spaces $X$ and $Y$. -pub trait SimplyAdjointable : Adjointable>::Codomain> {} -impl<'a,X : Space, T> SimplyAdjointable for T -where T : Adjointable>::Codomain> {} - /// The identity operator -#[derive(Clone,Copy,Debug,Serialize,Eq,PartialEq)] -pub struct IdOp (PhantomData); +#[derive(Clone, Copy, Debug, Serialize, Eq, PartialEq)] +pub struct IdOp(PhantomData); impl IdOp { - pub fn new() -> IdOp { IdOp(PhantomData) } + pub fn new() -> IdOp { + IdOp(PhantomData) + } } -impl Mapping for IdOp { - type Codomain = X; +impl Mapping for IdOp { + type Codomain = X::Principal; - fn apply>(&self, x : I) -> X { + fn apply>(&self, x: I) -> Self::Codomain { x.own() } } -impl Linear for IdOp -{ } +impl Linear for IdOp {} #[replace_float_literals(F::cast_from(literal))] -impl GEMV for IdOp +impl GEMV for IdOp where - Y : AXPY, - X : Clone + Space + Y: AXPY, + X: Space, { // Computes `y = αAx + βy`, where `A` is `Self`. - fn gemv>(&self, y : &mut Y, α : F, x : I, β : F) { + fn gemv>(&self, y: &mut Y, α: F, x: I, β: F) { y.axpy(α, x, β) } - fn apply_mut>(&self, y : &mut Y, x : I){ + fn apply_mut>(&self, y: &mut Y, x: I) { y.copy_from(x); } } impl BoundedLinear for IdOp where - X : Space + Clone + Norm, - F : Num, - E : NormExponent + X: Space + Clone, + F: Num, + E: NormExponent, { - fn opnorm_bound(&self, _xexp : E, _codexp : E) -> F { F::ONE } -} - -impl Adjointable for IdOp { - type AdjointCodomain=X; - type Adjoint<'a> = IdOp where X : 'a; - - fn adjoint(&self) -> Self::Adjoint<'_> { IdOp::new() } + fn opnorm_bound(&self, _xexp: E, _codexp: E) -> DynResult { + Ok(F::ONE) + } } -impl Preadjointable for IdOp { - type PreadjointCodomain=X; - type Preadjoint<'a> = IdOp where X : 'a; +impl Adjointable for IdOp { + type AdjointCodomain = X::Principal; + type Adjoint<'a> + = IdOp + where + X: 'a; - fn preadjoint(&self) -> Self::Preadjoint<'_> { IdOp::new() } + fn adjoint(&self) -> Self::Adjoint<'_> { + IdOp::new() + } } +impl SimplyAdjointable for IdOp { + type AdjointCodomain = X::Principal; + type SimpleAdjoint = IdOp; -/// The zero operator -#[derive(Clone,Copy,Debug,Serialize,Eq,PartialEq)] -pub struct ZeroOp<'a, X, XD, Y, F> { - zero : &'a Y, // TODO: don't pass this in `new`; maybe not even store. - dual_or_predual_zero : XD, - _phantoms : PhantomData<(X, Y, F)>, -} - -// TODO: Need to make Zero in Instance. - -impl<'a, F : Num, X : Space, XD, Y : Space + Clone> ZeroOp<'a, X, XD, Y, F> { - pub fn new(zero : &'a Y, dual_or_predual_zero : XD) -> Self { - ZeroOp{ zero, dual_or_predual_zero, _phantoms : PhantomData } + fn adjoint(&self) -> Self::SimpleAdjoint { + IdOp::new() } } -impl<'a, F : Num, X : Space, XD, Y : AXPY + Clone> Mapping for ZeroOp<'a, X, XD, Y, F> { - type Codomain = Y; +impl Preadjointable for IdOp { + type PreadjointCodomain = X::Principal; + type Preadjoint<'a> + = IdOp + where + X: 'a; - fn apply>(&self, _x : I) -> Y { - self.zero.clone() + fn preadjoint(&self) -> Self::Preadjoint<'_> { + IdOp::new() } } -impl<'a, F : Num, X : Space, XD, Y : AXPY + Clone> Linear for ZeroOp<'a, X, XD, Y, F> -{ } +/// The zero operator from a space to itself +#[derive(Clone, Copy, Debug, Serialize, Eq, PartialEq)] +pub struct SimpleZeroOp; + +impl Mapping for SimpleZeroOp { + type Codomain = X::PrincipalV; + + fn apply>(&self, x: I) -> X::PrincipalV { + X::similar_origin_inst(x) + } +} + +impl Linear for SimpleZeroOp {} #[replace_float_literals(F::cast_from(literal))] -impl<'a, F, X, XD, Y> GEMV for ZeroOp<'a, X, XD, Y, F> +impl GEMV for SimpleZeroOp where - F : Num, - Y : AXPY + Clone, - X : Space + F: Num, + Y: AXPY, + X: VectorSpace + Instance, { // Computes `y = αAx + βy`, where `A` is `Self`. - fn gemv>(&self, y : &mut Y, _α : F, _x : I, β : F) { + fn gemv>(&self, y: &mut Y, _α: F, _x: I, β: F) { *y *= β; } - fn apply_mut>(&self, y : &mut Y, _x : I){ + fn apply_mut>(&self, y: &mut Y, _x: I) { y.set_zero(); } } -impl<'a, F, X, XD, Y, E1, E2> BoundedLinear for ZeroOp<'a, X, XD, Y, F> +impl BoundedLinear for SimpleZeroOp +where + F: Num, + X: VectorSpace, + E1: NormExponent, + E2: NormExponent, +{ + fn opnorm_bound(&self, _xexp: E1, _codexp: E2) -> DynResult { + Ok(F::ZERO) + } +} + +impl Adjointable for SimpleZeroOp where - X : Space + Norm, - Y : AXPY + Clone + Norm, - F : Num, - E1 : NormExponent, - E2 : NormExponent, + F: Num, + X: VectorSpace + HasDual, + X::DualSpace: ClosedVectorSpace, { - fn opnorm_bound(&self, _xexp : E1, _codexp : E2) -> F { F::ZERO } + type AdjointCodomain = X::DualSpace; + type Adjoint<'b> + = SimpleZeroOp + where + Self: 'b; + // () means not (pre)adjointable. + + fn adjoint(&self) -> Self::Adjoint<'_> { + SimpleZeroOp + } +} + +pub trait OriginGenerator { + type Ref<'b>: OriginGenerator + where + Self: 'b; + + fn origin(&self) -> Y::PrincipalV; + fn as_ref(&self) -> Self::Ref<'_>; } -impl<'a, F : Num, X, XD, Y, Yprime : Space> Adjointable for ZeroOp<'a, X, XD, Y, F> -where - X : Space, - Y : AXPY + Clone + 'static, - XD : AXPY + Clone + 'static, +#[derive(Copy, Clone, Debug)] +pub struct StaticEuclideanOriginGenerator; + +impl, F: Float> OriginGenerator + for StaticEuclideanOriginGenerator { - type AdjointCodomain = XD; - type Adjoint<'b> = ZeroOp<'b, Yprime, (), XD, F> where Self : 'b; - // () means not (pre)adjointable. + type Ref<'b> + = Self + where + Self: 'b; + + #[inline] + fn origin(&self) -> Y::PrincipalV { + return Y::origin(); + } + + #[inline] + fn as_ref(&self) -> Self::Ref<'_> { + *self + } +} + +impl OriginGenerator for Y { + type Ref<'b> + = &'b Y + where + Self: 'b; + + #[inline] + fn origin(&self) -> Y::PrincipalV { + return self.similar_origin(); + } + + #[inline] + fn as_ref(&self) -> Self::Ref<'_> { + self + } +} - fn adjoint(&self) -> Self::Adjoint<'_> { - ZeroOp::new(&self.dual_or_predual_zero, ()) +impl<'b, Y: VectorSpace> OriginGenerator for &'b Y { + type Ref<'c> + = Self + where + Self: 'c; + + #[inline] + fn origin(&self) -> Y::PrincipalV { + return self.similar_origin(); + } + + #[inline] + fn as_ref(&self) -> Self::Ref<'_> { + self + } +} + +/// A zero operator that can be eitherh dualised or predualised (once). +/// This is achieved by storing an oppropriate zero. +pub struct ZeroOp, OY: OriginGenerator, O, F: Float = f64> { + codomain_origin_generator: OY, + other_origin_generator: O, + _phantoms: PhantomData<(X, Y, F)>, +} + +impl ZeroOp +where + OY: OriginGenerator, + X: VectorSpace, + Y: VectorSpace, + F: Float, +{ + pub fn new(y_og: OY) -> Self { + ZeroOp { + codomain_origin_generator: y_og, + other_origin_generator: (), + _phantoms: PhantomData, + } } } -impl<'a, F, X, XD, Y, Ypre> Preadjointable for ZeroOp<'a, X, XD, Y, F> +impl ZeroOp +where + OY: OriginGenerator, + OXprime: OriginGenerator, + X: HasDual, + Y: HasDual, + F: Float, + Xprime: VectorSpace, + Xprime::PrincipalV: AXPY, + Yprime: Space + Instance, +{ + pub fn new_dualisable(y_og: OY, xprime_og: OXprime) -> Self { + ZeroOp { + codomain_origin_generator: y_og, + other_origin_generator: xprime_og, + _phantoms: PhantomData, + } + } +} + +impl Mapping for ZeroOp +where + X: Space, + Y: VectorSpace, + F: Float, + OY: OriginGenerator, +{ + type Codomain = Y::PrincipalV; + + fn apply>(&self, _x: I) -> Y::PrincipalV { + self.codomain_origin_generator.origin() + } +} + +impl Linear for ZeroOp +where + X: Space, + Y: VectorSpace, + F: Float, + OY: OriginGenerator, +{ +} + +#[replace_float_literals(F::cast_from(literal))] +impl GEMV for ZeroOp +where + X: Space, + Y: AXPY, + F: Float, + OY: OriginGenerator, +{ + // Computes `y = αAx + βy`, where `A` is `Self`. + fn gemv>(&self, y: &mut Y, _α: F, _x: I, β: F) { + *y *= β; + } + + fn apply_mut>(&self, y: &mut Y, _x: I) { + y.set_zero(); + } +} + +impl BoundedLinear for ZeroOp where - F : Num, - X : Space, - Y : AXPY + Clone, - Ypre : Space, - XD : AXPY + Clone + 'static, + X: Space + Instance, + Y: VectorSpace, + Y::PrincipalV: Clone, + F: Float, + E1: NormExponent, + E2: NormExponent, + OY: OriginGenerator, +{ + fn opnorm_bound(&self, _xexp: E1, _codexp: E2) -> DynResult { + Ok(F::ZERO) + } +} + +impl<'b, X, Y, OY, OXprime, Xprime, Yprime, F> Adjointable + for ZeroOp +where + X: HasDual, + Y: HasDual, + F: Float, + Xprime: ClosedVectorSpace, + //Xprime::Owned: AXPY, + Yprime: ClosedSpace, + OY: OriginGenerator, + OXprime: OriginGenerator, { - type PreadjointCodomain = XD; - type Preadjoint<'b> = ZeroOp<'b, Ypre, (), XD, F> where Self : 'b; - // () means not (pre)adjointable. + type AdjointCodomain = Xprime; + type Adjoint<'c> + = ZeroOp, (), F> + where + Self: 'c; + // () means not (pre)adjointable. + + fn adjoint(&self) -> Self::Adjoint<'_> { + ZeroOp { + codomain_origin_generator: self.other_origin_generator.as_ref(), + other_origin_generator: (), + _phantoms: PhantomData, + } + } +} - fn preadjoint(&self) -> Self::Preadjoint<'_> { - ZeroOp::new(&self.dual_or_predual_zero, ()) +impl<'b, X, Y, OY, OXprime, Xprime, Yprime, F> SimplyAdjointable + for ZeroOp +where + X: HasDual, + Y: HasDual, + F: Float, + Xprime: ClosedVectorSpace, + //Xprime::Owned: AXPY, + Yprime: ClosedSpace, + OY: OriginGenerator, + OXprime: OriginGenerator + Clone, +{ + type AdjointCodomain = Xprime; + type SimpleAdjoint = ZeroOp; + // () means not (pre)adjointable. + + fn adjoint(&self) -> Self::SimpleAdjoint { + ZeroOp { + codomain_origin_generator: self.other_origin_generator.clone(), + other_origin_generator: (), + _phantoms: PhantomData, + } } } impl Linear for Composition where - X : Space, - T : Linear, - S : Linear -{ } + X: Space, + T: Linear, + S: Linear, +{ +} impl GEMV for Composition where - F : Num, - X : Space, - T : Linear, - S : GEMV, + F: Num, + X: Space, + T: Linear, + S: GEMV, { - fn gemv>(&self, y : &mut Y, α : F, x : I, β : F) { + fn gemv>(&self, y: &mut Y, α: F, x: I, β: F) { self.outer.gemv(y, α, self.inner.apply(x), β) } /// Computes `y = Ax`, where `A` is `Self` - fn apply_mut>(&self, y : &mut Y, x : I){ + fn apply_mut>(&self, y: &mut Y, x: I) { self.outer.apply_mut(y, self.inner.apply(x)) } /// Computes `y += Ax`, where `A` is `Self` - fn apply_add>(&self, y : &mut Y, x : I){ + fn apply_add>(&self, y: &mut Y, x: I) { self.outer.apply_add(y, self.inner.apply(x)) } } impl BoundedLinear for Composition where - F : Num, - X : Space + Norm, - Z : Space + Norm, - Xexp : NormExponent, - Yexp : NormExponent, - Zexp : NormExponent, - T : BoundedLinear, - S : BoundedLinear, + F: Num, + X: Space, + Z: Space, + Xexp: NormExponent, + Yexp: NormExponent, + Zexp: NormExponent, + T: BoundedLinear, + S: BoundedLinear, { - fn opnorm_bound(&self, xexp : Xexp, yexp : Yexp) -> F { + fn opnorm_bound(&self, xexp: Xexp, yexp: Yexp) -> DynResult { let zexp = self.intermediate_norm_exponent; - self.outer.opnorm_bound(zexp, yexp) * self.inner.opnorm_bound(xexp, zexp) + Ok(self.outer.opnorm_bound(zexp, yexp)? * self.inner.opnorm_bound(xexp, zexp)?) } } /// “Row operator” $(S, T)$; $(S, T)(x, y)=Sx + Ty$. +#[derive(Clone, Copy, Debug, Serialize, Eq, PartialEq)] pub struct RowOp(pub S, pub T); -use std::ops::Add; - impl Mapping> for RowOp where - A : Space, - B : Space, - S : Mapping, - T : Mapping, - S::Codomain : Add, - >::Output : Space, - + A: Space, + B: Space, + S: Mapping, + T: Mapping, + S::Codomain: Add, + >::Output: ClosedSpace, { type Codomain = >::Output; - fn apply>>(&self, x : I) -> Self::Codomain { - let Pair(a, b) = x.decompose(); - self.0.apply(a) + self.1.apply(b) + fn apply>>(&self, x: I) -> Self::Codomain { + x.eval_decompose(|Pair(a, b)| self.0.apply(a) + self.1.apply(b)) } } impl Linear> for RowOp where - A : Space, - B : Space, - S : Linear, - T : Linear, - S::Codomain : Add, - >::Output : Space, -{ } - + A: Space, + B: Space, + S: Linear, + T: Linear, + S::Codomain: Add, + >::Output: ClosedSpace, +{ +} impl<'b, F, S, T, Y, U, V> GEMV, Y> for RowOp where - U : Space, - V : Space, - S : GEMV, - T : GEMV, - F : Num, - Self : Linear, Codomain=Y> + U: Space, + V: Space, + S: GEMV, + T: GEMV, + F: Num, + Self: Linear, Codomain = Y>, { - fn gemv>>(&self, y : &mut Y, α : F, x : I, β : F) { - let Pair(u, v) = x.decompose(); - self.0.gemv(y, α, u, β); - self.1.gemv(y, α, v, F::ONE); + fn gemv>>(&self, y: &mut Y, α: F, x: I, β: F) { + x.eval_decompose(|Pair(u, v)| { + self.0.gemv(y, α, u, β); + self.1.gemv(y, α, v, F::ONE); + }) } - fn apply_mut>>(&self, y : &mut Y, x : I) { - let Pair(u, v) = x.decompose(); - self.0.apply_mut(y, u); - self.1.apply_add(y, v); + fn apply_mut>>(&self, y: &mut Y, x: I) { + x.eval_decompose(|Pair(u, v)| { + self.0.apply_mut(y, u); + self.1.apply_add(y, v); + }) } /// Computes `y += Ax`, where `A` is `Self` - fn apply_add>>(&self, y : &mut Y, x : I) { - let Pair(u, v) = x.decompose(); - self.0.apply_add(y, u); - self.1.apply_add(y, v); + fn apply_add>>(&self, y: &mut Y, x: I) { + x.eval_decompose(|Pair(u, v)| { + self.0.apply_add(y, u); + self.1.apply_add(y, v); + }) } } /// “Column operator” $(S; T)$; $(S; T)x=(Sx, Tx)$. +#[derive(Clone, Copy, Debug, Serialize, Eq, PartialEq)] pub struct ColOp(pub S, pub T); impl Mapping for ColOp where - A : Space, - S : Mapping, - T : Mapping, + A: Space, + S: Mapping, + T: Mapping, { type Codomain = Pair; - fn apply>(&self, a : I) -> Self::Codomain { - Pair(self.0.apply(a.ref_instance()), self.1.apply(a)) + fn apply>(&self, a: I) -> Self::Codomain { + Pair(a.eval_ref(|r| self.0.apply(r)), self.1.apply(a)) } } impl Linear for ColOp where - A : Space, - S : Mapping, - T : Mapping, -{ } + A: Space, + S: Mapping, + T: Mapping, +{ +} impl GEMV> for ColOp where - X : Space, - S : GEMV, - T : GEMV, - F : Num, - Self : Linear> + X: Space, + S: GEMV, + T: GEMV, + F: Num, + Self: Linear>, { - fn gemv>(&self, y : &mut Pair, α : F, x : I, β : F) { - self.0.gemv(&mut y.0, α, x.ref_instance(), β); + fn gemv>(&self, y: &mut Pair, α: F, x: I, β: F) { + x.eval_ref(|r| self.0.gemv(&mut y.0, α, r, β)); self.1.gemv(&mut y.1, α, x, β); } - fn apply_mut>(&self, y : &mut Pair, x : I){ - self.0.apply_mut(&mut y.0, x.ref_instance()); + fn apply_mut>(&self, y: &mut Pair, x: I) { + x.eval_ref(|r| self.0.apply_mut(&mut y.0, r)); self.1.apply_mut(&mut y.1, x); } /// Computes `y += Ax`, where `A` is `Self` - fn apply_add>(&self, y : &mut Pair, x : I){ - self.0.apply_add(&mut y.0, x.ref_instance()); + fn apply_add>(&self, y: &mut Pair, x: I) { + x.eval_ref(|r| self.0.apply_add(&mut y.0, r)); self.1.apply_add(&mut y.1, x); } } - -impl Adjointable, Yʹ> for RowOp +impl Adjointable, Yʹ> for RowOp where - A : Space, - B : Space, - Yʹ : Space, - S : Adjointable, - T : Adjointable, - Self : Linear>, + A: Space, + B: Space, + Yʹ: Space, + S: Adjointable, + T: Adjointable, + Self: Linear>, // for<'a> ColOp, T::Adjoint<'a>> : Linear< // Yʹ, // Codomain=Pair // >, { type AdjointCodomain = Pair; - type Adjoint<'a> = ColOp, T::Adjoint<'a>> where Self : 'a; + type Adjoint<'a> + = ColOp, T::Adjoint<'a>> + where + Self: 'a; fn adjoint(&self) -> Self::Adjoint<'_> { ColOp(self.0.adjoint(), self.1.adjoint()) } } -impl Preadjointable, Yʹ> for RowOp +impl SimplyAdjointable, Yʹ> for RowOp where - A : Space, - B : Space, - Yʹ : Space, - S : Preadjointable, - T : Preadjointable, - Self : Linear>, - for<'a> ColOp, T::Preadjoint<'a>> : Linear< - Yʹ, Codomain=Pair, - >, + A: Space, + B: Space, + Yʹ: Space, + S: SimplyAdjointable, + T: SimplyAdjointable, + Self: Linear>, + // for<'a> ColOp, T::Adjoint<'a>> : Linear< + // Yʹ, + // Codomain=Pair + // >, +{ + type AdjointCodomain = Pair; + type SimpleAdjoint = ColOp; + + fn adjoint(&self) -> Self::SimpleAdjoint { + ColOp(self.0.adjoint(), self.1.adjoint()) + } +} + +impl Preadjointable, Yʹ> for RowOp +where + A: Space, + B: Space, + Yʹ: Space, + S: Preadjointable, + T: Preadjointable, + Self: Linear>, + for<'a> ColOp, T::Preadjoint<'a>>: + Linear>, { type PreadjointCodomain = Pair; - type Preadjoint<'a> = ColOp, T::Preadjoint<'a>> where Self : 'a; + type Preadjoint<'a> + = ColOp, T::Preadjoint<'a>> + where + Self: 'a; fn preadjoint(&self) -> Self::Preadjoint<'_> { ColOp(self.0.preadjoint(), self.1.preadjoint()) } } - -impl Adjointable> for ColOp +impl Adjointable> for ColOp where - A : Space, - Xʹ : Space, - Yʹ : Space, - R : Space + ClosedAdd, - S : Adjointable, - T : Adjointable, - Self : Linear, + A: Space, + Xʹ: Space, + Yʹ: Space, + R: ClosedSpace + ClosedAdd, + S: Adjointable, + T: Adjointable, + Self: Linear, // for<'a> RowOp, T::Adjoint<'a>> : Linear< // Pair, // Codomain=R, // >, { type AdjointCodomain = R; - type Adjoint<'a> = RowOp, T::Adjoint<'a>> where Self : 'a; + type Adjoint<'a> + = RowOp, T::Adjoint<'a>> + where + Self: 'a; fn adjoint(&self) -> Self::Adjoint<'_> { RowOp(self.0.adjoint(), self.1.adjoint()) } } -impl Preadjointable> for ColOp +impl SimplyAdjointable> for ColOp where - A : Space, - Xʹ : Space, - Yʹ : Space, - R : Space + ClosedAdd, - S : Preadjointable, - T : Preadjointable, - Self : Linear, - for<'a> RowOp, T::Preadjoint<'a>> : Linear< - Pair, Codomain = R, - >, + A: Space, + Xʹ: Space, + Yʹ: Space, + R: ClosedSpace + ClosedAdd, + S: SimplyAdjointable, + T: SimplyAdjointable, + Self: Linear, + // for<'a> RowOp, T::Adjoint<'a>> : Linear< + // Pair, + // Codomain=R, + // >, +{ + type AdjointCodomain = R; + type SimpleAdjoint = RowOp; + + fn adjoint(&self) -> Self::SimpleAdjoint { + RowOp(self.0.adjoint(), self.1.adjoint()) + } +} + +impl Preadjointable> for ColOp +where + A: Space, + Xʹ: Space, + Yʹ: Space, + R: ClosedSpace + ClosedAdd, + S: Preadjointable, + T: Preadjointable, + Self: Linear, + for<'a> RowOp, T::Preadjoint<'a>>: Linear, Codomain = R>, { type PreadjointCodomain = R; - type Preadjoint<'a> = RowOp, T::Preadjoint<'a>> where Self : 'a; + type Preadjoint<'a> + = RowOp, T::Preadjoint<'a>> + where + Self: 'a; fn preadjoint(&self) -> Self::Preadjoint<'_> { RowOp(self.0.preadjoint(), self.1.preadjoint()) @@ -517,100 +869,126 @@ } /// Diagonal operator +#[derive(Clone, Copy, Debug, Serialize, Eq, PartialEq)] pub struct DiagOp(pub S, pub T); impl Mapping> for DiagOp where - A : Space, - B : Space, - S : Mapping, - T : Mapping, + A: Space, + B: Space, + S: Mapping, + T: Mapping, { type Codomain = Pair; - fn apply>>(&self, x : I) -> Self::Codomain { - let Pair(a, b) = x.decompose(); - Pair(self.0.apply(a), self.1.apply(b)) + fn apply>>(&self, x: I) -> Self::Codomain { + x.eval_decompose(|Pair(a, b)| Pair(self.0.apply(a), self.1.apply(b))) } } impl Linear> for DiagOp where - A : Space, - B : Space, - S : Linear, - T : Linear, -{ } + A: Space, + B: Space, + S: Linear, + T: Linear, +{ +} impl GEMV, Pair> for DiagOp where - A : Space, - B : Space, - U : Space, - V : Space, - S : GEMV, - T : GEMV, - F : Num, - Self : Linear, Codomain=Pair>, + A: Space, + B: Space, + U: Space, + V: Space, + S: GEMV, + T: GEMV, + F: Num, + Self: Linear, Codomain = Pair>, { - fn gemv>>(&self, y : &mut Pair, α : F, x : I, β : F) { - let Pair(u, v) = x.decompose(); - self.0.gemv(&mut y.0, α, u, β); - self.1.gemv(&mut y.1, α, v, β); + fn gemv>>(&self, y: &mut Pair, α: F, x: I, β: F) { + x.eval_decompose(|Pair(u, v)| { + self.0.gemv(&mut y.0, α, u, β); + self.1.gemv(&mut y.1, α, v, β); + }) } - fn apply_mut>>(&self, y : &mut Pair, x : I){ - let Pair(u, v) = x.decompose(); - self.0.apply_mut(&mut y.0, u); - self.1.apply_mut(&mut y.1, v); + fn apply_mut>>(&self, y: &mut Pair, x: I) { + x.eval_decompose(|Pair(u, v)| { + self.0.apply_mut(&mut y.0, u); + self.1.apply_mut(&mut y.1, v); + }) } /// Computes `y += Ax`, where `A` is `Self` - fn apply_add>>(&self, y : &mut Pair, x : I){ - let Pair(u, v) = x.decompose(); - self.0.apply_add(&mut y.0, u); - self.1.apply_add(&mut y.1, v); + fn apply_add>>(&self, y: &mut Pair, x: I) { + x.eval_decompose(|Pair(u, v)| { + self.0.apply_add(&mut y.0, u); + self.1.apply_add(&mut y.1, v); + }) } } -impl Adjointable, Pair> for DiagOp +impl Adjointable, Pair> for DiagOp where - A : Space, - B : Space, + A: Space, + B: Space, Xʹ: Space, - Yʹ : Space, - R : Space, - S : Adjointable, - T : Adjointable, - Self : Linear>, - for<'a> DiagOp, T::Adjoint<'a>> : Linear< - Pair, Codomain=R, - >, + Yʹ: Space, + R: ClosedSpace, + S: Adjointable, + T: Adjointable, + Self: Linear>, + for<'a> DiagOp, T::Adjoint<'a>>: Linear, Codomain = R>, { type AdjointCodomain = R; - type Adjoint<'a> = DiagOp, T::Adjoint<'a>> where Self : 'a; + type Adjoint<'a> + = DiagOp, T::Adjoint<'a>> + where + Self: 'a; fn adjoint(&self) -> Self::Adjoint<'_> { DiagOp(self.0.adjoint(), self.1.adjoint()) } } -impl Preadjointable, Pair> for DiagOp +impl SimplyAdjointable, Pair> for DiagOp where - A : Space, - B : Space, + A: Space, + B: Space, Xʹ: Space, - Yʹ : Space, - R : Space, - S : Preadjointable, - T : Preadjointable, - Self : Linear>, - for<'a> DiagOp, T::Preadjoint<'a>> : Linear< - Pair, Codomain=R, - >, + Yʹ: Space, + R: ClosedSpace, + S: SimplyAdjointable, + T: SimplyAdjointable, + Self: Linear>, + for<'a> DiagOp: Linear, Codomain = R>, +{ + type AdjointCodomain = R; + type SimpleAdjoint = DiagOp; + + fn adjoint(&self) -> Self::SimpleAdjoint { + DiagOp(self.0.adjoint(), self.1.adjoint()) + } +} + +impl Preadjointable, Pair> for DiagOp +where + A: Space, + B: Space, + Xʹ: Space, + Yʹ: Space, + R: ClosedSpace, + S: Preadjointable, + T: Preadjointable, + Self: Linear>, + for<'a> DiagOp, T::Preadjoint<'a>>: Linear, Codomain = R>, { type PreadjointCodomain = R; - type Preadjoint<'a> = DiagOp, T::Preadjoint<'a>> where Self : 'a; + type Preadjoint<'a> + = DiagOp, T::Preadjoint<'a>> + where + Self: 'a; fn preadjoint(&self) -> Self::Preadjoint<'_> { DiagOp(self.0.preadjoint(), self.1.preadjoint()) @@ -620,65 +998,90 @@ /// Block operator pub type BlockOp = ColOp, RowOp>; - macro_rules! pairnorm { ($expj:ty) => { impl - BoundedLinear, PairNorm, ExpR, F> - for RowOp + BoundedLinear, PairNorm, ExpR, F> for RowOp where - F : Float, - A : Space + Norm, - B : Space + Norm, - S : BoundedLinear, - T : BoundedLinear, - S::Codomain : Add, - >::Output : Space, - ExpA : NormExponent, - ExpB : NormExponent, - ExpR : NormExponent, + F: Float, + A: Space, + B: Space, + S: BoundedLinear, + T: BoundedLinear, + S::Codomain: Add, + >::Output: ClosedSpace, + ExpA: NormExponent, + ExpB: NormExponent, + ExpR: NormExponent, { fn opnorm_bound( &self, - PairNorm(expa, expb, _) : PairNorm, - expr : ExpR - ) -> F { + PairNorm(expa, expb, _): PairNorm, + expr: ExpR, + ) -> DynResult { // An application of the triangle inequality bounds the norm by the maximum // of the individual norms. A simple observation shows this to be exact. - let na = self.0.opnorm_bound(expa, expr); - let nb = self.1.opnorm_bound(expb, expr); - na.max(nb) + let na = self.0.opnorm_bound(expa, expr)?; + let nb = self.1.opnorm_bound(expb, expr)?; + Ok(na.max(nb)) } } - - impl - BoundedLinear, F> - for ColOp + + impl BoundedLinear, F> + for ColOp where - F : Float, - A : Space + Norm, - S : BoundedLinear, - T : BoundedLinear, - ExpA : NormExponent, - ExpS : NormExponent, - ExpT : NormExponent, + F: Float, + A: Space, + S: BoundedLinear, + T: BoundedLinear, + ExpA: NormExponent, + ExpS: NormExponent, + ExpT: NormExponent, { fn opnorm_bound( &self, - expa : ExpA, - PairNorm(exps, expt, _) : PairNorm - ) -> F { + expa: ExpA, + PairNorm(exps, expt, _): PairNorm, + ) -> DynResult { // This is based on the rule for RowOp and ‖A^*‖ = ‖A‖, hence, // for A=[S; T], ‖A‖=‖[S^*, T^*]‖ ≤ max{‖S^*‖, ‖T^*‖} = max{‖S‖, ‖T‖} - let ns = self.0.opnorm_bound(expa, exps); - let nt = self.1.opnorm_bound(expa, expt); - ns.max(nt) + let ns = self.0.opnorm_bound(expa, exps)?; + let nt = self.1.opnorm_bound(expa, expt)?; + Ok(ns.max(nt)) } } - } + }; } pairnorm!(L1); pairnorm!(L2); pairnorm!(Linfinity); +/// The simplest linear mapping, scaling by a scalar. +/// +/// TODO: redefined/replace `Weighted` by composition with [`Scaled`]. +pub struct Scaled(pub F); + +impl Mapping for Scaled +where + F: Float, + Domain: Space, + Domain::Principal: Mul, + >::Output: ClosedSpace, +{ + type Codomain = >::Output; + + /// Compute the value of `self` at `x`. + fn apply>(&self, x: I) -> Self::Codomain { + x.own() * self.0 + } +} + +impl Linear for Scaled +where + F: Float, + Domain: Space, + Domain::Principal: Mul, + >::Output: ClosedSpace, +{ +} diff -r 1f19c6bbf07b -r 3868555d135c src/loc.rs --- a/src/loc.rs Sun Apr 27 20:29:43 2025 -0500 +++ b/src/loc.rs Fri May 15 14:46:30 2026 -0500 @@ -3,44 +3,48 @@ For working with small vectors in $ℝ^2$ or $ℝ^3$. */ -use std::ops::{Add,Sub,AddAssign,SubAssign,Mul,Div,MulAssign,DivAssign,Neg,Index,IndexMut}; -use std::slice::{Iter,IterMut}; +use crate::euclidean::*; +use crate::instance::{BasicDecomposition, Instance}; +use crate::linops::{Linear, Mapping, VectorSpace, AXPY}; +use crate::mapping::Space; +use crate::maputil::{map1, map1_mut, map2, map2_mut, FixedLength, FixedLengthMut}; +use crate::norms::*; +use crate::self_ownable; +use crate::types::{Float, Num, SignedNum}; +use serde::ser::{Serialize, SerializeSeq, Serializer}; use std::fmt::{Display, Formatter}; -use crate::types::{Float,Num,SignedNum}; -use crate::maputil::{FixedLength,FixedLengthMut,map1,map2,map1_mut,map2_mut}; -use crate::euclidean::*; -use crate::norms::*; -use crate::linops::{AXPY, Mapping, Linear}; -use crate::instance::{Instance, BasicDecomposition}; -use crate::mapping::Space; -use serde::ser::{Serialize, Serializer, SerializeSeq}; - +use std::ops::{ + Add, AddAssign, Div, DivAssign, Index, IndexMut, Mul, MulAssign, Neg, Sub, SubAssign, +}; +use std::slice::{Iter, IterMut}; /// A container type for (short) `N`-dimensional vectors of element type `F`. /// /// Supports basic operations of an [`Euclidean`] space, several [`Norm`]s, and /// fused [`AXPY`] operations, among others. -#[derive(Copy,Clone,Debug,PartialEq,Eq)] -pub struct Loc( +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub struct Loc( /// An array of the elements of the vector - pub [F; N] + pub [F; N], ); -impl Display for Loc{ +self_ownable!(Loc where const N: usize, F: Copy); + +impl Display for Loc { // Required method fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "[")?; - let mut comma = ""; - for e in self.iter() { + write!(f, "[")?; + let mut comma = ""; + for e in self.iter() { write!(f, "{comma}{e}")?; comma = ", "; - } - write!(f, "]") + } + write!(f, "]") } } // Need to manually implement as [F; N] serialisation is provided only for some N. -impl Serialize for Loc +impl Serialize for Loc where F: Serialize, { @@ -56,10 +60,10 @@ } } -impl Loc { +impl Loc { /// Creates a new `Loc` vector from an array. #[inline] - pub fn new(arr : [F; N]) -> Self { + pub fn new(arr: [F; N]) -> Self { Loc(arr) } @@ -76,43 +80,44 @@ } } -impl Loc { +impl Loc { /// Maps `g` over the elements of the vector, returning a new [`Loc`] vector #[inline] - pub fn map(&self, g : impl Fn(F) -> H) -> Loc { + pub fn map(&self, g: impl Fn(F) -> H) -> Loc { Loc::new(map1(self, |u| g(*u))) } /// Maps `g` over pairs of elements of two vectors, retuning a new one. #[inline] - pub fn map2(&self, other : &Self, g : impl Fn(F, F) -> H) -> Loc { + pub fn map2(&self, other: &Self, g: impl Fn(F, F) -> H) -> Loc { Loc::new(map2(self, other, |u, v| g(*u, *v))) } /// Maps `g` over mutable references to elements of the vector. #[inline] - pub fn map_mut(&mut self, g : impl Fn(&mut F)) { + pub fn map_mut(&mut self, g: impl Fn(&mut F)) { map1_mut(self, g) } /// Maps `g` over pairs of mutable references to elements of `self, and elements /// of `other` vector. #[inline] - pub fn map2_mut(&mut self, other : &Self, g : impl Fn(&mut F, F)) { + pub fn map2_mut(&mut self, other: &Self, g: impl Fn(&mut F, F)) { map2_mut(self, other, |u, v| g(u, *v)) } /// Maps `g` over the elements of `self` and returns the product of the results. #[inline] - pub fn product_map(&self, g : impl Fn(F) -> A) -> A { + pub fn product_map(&self, g: impl Fn(F) -> A) -> A { match N { 1 => g(unsafe { *self.0.get_unchecked(0) }), - 2 => g(unsafe { *self.0.get_unchecked(0) }) * - g(unsafe { *self.0.get_unchecked(1) }), - 3 => g(unsafe { *self.0.get_unchecked(0) }) * - g(unsafe { *self.0.get_unchecked(1) }) * - g(unsafe { *self.0.get_unchecked(2) }), - _ => self.iter().fold(A::ONE, |m, &x| m * g(x)) + 2 => g(unsafe { *self.0.get_unchecked(0) }) * g(unsafe { *self.0.get_unchecked(1) }), + 3 => { + g(unsafe { *self.0.get_unchecked(0) }) + * g(unsafe { *self.0.get_unchecked(1) }) + * g(unsafe { *self.0.get_unchecked(2) }) + } + _ => self.iter().fold(A::ONE, |m, &x| m * g(x)), } } } @@ -130,29 +135,28 @@ ($($x:expr),+ $(,)?) => { Loc::new([$($x),+]) } } - -impl From<[F; N]> for Loc { +impl From<[F; N]> for Loc { #[inline] - fn from(other: [F; N]) -> Loc { + fn from(other: [F; N]) -> Loc { Loc(other) } } -/*impl From<&[F; N]> for Loc { +/*impl From<&[F; N]> for Loc { #[inline] - fn from(other: &[F; N]) -> Loc { + fn from(other: &[F; N]) -> Loc { Loc(*other) } }*/ -impl From for Loc { +impl From for Loc<1, F> { #[inline] - fn from(other: F) -> Loc { + fn from(other: F) -> Loc<1, F> { Loc([other]) } } -impl Loc { +impl Loc<1, F> { #[inline] pub fn flatten1d(self) -> F { let Loc([v]) = self; @@ -160,22 +164,21 @@ } } -impl From> for [F; N] { +impl From> for [F; N] { #[inline] - fn from(other : Loc) -> [F; N] { + fn from(other: Loc) -> [F; N] { other.0 } } -/*impl From<&Loc> for [F; N] { +/*impl From<&Loc> for [F; N] { #[inline] - fn from(other : &Loc) -> [F; N] { + fn from(other : &Loc) -> [F; N] { other.0 } }*/ - -impl IntoIterator for Loc { +impl IntoIterator for Loc { type Item = <[F; N] as IntoIterator>::Item; type IntoIter = <[F; N] as IntoIterator>::IntoIter; @@ -187,20 +190,24 @@ // Indexing -impl Index for Loc -where [F; N] : Index { +impl Index for Loc +where + [F; N]: Index, +{ type Output = <[F; N] as Index>::Output; #[inline] - fn index(&self, ix : Ix) -> &Self::Output { + fn index(&self, ix: Ix) -> &Self::Output { self.0.index(ix) } } -impl IndexMut for Loc -where [F; N] : IndexMut { +impl IndexMut for Loc +where + [F; N]: IndexMut, +{ #[inline] - fn index_mut(&mut self, ix : Ix) -> &mut Self::Output { + fn index_mut(&mut self, ix: Ix) -> &mut Self::Output { self.0.index_mut(ix) } } @@ -209,61 +216,61 @@ macro_rules! make_binop { ($trait:ident, $fn:ident, $trait_assign:ident, $fn_assign:ident) => { - impl $trait> for Loc { - type Output = Loc; + impl $trait> for Loc { + type Output = Loc; #[inline] - fn $fn(mut self, other : Loc) -> Self::Output { + fn $fn(mut self, other: Loc) -> Self::Output { self.$fn_assign(other); self } } - impl<'a, F : Num, const N : usize> $trait<&'a Loc> for Loc { - type Output = Loc; + impl<'a, F: Num, const N: usize> $trait<&'a Loc> for Loc { + type Output = Loc; #[inline] - fn $fn(mut self, other : &'a Loc) -> Self::Output { + fn $fn(mut self, other: &'a Loc) -> Self::Output { self.$fn_assign(other); self } } - impl<'b, F : Num, const N : usize> $trait> for &'b Loc { - type Output = Loc; + impl<'b, F: Num, const N: usize> $trait> for &'b Loc { + type Output = Loc; #[inline] - fn $fn(self, other : Loc) -> Self::Output { + fn $fn(self, other: Loc) -> Self::Output { self.map2(&other, |a, b| a.$fn(b)) } } - impl<'a, 'b, F : Num, const N : usize> $trait<&'a Loc> for &'b Loc { - type Output = Loc; + impl<'a, 'b, F: Num, const N: usize> $trait<&'a Loc> for &'b Loc { + type Output = Loc; #[inline] - fn $fn(self, other : &'a Loc) -> Self::Output { + fn $fn(self, other: &'a Loc) -> Self::Output { self.map2(other, |a, b| a.$fn(b)) } } - impl $trait_assign> for Loc { + impl $trait_assign> for Loc { #[inline] - fn $fn_assign(&mut self, other : Loc) { + fn $fn_assign(&mut self, other: Loc) { self.map2_mut(&other, |a, b| a.$fn_assign(b)) } } - impl<'a, F : Num, const N : usize> $trait_assign<&'a Loc> for Loc { + impl<'a, F: Num, const N: usize> $trait_assign<&'a Loc> for Loc { #[inline] - fn $fn_assign(&mut self, other : &'a Loc) { + fn $fn_assign(&mut self, other: &'a Loc) { self.map2_mut(other, |a, b| a.$fn_assign(b)) } } - } + }; } make_binop!(Add, add, AddAssign, add_assign); make_binop!(Sub, sub, SubAssign, sub_assign); -impl std::iter::Sum for Loc { - fn sum>>(mut iter: I) -> Self { +impl std::iter::Sum for Loc { + fn sum>>(mut iter: I) -> Self { match iter.next() { None => Self::ORIGIN, Some(mut v) => { @@ -278,62 +285,61 @@ macro_rules! make_scalarop_rhs { ($trait:ident, $fn:ident, $trait_assign:ident, $fn_assign:ident) => { - impl $trait for Loc { - type Output = Loc; + impl $trait for Loc { + type Output = Loc; #[inline] - fn $fn(self, b : F) -> Self::Output { + fn $fn(self, b: F) -> Self::Output { self.map(|a| a.$fn(b)) } } - impl<'a, F : Num, const N : usize> $trait<&'a F> for Loc { - type Output = Loc; + impl<'a, F: Num, const N: usize> $trait<&'a F> for Loc { + type Output = Loc; #[inline] - fn $fn(self, b : &'a F) -> Self::Output { + fn $fn(self, b: &'a F) -> Self::Output { self.map(|a| a.$fn(*b)) } } - impl<'b, F : Num, const N : usize> $trait for &'b Loc { - type Output = Loc; + impl<'b, F: Num, const N: usize> $trait for &'b Loc { + type Output = Loc; #[inline] - fn $fn(self, b : F) -> Self::Output { + fn $fn(self, b: F) -> Self::Output { self.map(|a| a.$fn(b)) } } - impl<'a, 'b, F : Float, const N : usize> $trait<&'a F> for &'b Loc { - type Output = Loc; + impl<'a, 'b, F: Float, const N: usize> $trait<&'a F> for &'b Loc { + type Output = Loc; #[inline] - fn $fn(self, b : &'a F) -> Self::Output { + fn $fn(self, b: &'a F) -> Self::Output { self.map(|a| a.$fn(*b)) } } - impl $trait_assign for Loc { + impl $trait_assign for Loc { #[inline] - fn $fn_assign(&mut self, b : F) { + fn $fn_assign(&mut self, b: F) { self.map_mut(|a| a.$fn_assign(b)); } } - impl<'a, F : Num, const N : usize> $trait_assign<&'a F> for Loc { + impl<'a, F: Num, const N: usize> $trait_assign<&'a F> for Loc { #[inline] - fn $fn_assign(&mut self, b : &'a F) { + fn $fn_assign(&mut self, b: &'a F) { self.map_mut(|a| a.$fn_assign(*b)); } } - } + }; } - make_scalarop_rhs!(Mul, mul, MulAssign, mul_assign); make_scalarop_rhs!(Div, div, DivAssign, div_assign); macro_rules! make_unaryop { ($trait:ident, $fn:ident) => { - impl $trait for Loc { - type Output = Loc; + impl $trait for Loc { + type Output = Loc; #[inline] fn $fn(mut self) -> Self::Output { self.map_mut(|a| *a = (*a).$fn()); @@ -341,48 +347,48 @@ } } - impl<'a, F : SignedNum, const N : usize> $trait for &'a Loc { - type Output = Loc; + impl<'a, F: SignedNum, const N: usize> $trait for &'a Loc { + type Output = Loc; #[inline] fn $fn(self) -> Self::Output { self.map(|a| a.$fn()) } } - } + }; } make_unaryop!(Neg, neg); macro_rules! make_scalarop_lhs { ($trait:ident, $fn:ident; $($f:ident)+) => { $( - impl $trait> for $f { - type Output = Loc<$f, N>; + impl $trait> for $f { + type Output = Loc; #[inline] - fn $fn(self, v : Loc<$f,N>) -> Self::Output { + fn $fn(self, v : Loc) -> Self::Output { v.map(|b| self.$fn(b)) } } - impl<'a, const N : usize> $trait<&'a Loc<$f,N>> for $f { - type Output = Loc<$f, N>; + impl<'a, const N : usize> $trait<&'a Loc> for $f { + type Output = Loc; #[inline] - fn $fn(self, v : &'a Loc<$f,N>) -> Self::Output { + fn $fn(self, v : &'a Loc) -> Self::Output { v.map(|b| self.$fn(b)) } } - impl<'b, const N : usize> $trait> for &'b $f { - type Output = Loc<$f, N>; + impl<'b, const N : usize> $trait> for &'b $f { + type Output = Loc; #[inline] - fn $fn(self, v : Loc<$f,N>) -> Self::Output { + fn $fn(self, v : Loc) -> Self::Output { v.map(|b| self.$fn(b)) } } - impl<'a, 'b, const N : usize> $trait<&'a Loc<$f,N>> for &'b $f { - type Output = Loc<$f, N>; + impl<'a, 'b, const N : usize> $trait<&'a Loc> for &'b $f { + type Output = Loc; #[inline] - fn $fn(self, v : &'a Loc<$f, N>) -> Self::Output { + fn $fn(self, v : &'a Loc) -> Self::Output { v.map(|b| self.$fn(b)) } } @@ -396,21 +402,21 @@ macro_rules! domination { ($norm:ident, $dominates:ident) => { - impl Dominated> for $norm { + impl Dominated> for $norm { #[inline] - fn norm_factor(&self, _p : $dominates) -> F { + fn norm_factor(&self, _p: $dominates) -> F { F::ONE } #[inline] - fn from_norm(&self, p_norm : F, _p : $dominates) -> F { + fn from_norm(&self, p_norm: F, _p: $dominates) -> F { p_norm } } }; ($norm:ident, $dominates:ident, $fn:path) => { - impl Dominated> for $norm { + impl Dominated> for $norm { #[inline] - fn norm_factor(&self, _p : $dominates) -> F { + fn norm_factor(&self, _p: $dominates) -> F { $fn(F::cast_from(N)) } } @@ -429,16 +435,19 @@ domination!(Linfinity, L2); domination!(L2, L1); -impl Euclidean for Loc { - type Output = Self; +impl Euclidean for Loc { + type PrincipalE = Self; /// This implementation is not stabilised as it's meant to be used for very small vectors. /// Use [`nalgebra`] for larger vectors. #[inline] - fn dot>(&self, other : I) -> F { - self.0.iter() - .zip(other.ref_instance().0.iter()) - .fold(F::ZERO, |m, (&v, &w)| m + v * w) + fn dot>(&self, other: I) -> F { + other.eval_ref(|r| { + self.0 + .iter() + .zip(r.0.iter()) + .fold(F::ZERO, |m, (&v, &w)| m + v * w) + }) } /// This implementation is not stabilised as it's meant to be used for very small vectors. @@ -448,16 +457,19 @@ self.iter().fold(F::ZERO, |m, &v| m + v * v) } - fn dist2_squared>(&self, other : I) -> F { - self.iter() - .zip(other.ref_instance().iter()) - .fold(F::ZERO, |m, (&v, &w)| { let d = v - w; m + d * d }) + fn dist2_squared>(&self, other: I) -> F { + other.eval_ref(|r| { + self.iter().zip(r.iter()).fold(F::ZERO, |m, (&v, &w)| { + let d = v - w; + m + d * d + }) + }) } #[inline] fn norm2(&self) -> F { // Optimisation for N==1 that avoids squaring and square rooting. - if N==1 { + if N == 1 { unsafe { self.0.get_unchecked(0) }.abs() } else { self.norm2_squared().sqrt() @@ -465,39 +477,38 @@ } #[inline] - fn dist2>(&self, other : I) -> F { + fn dist2>(&self, other: I) -> F { // Optimisation for N==1 that avoids squaring and square rooting. - let otherr = other.ref_instance(); - if N==1 { - unsafe { *self.0.get_unchecked(0) - *otherr.0.get_unchecked(0) }.abs() + if N == 1 { + other.eval_ref(|r| unsafe { *self.0.get_unchecked(0) - *r.0.get_unchecked(0) }.abs()) } else { self.dist2_squared(other).sqrt() } } } -impl Loc { +impl Loc { /// Origin point - pub const ORIGIN : Self = Loc([F::ZERO; N]); + pub const ORIGIN: Self = Loc([F::ZERO; N]); } -impl, const N : usize> Loc { +impl, const N: usize> Loc { /// Reflects along the given coordinate - pub fn reflect(mut self, i : usize) -> Self { + pub fn reflect(mut self, i: usize) -> Self { self[i] = -self[i]; self } /// Reflects along the given coordinate sequences - pub fn reflect_many>(mut self, idxs : I) -> Self { + pub fn reflect_many>(mut self, idxs: I) -> Self { for i in idxs { - self[i]=-self[i] + self[i] = -self[i] } self } } -impl> Loc { +impl> Loc<2, F> { /// Reflect x coordinate pub fn reflect_x(self) -> Self { let Loc([x, y]) = self; @@ -511,18 +522,17 @@ } } -impl Loc { +impl Loc<2, F> { /// Rotate by angle φ - pub fn rotate(self, φ : F) -> Self { + pub fn rotate(self, φ: F) -> Self { let Loc([x, y]) = self; let sin_φ = φ.sin(); let cos_φ = φ.cos(); - [cos_φ * x - sin_φ * y, - sin_φ * x + cos_φ * y].into() + [cos_φ * x - sin_φ * y, sin_φ * x + cos_φ * y].into() } } -impl> Loc { +impl> Loc<3, F> { /// Reflect x coordinate pub fn reflect_x(self) -> Self { let Loc([x, y, z]) = self; @@ -542,39 +552,32 @@ } } -impl Loc { +impl Loc<3, F> { /// Rotate by angles (yaw, pitch, roll) - pub fn rotate(self, Loc([φ, ψ, θ]) : Self) -> Self { + pub fn rotate(self, Loc([φ, ψ, θ]): Self) -> Self { let Loc([mut x, mut y, mut z]) = self; let sin_φ = φ.sin(); let cos_φ = φ.cos(); - [x, y, z] = [cos_φ * x - sin_φ *y, - sin_φ * x + cos_φ * y, - z]; + [x, y, z] = [cos_φ * x - sin_φ * y, sin_φ * x + cos_φ * y, z]; let sin_ψ = ψ.sin(); let cos_ψ = ψ.cos(); - [x, y, z] = [cos_ψ * x + sin_ψ * z, - y, - -sin_ψ * x + cos_ψ * z]; + [x, y, z] = [cos_ψ * x + sin_ψ * z, y, -sin_ψ * x + cos_ψ * z]; let sin_θ = θ.sin(); let cos_θ = θ.cos(); - [x, y, z] = [x, - cos_θ * y - sin_θ * z, - sin_θ * y + cos_θ * z]; + [x, y, z] = [x, cos_θ * y - sin_θ * z, sin_θ * y + cos_θ * z]; [x, y, z].into() } } -impl StaticEuclidean for Loc { +impl StaticEuclidean for Loc { #[inline] fn origin() -> Self { Self::ORIGIN } } - /// The default norm for `Loc` is [`L2`]. -impl Normed for Loc { +impl Normed for Loc { type NormExp = L2; #[inline] @@ -593,22 +596,30 @@ } } -impl HasDual for Loc { +impl HasDual for Loc { type DualSpace = Self; + + fn dual_origin(&self) -> Self::DualSpace { + self.similar_origin() + } } -impl Norm for Loc { +impl Norm for Loc { #[inline] - fn norm(&self, _ : L2) -> F { self.norm2() } + fn norm(&self, _: L2) -> F { + self.norm2() + } } -impl Dist for Loc { +impl Dist for Loc { #[inline] - fn dist>(&self, other : I, _ : L2) -> F { self.dist2(other) } + fn dist>(&self, other: I, _: L2) -> F { + self.dist2(other) + } } /* Implemented automatically as Euclidean. -impl Projection for Loc { +impl Projection for Loc { #[inline] fn proj_ball_mut(&mut self, ρ : F, nrm : L2) { let n = self.norm(nrm); @@ -618,53 +629,65 @@ } }*/ -impl Norm for Loc { +impl Norm for Loc { /// This implementation is not stabilised as it's meant to be used for very small vectors. /// Use [`nalgebra`] for larger vectors. #[inline] - fn norm(&self, _ : L1) -> F { + fn norm(&self, _: L1) -> F { self.iter().fold(F::ZERO, |m, v| m + v.abs()) } } -impl Dist for Loc { +impl Dist for Loc { #[inline] - fn dist>(&self, other : I, _ : L1) -> F { - self.iter() - .zip(other.ref_instance().iter()) - .fold(F::ZERO, |m, (&v, &w)| m + (v-w).abs() ) + fn dist>(&self, other: I, _: L1) -> F { + other.eval_ref(|r| { + self.iter() + .zip(r.iter()) + .fold(F::ZERO, |m, (&v, &w)| m + (v - w).abs()) + }) } } -impl Projection for Loc { +impl Projection for Loc { #[inline] - fn proj_ball_mut(&mut self, ρ : F, _ : Linfinity) { - self.iter_mut().for_each(|v| *v = num_traits::clamp(*v, -ρ, ρ)) + fn proj_ball(mut self, ρ: F, exp: Linfinity) -> Self { + self.proj_ball_mut(ρ, exp); + self } } -impl Norm for Loc { +impl ProjectionMut for Loc { + #[inline] + fn proj_ball_mut(&mut self, ρ: F, _: Linfinity) { + self.iter_mut() + .for_each(|v| *v = num_traits::clamp(*v, -ρ, ρ)) + } +} + +impl Norm for Loc { /// This implementation is not stabilised as it's meant to be used for very small vectors. /// Use [`nalgebra`] for larger vectors. #[inline] - fn norm(&self, _ : Linfinity) -> F { + fn norm(&self, _: Linfinity) -> F { self.iter().fold(F::ZERO, |m, v| m.max(v.abs())) } } -impl Dist for Loc { +impl Dist for Loc { #[inline] - fn dist>(&self, other : I, _ : Linfinity) -> F { - self.iter() - .zip(other.ref_instance().iter()) - .fold(F::ZERO, |m, (&v, &w)| m.max((v-w).abs())) + fn dist>(&self, other: I, _: Linfinity) -> F { + other.eval_ref(|r| { + self.iter() + .zip(r.iter()) + .fold(F::ZERO, |m, (&v, &w)| m.max((v - w).abs())) + }) } } - // Misc. -impl FixedLength for Loc { +impl FixedLength for Loc { type Iter = std::array::IntoIter; type Elem = A; #[inline] @@ -673,15 +696,18 @@ } } -impl FixedLengthMut for Loc { - type IterMut<'a> = std::slice::IterMut<'a, A> where A : 'a; +impl FixedLengthMut for Loc { + type IterMut<'a> + = std::slice::IterMut<'a, A> + where + A: 'a; #[inline] fn fl_iter_mut(&mut self) -> Self::IterMut<'_> { self.iter_mut() } } -impl<'a, A, const N : usize> FixedLength for &'a Loc { +impl<'a, A, const N: usize> FixedLength for &'a Loc { type Iter = std::slice::Iter<'a, A>; type Elem = &'a A; #[inline] @@ -690,43 +716,61 @@ } } - -impl Space for Loc { +impl Space for Loc { + type Principal = Self; type Decomp = BasicDecomposition; } -impl Mapping> for Loc { +impl Mapping> for Loc { type Codomain = F; - fn apply>>(&self, x : I) -> Self::Codomain { - x.eval(|x̃| self.dot(x̃)) + fn apply>>(&self, x: I) -> Self::Codomain { + x.eval_decompose(|x̃| self.dot(x̃)) } } -impl Linear> for Loc { } +impl Linear> for Loc {} + +impl VectorSpace for Loc { + type Field = F; + type PrincipalV = Self; -impl AXPY> for Loc { - type Owned = Self; + // #[inline] + // fn make_origin_generator(&self) -> StaticEuclideanOriginGenerator { + // StaticEuclideanOriginGenerator + // } + + #[inline] + fn similar_origin(&self) -> Self::PrincipalV { + Self::ORIGIN + } #[inline] - fn axpy>>(&mut self, α : F, x : I, β : F) { - x.eval(|x̃| { + fn similar_origin_inst>(_: I) -> Self::PrincipalV { + Self::ORIGIN + } + + // #[inline] + // fn into_owned(self) -> Self::Owned { + // self + // } +} + +impl AXPY> for Loc { + #[inline] + fn axpy>>(&mut self, α: F, x: I, β: F) { + x.eval_ref(|x̃| { if β == F::ZERO { - map2_mut(self, x̃, |yi, xi| { *yi = α * (*xi) }) + map2_mut(self, x̃, |yi, xi| *yi = α * (*xi)) } else { - map2_mut(self, x̃, |yi, xi| { *yi = β * (*yi) + α * (*xi) }) + map2_mut(self, x̃, |yi, xi| *yi = β * (*yi) + α * (*xi)) } }) } #[inline] - fn copy_from>>(&mut self, x : I) { - x.eval(|x̃| map2_mut(self, x̃, |yi, xi| *yi = *xi )) - } - - #[inline] - fn similar_origin(&self) -> Self::Owned { - Self::ORIGIN + fn copy_from>>(&mut self, x: I) { + x.eval_ref(|x̃| map2_mut(self, x̃, |yi, xi| *yi = *xi)) } #[inline] @@ -734,4 +778,3 @@ *self = Self::ORIGIN; } } - diff -r 1f19c6bbf07b -r 3868555d135c src/mapping.rs --- a/src/mapping.rs Sun Apr 27 20:29:43 2025 -0500 +++ b/src/mapping.rs Fri May 15 14:46:30 2026 -0500 @@ -2,84 +2,94 @@ Traits for mathematical functions. */ -use std::marker::PhantomData; -use std::borrow::Cow; -use crate::types::{Num, Float, ClosedMul}; +use crate::error::DynResult; +use crate::instance::MyCow; +pub use crate::instance::{BasicDecomposition, ClosedSpace, Decomposition, Instance, Space}; use crate::loc::Loc; -pub use crate::instance::{Instance, Decomposition, BasicDecomposition, Space}; use crate::norms::{Norm, NormExponent}; -use crate::operator_arithmetic::{Weighted, Constant}; +use crate::operator_arithmetic::{Constant, Weighted}; +use crate::types::{ClosedMul, Float, Num}; +use std::marker::PhantomData; +use std::ops::Mul; /// A mapping from `Domain` to `Self::Codomain`. -pub trait Mapping { - type Codomain : Space; +pub trait Mapping { + type Codomain: ClosedSpace; /// Compute the value of `self` at `x`. - fn apply>(&self, x : I) -> Self::Codomain; + fn apply>(&self, x: I) -> Self::Codomain; #[inline] /// Form the composition `self ∘ other` - fn compose>(self, other : T) - -> Composition + fn compose>(self, other: T) -> Composition where - Self : Sized + Self: Sized, { - Composition{ outer : self, inner : other, intermediate_norm_exponent : () } + Composition { outer: self, inner: other, intermediate_norm_exponent: () } } - #[inline] /// Form the composition `self ∘ other`, assigning a norm to the inermediate space - fn compose_with_norm( - self, other : T, norm : E - ) -> Composition + fn compose_with_norm(self, other: T, norm: E) -> Composition where - Self : Sized, - X : Space, - T : Mapping, - E : NormExponent, - Domain : Norm, - F : Num + Self: Sized, + X: Space, + T: Mapping, + E: NormExponent, + Domain: Norm, + F: Num, { - Composition{ outer : self, inner : other, intermediate_norm_exponent : norm } + Composition { outer: self, inner: other, intermediate_norm_exponent: norm } } /// Multiply `self` by the scalar `a`. #[inline] - fn weigh(self, a : C) -> Weighted + fn weigh(self, a: C) -> Weighted where - Self : Sized, - C : Constant, - Self::Codomain : ClosedMul, + Self: Sized, + C: Constant, + Self::Codomain: ClosedMul, { - Weighted { weight : a, base_fn : self } + Weighted { weight: a, base_fn: self } } } -/// Automatically implemented shorthand for referring to [`Mapping`]s from [`Loc`] to `F`. -pub trait RealMapping -: Mapping, Codomain = F> {} +/// Automatically implemented shorthand for referring to [`Mapping`]s from [`Loc`] to `F`. +pub trait RealMapping: Mapping, Codomain = F> {} -impl RealMapping for T -where T : Mapping, Codomain = F> {} +impl RealMapping for T where T: Mapping, Codomain = F> {} -/// A helper trait alias for referring to [`Mapping`]s from [`Loc`] to [`Loc`]. -pub trait RealVectorField -: Mapping, Codomain = Loc> {} +/// A helper trait alias for referring to [`Mapping`]s from [`Loc`] to [`Loc`]. +pub trait RealVectorField: + Mapping, Codomain = Loc> +{ +} -impl RealVectorField for T -where T : Mapping, Codomain = Loc> {} +impl RealVectorField for T where + T: Mapping, Codomain = Loc> +{ +} /// A differentiable mapping from `Domain` to [`Mapping::Codomain`], with differentials /// `Differential`. /// /// This is automatically implemented when [`DifferentiableImpl`] is. -pub trait DifferentiableMapping : Mapping { - type DerivativeDomain : Space; - type Differential<'b> : Mapping where Self : 'b; +pub trait DifferentiableMapping: Mapping { + type DerivativeDomain: ClosedSpace; + type Differential<'b>: Mapping + where + Self: 'b; /// Calculate differential at `x` - fn differential>(&self, x : I) -> Self::DerivativeDomain; + fn differential>(&self, x: I) -> Self::DerivativeDomain; + + /// Calculate differential and value at `x` + fn apply_and_differential>( + &self, + x: I, + ) -> (Self::Codomain, Self::DerivativeDomain) { + x.eval_ref(|xo| (self.apply(xo), self.differential(xo))) + } /// Form the differential mapping of `self`. fn diff(self) -> Self::Differential<'static>; @@ -89,51 +99,75 @@ } /// Automatically implemented shorthand for referring to differentiable [`Mapping`]s from -/// [`Loc`] to `F`. -pub trait DifferentiableRealMapping -: DifferentiableMapping, Codomain = F, DerivativeDomain = Loc> {} +/// [`Loc`] to `F`. +pub trait DifferentiableRealMapping: + DifferentiableMapping, Codomain = F, DerivativeDomain = Loc> +{ +} -impl DifferentiableRealMapping for T -where T : DifferentiableMapping, Codomain = F, DerivativeDomain = Loc> {} +impl DifferentiableRealMapping for T where + T: DifferentiableMapping, Codomain = F, DerivativeDomain = Loc> +{ +} /// Helper trait for implementing [`DifferentiableMapping`] -pub trait DifferentiableImpl : Sized { - type Derivative : Space; +pub trait DifferentiableImpl: Sized { + type Derivative: ClosedSpace; /// Compute the differential of `self` at `x`, consuming the input. - fn differential_impl>(&self, x : I) -> Self::Derivative; + fn differential_impl>(&self, x: I) -> Self::Derivative; + + fn apply_and_differential_impl>( + &self, + x: I, + ) -> (Self::Codomain, Self::Derivative) + where + Self: Mapping, + { + x.eval_ref(|xo| (self.apply(xo), self.differential_impl(xo))) + } } impl DifferentiableMapping for T where - Domain : Space, - T : Clone + Mapping + DifferentiableImpl + Domain: Space, + T: Mapping + DifferentiableImpl, { type DerivativeDomain = T::Derivative; - type Differential<'b> = Differential<'b, Domain, Self> where Self : 'b; - + type Differential<'b> + = Differential<'b, Domain, Self> + where + Self: 'b; + #[inline] - fn differential>(&self, x : I) -> Self::DerivativeDomain { + fn differential>(&self, x: I) -> Self::DerivativeDomain { self.differential_impl(x) } + #[inline] + fn apply_and_differential>( + &self, + x: I, + ) -> (T::Codomain, Self::DerivativeDomain) { + self.apply_and_differential_impl(x) + } + fn diff(self) -> Differential<'static, Domain, Self> { - Differential{ g : Cow::Owned(self), _space : PhantomData } + Differential { g: MyCow::Owned(self), _space: PhantomData } } fn diff_ref(&self) -> Differential<'_, Domain, Self> { - Differential{ g : Cow::Borrowed(self), _space : PhantomData } + Differential { g: MyCow::Borrowed(self), _space: PhantomData } } } - /// Container for the differential [`Mapping`] of a [`DifferentiableMapping`]. -pub struct Differential<'a, X, G : Clone> { - g : Cow<'a, G>, - _space : PhantomData +pub struct Differential<'a, X, G> { + g: MyCow<'a, G>, + _space: PhantomData, } -impl<'a, X, G : Clone> Differential<'a, X, G> { +impl<'a, X, G> Differential<'a, X, G> { pub fn base_fn(&self) -> &G { &self.g } @@ -141,65 +175,66 @@ impl<'a, X, G> Mapping for Differential<'a, X, G> where - X : Space, - G : Clone + DifferentiableMapping + X: Space, + G: DifferentiableMapping, { type Codomain = G::DerivativeDomain; #[inline] - fn apply>(&self, x : I) -> Self::Codomain { + fn apply>(&self, x: I) -> Self::Codomain { (*self.g).differential(x) } } /// Container for flattening [`Loc`]`` codomain of a [`Mapping`] to `F`. pub struct FlattenedCodomain { - g : G, - _phantoms : PhantomData<(X, F)> + g: G, + _phantoms: PhantomData<(X, F)>, } -impl Mapping for FlattenedCodomain +impl Mapping for FlattenedCodomain where - X : Space, - G: Mapping> + F: ClosedSpace, + X: Space, + G: Mapping>, { type Codomain = F; #[inline] - fn apply>(&self, x : I) -> Self::Codomain { + fn apply>(&self, x: I) -> Self::Codomain { self.g.apply(x).flatten1d() } } /// An auto-trait for constructing a [`FlattenCodomain`] structure for /// flattening the codomain of a [`Mapping`] from [`Loc`]`` to `F`. -pub trait FlattenCodomain : Mapping> + Sized { +pub trait FlattenCodomain: Mapping> + Sized { /// Flatten the codomain from [`Loc`]`` to `F`. fn flatten_codomain(self) -> FlattenedCodomain { - FlattenedCodomain{ g : self, _phantoms : PhantomData } + FlattenedCodomain { g: self, _phantoms: PhantomData } } } -impl>> FlattenCodomain for G {} +impl>> FlattenCodomain for G {} -/// Container for dimensional slicing [`Loc`]`` codomain of a [`Mapping`] to `F`. -pub struct SlicedCodomain<'a, X, F, G : Clone, const N : usize> { - g : Cow<'a, G>, - slice : usize, - _phantoms : PhantomData<(X, F)> +/// Container for dimensional slicing [`Loc`]`` codomain of a [`Mapping`] to `F`. +pub struct SlicedCodomain<'a, X, F, G, const N: usize> { + g: MyCow<'a, G>, + slice: usize, + _phantoms: PhantomData<(X, F)>, } -impl<'a, X, F, G, const N : usize> Mapping for SlicedCodomain<'a, X, F, G, N> +impl<'a, X, F, G, const N: usize> Mapping for SlicedCodomain<'a, X, F, G, N> where - X : Space, - F : Copy + Space, - G : Mapping> + Clone, + X: Space, + F: Copy + ClosedSpace, + G: Mapping>, { type Codomain = F; #[inline] - fn apply>(&self, x : I) -> Self::Codomain { - let tmp : [F; N] = (*self.g).apply(x).into(); + fn apply>(&self, x: I) -> Self::Codomain { + let tmp: [F; N] = (*self.g).apply(x).into(); // Safety: `slice_codomain` below checks the range. unsafe { *tmp.get_unchecked(self.slice) } } @@ -207,44 +242,106 @@ /// An auto-trait for constructing a [`FlattenCodomain`] structure for /// flattening the codomain of a [`Mapping`] from [`Loc`]`` to `F`. -pub trait SliceCodomain - : Mapping> + Clone + Sized +pub trait SliceCodomain: + Mapping> + Sized { /// Flatten the codomain from [`Loc`]`` to `F`. - fn slice_codomain(self, slice : usize) -> SlicedCodomain<'static, X, F, Self, N> { + fn slice_codomain(self, slice: usize) -> SlicedCodomain<'static, X, F, Self, N> { assert!(slice < N); - SlicedCodomain{ g : Cow::Owned(self), slice, _phantoms : PhantomData } + SlicedCodomain { g: MyCow::Owned(self), slice, _phantoms: PhantomData } } /// Flatten the codomain from [`Loc`]`` to `F`. - fn slice_codomain_ref(&self, slice : usize) -> SlicedCodomain<'_, X, F, Self, N> { + fn slice_codomain_ref(&self, slice: usize) -> SlicedCodomain<'_, X, F, Self, N> { assert!(slice < N); - SlicedCodomain{ g : Cow::Borrowed(self), slice, _phantoms : PhantomData } + SlicedCodomain { g: MyCow::Borrowed(self), slice, _phantoms: PhantomData } } } -impl> + Clone, const N : usize> -SliceCodomain -for G {} - +impl>, const N: usize> + SliceCodomain for G +{ +} /// The composition S ∘ T. `E` is for storing a `NormExponent` for the intermediate space. pub struct Composition { - pub outer : S, - pub inner : T, - pub intermediate_norm_exponent : E + pub outer: S, + pub inner: T, + pub intermediate_norm_exponent: E, } impl Mapping for Composition where - X : Space, - T : Mapping, - S : Mapping + X: Space, + T: Mapping, + S: Mapping, { type Codomain = S::Codomain; #[inline] - fn apply>(&self, x : I) -> Self::Codomain { + fn apply>(&self, x: I) -> Self::Codomain { self.outer.apply(self.inner.apply(x)) } } + +/// Helper trait for implementing [`DifferentiableMapping`] +impl DifferentiableImpl for Composition +where + X: Space, + T: DifferentiableImpl + Mapping, + S: DifferentiableImpl, + E: Copy, + //Composition: Space, + S::Derivative: Mul, + Y: ClosedSpace, +{ + //type Derivative = Composition; + type Derivative = Y; + + /// Compute the differential of `self` at `x`, consuming the input. + fn differential_impl>(&self, x: I) -> Self::Derivative { + // Composition { + // outer: self + // .outer + // .differential_impl(self.inner.apply(x.ref_instance())), + // inner: self.inner.differential_impl(x), + // intermediate_norm_exponent: self.intermediate_norm_exponent, + // } + + self.outer + .differential_impl(x.eval_ref(|r| self.inner.apply(r))) + * self.inner.differential_impl(x) + } +} + +mod dataterm; +pub use dataterm::DataTerm; + +/// Trait for indicating that `Self` is Lipschitz with respect to the (semi)norm `D`. +pub trait Lipschitz { + /// The type of floats + type FloatType: Float; + + /// Returns the Lipschitz factor of `self` with respect to the (semi)norm `D`. + fn lipschitz_factor(&self, seminorm: M) -> DynResult; +} + +/// Helper trait for implementing [`Lipschitz`] for mappings that implement [`DifferentiableImpl`]. +pub trait LipschitzDifferentiableImpl: DifferentiableImpl { + type FloatType: Float; + + /// Compute the lipschitz factor of the derivative of `f`. + fn diff_lipschitz_factor(&self, seminorm: M) -> DynResult; +} + +impl<'b, M, X, A> Lipschitz for Differential<'b, X, A> +where + X: Space, + A: LipschitzDifferentiableImpl, +{ + type FloatType = A::FloatType; + + fn lipschitz_factor(&self, seminorm: M) -> DynResult { + (*self.g).diff_lipschitz_factor(seminorm) + } +} diff -r 1f19c6bbf07b -r 3868555d135c src/mapping/dataterm.rs --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/mapping/dataterm.rs Fri May 15 14:46:30 2026 -0500 @@ -0,0 +1,155 @@ +/*! +General deata terms of the form $g(Ax-b)$ for an operator $A$ +to a [`Euclidean`] space, and a function g on that space. +*/ + +#![allow(non_snake_case)] + +use super::{DifferentiableImpl, DifferentiableMapping, LipschitzDifferentiableImpl, Mapping}; +use crate::convex::ConvexMapping; +use crate::error::DynResult; +use crate::instance::{ClosedSpace, Instance, Space}; +use crate::linops::{BoundedLinear, Linear, Preadjointable}; +use crate::norms::{Normed, L2}; +use crate::types::Float; +use std::ops::Sub; + +//use serde::{Deserialize, Serialize}; + +/// Functions of the form $g(Ax-b)$ for an operator $A$, data $b$, and fidelity $g$. +pub struct DataTerm< + F: Float, + Domain: Space, + A: Mapping, + G: Mapping, +> { + // The operator A + opA: A, + // The data b + b: >::Codomain, + // The outer fidelity + g: G, +} + +// Derive has troubles with `b`. +impl Clone for DataTerm +where + F: Float, + Domain: Space, + A: Mapping + Clone, + >::Codomain: Clone, + G: Mapping + Clone, +{ + fn clone(&self) -> Self { + DataTerm { opA: self.opA.clone(), b: self.b.clone(), g: self.g.clone() } + } +} + +#[allow(non_snake_case)] +impl, G: Mapping> + DataTerm +{ + pub fn new(opA: A, b: A::Codomain, g: G) -> Self { + DataTerm { opA, b, g } + } + + pub fn operator(&self) -> &'_ A { + &self.opA + } + + pub fn data(&self) -> &'_ >::Codomain { + &self.b + } + + pub fn fidelity(&self) -> &'_ G { + &self.g + } + + /// Returns the residual $Ax-b$. + pub fn residual<'a, 'b>(&'b self, x: &'a Domain) -> >::Codomain + where + &'a Domain: Instance, + >::Codomain: + Sub<&'b >::Codomain, Output = >::Codomain>, + { + self.opA.apply(x) - &self.b + } +} + +//+ AdjointProductBoundedBy, P, FloatType = F>, + +impl Mapping for DataTerm +where + F: Float, + X: Space, + A: Mapping, + G: Mapping, + A::Codomain: ClosedSpace + for<'a> Sub<&'a A::Codomain, Output = A::Codomain>, +{ + type Codomain = F; + + fn apply>(&self, x: I) -> F { + // TODO: possibly (if at all more effcient) use GEMV once generalised + // to not require preallocation. However, Rust should be pretty efficient + // at not doing preallocations or anything here, as the result of self.opA.apply() + // can be consumed, so maybe GEMV is no use. + self.g.apply(self.opA.apply(x) - &self.b) + } +} + +impl ConvexMapping for DataTerm +where + F: Float, + X: Normed, + A: Linear, + G: ConvexMapping, + A::Codomain: ClosedSpace + Normed + for<'a> Sub<&'a A::Codomain, Output = A::Codomain>, +{ +} + +impl DifferentiableImpl for DataTerm +where + F: Float, + X: Space, + Y: Space + Instance + for<'a> Sub<&'a Y, Output = Y>, + A: Linear + Preadjointable, + G::DerivativeDomain: Instance, + A::PreadjointCodomain: ClosedSpace, + G: DifferentiableMapping, + Self: Mapping, +{ + type Derivative = A::PreadjointCodomain; + + fn differential_impl>(&self, x: I) -> Self::Derivative { + // TODO: possibly (if at all more effcient) use GEMV once generalised + // to not require preallocation. However, Rust should be pretty efficient + // at not doing preallocations or anything here, as the result of self.opA.apply() + // can be consumed, so maybe GEMV is no use. + //self.opA.preadjoint().apply(self.opA.apply(x) - self.b) + self.opA + .preadjoint() + .apply(self.g.differential(self.opA.apply(x) - &self.b)) + } + + fn apply_and_differential_impl>(&self, x: I) -> (F, Self::Derivative) { + let j = self.opA.apply(x) - &self.b; + let (v, d) = self.g.apply_and_differential(j); + (v, self.opA.preadjoint().apply(d)) + } +} + +impl<'a, F, X, Y, A, G> LipschitzDifferentiableImpl for DataTerm +where + F: Float, + X: Normed, + Y: Normed, + A: BoundedLinear, + G: Mapping + LipschitzDifferentiableImpl, + Self: DifferentiableImpl, +{ + type FloatType = F; + + fn diff_lipschitz_factor(&self, seminorm: X::NormExp) -> DynResult { + Ok(self.opA.opnorm_bound(seminorm, L2)?.powi(2)) + } +} diff -r 1f19c6bbf07b -r 3868555d135c src/metaprogramming.rs --- a/src/metaprogramming.rs Sun Apr 27 20:29:43 2025 -0500 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,54 +0,0 @@ -/*! -Metaprogramming tools -*/ - -/// Reference `x` if so indicated by the first parameter. -/// Typically to be used from another macro. -/// -/// ```ignore -/// maybe_ref!(ref, V) // ➡ &V -/// maybe_ref!(noref, V) // ➡ V -/// ``` -macro_rules! maybe_ref { - (ref, $x:expr) => { - &$x - }; - (noref, $x:expr) => { - $x - }; - (ref, $x:ty) => { - &$x - }; - (noref, $x:ty) => { - $x - }; -} - -/// Choose `a` if first argument is the literal `ref`, otherwise `b`. -// macro_rules! ifref { -// (noref, $a:expr, $b:expr) => { -// $b -// }; -// (ref, $a:expr, $b:expr) => { -// $a -// }; -// } - -/// Annotate `x` with a lifetime if the first parameter -/// Typically to be used from another macro. -/// -/// ```ignore -/// maybe_ref!(ref, &'a V) // ➡ &'a V -/// maybe_ref!(noref, &'a V) // ➡ V -/// ``` -macro_rules! maybe_lifetime { - (ref, $x:ty) => { - $x - }; - (noref, &$lt:lifetime $x:ty) => { - $x - }; - (noref, &$x:ty) => { - $x - }; -} diff -r 1f19c6bbf07b -r 3868555d135c src/nalgebra_support.rs --- a/src/nalgebra_support.rs Sun Apr 27 20:29:43 2025 -0500 +++ b/src/nalgebra_support.rs Fri May 15 14:46:30 2026 -0500 @@ -8,107 +8,390 @@ [`num_traits`] does. */ -use nalgebra::{ - Matrix, Storage, StorageMut, OMatrix, Dim, DefaultAllocator, Scalar, - ClosedAddAssign, ClosedMulAssign, SimdComplexField, Vector, OVector, RealField, - LpNorm, UniformNorm -}; -use nalgebra::base::constraint::{ - ShapeConstraint, SameNumberOfRows, SameNumberOfColumns -}; -use nalgebra::base::dimension::*; -use nalgebra::base::allocator::Allocator; -use std::ops::Mul; -use num_traits::identities::{Zero, One}; +use crate::euclidean::*; +use crate::instance::{Decomposition, Instance, MyCow, Ownable, Space}; use crate::linops::*; -use crate::euclidean::*; -use crate::mapping::{Space, BasicDecomposition}; +use crate::norms::*; use crate::types::Float; -use crate::norms::*; -use crate::instance::Instance; +use nalgebra::base::allocator::Allocator; +use nalgebra::base::constraint::{DimEq, SameNumberOfColumns, SameNumberOfRows, ShapeConstraint}; +use nalgebra::base::dimension::*; +use nalgebra::{ + ArrayStorage, ClosedAddAssign, ClosedMulAssign, DefaultAllocator, Dim, LpNorm, Matrix, + MatrixView, OMatrix, OVector, RawStorage, RealField, SimdComplexField, Storage, StorageMut, + UniformNorm, VecStorage, Vector, ViewStorage, ViewStorageMut, U1, +}; +use num_traits::identities::Zero; +use std::ops::Mul; + +impl Ownable for Matrix +where + S: Storage, + M: Dim, + N: Dim, + E: Float, + DefaultAllocator: Allocator, +{ + type OwnedVariant = OMatrix; + + #[inline] + fn into_owned(self) -> Self::OwnedVariant { + Matrix::into_owned(self) + } + + /// Returns an owned instance of a reference. + fn clone_owned(&self) -> Self::OwnedVariant { + Matrix::clone_owned(self) + } -impl Space for Matrix + fn cow_owned<'b>(self) -> MyCow<'b, Self::OwnedVariant> + where + Self: 'b, + { + MyCow::Owned(self.into_owned()) + } + + fn ref_cow_owned<'b>(&'b self) -> MyCow<'b, Self::OwnedVariant> + where + Self: 'b, + { + MyCow::Owned(self.clone_owned()) + } +} + +trait StridesOk>::Buffer>: + DimEq + DimEq where - SM: Storage + Clone, - N : Dim, M : Dim, E : Scalar + Zero + One + ClosedAddAssign + ClosedMulAssign, - DefaultAllocator : Allocator, + S: RawStorage, + E: Float, + N: Dim, + M: Dim, + DefaultAllocator: Allocator, { - type Decomp = BasicDecomposition; +} + +impl StridesOk> for ShapeConstraint +where + M: Dim, + E: Float, + DefaultAllocator: Allocator, +{ } -impl Mapping> for Matrix -where SM: Storage, SV: Storage + Clone, - N : Dim, M : Dim, K : Dim, E : Scalar + Zero + One + ClosedAddAssign + ClosedMulAssign, - DefaultAllocator : Allocator, - DefaultAllocator : Allocator, - DefaultAllocator : Allocator, - DefaultAllocator : Allocator { - type Codomain = OMatrix; +impl StridesOk, Const, ArrayStorage> + for ShapeConstraint +where + E: Float, +{ +} + +macro_rules! strides_ok { + ($R:ty, $C:ty where $($qual:tt)*) => { + impl<'a, E, N, M, $($qual)*> StridesOk> for ShapeConstraint + where + N: Dim, + M: Dim, + E: Float, + DefaultAllocator: Allocator, + { + } + impl<'a, E, N, M, $($qual)*> StridesOk> for ShapeConstraint + where + N: Dim, + M: Dim, + E: Float, + DefaultAllocator: Allocator, + { + } + }; +} + +strides_ok!(Dyn, Dyn where ); +strides_ok!(Dyn, Const where const C : usize); +strides_ok!(Const, Dyn where const R : usize); +strides_ok!(Const, Const where const R : usize, const C : usize); + +impl Space for Matrix +where + SM: Storage, + N: Dim, + M: Dim, + E: Float, + DefaultAllocator: Allocator, + ShapeConstraint: StridesOk, +{ + type Principal = OMatrix; + type Decomp = MatrixDecomposition; +} + +#[derive(Copy, Clone, Debug)] +pub struct MatrixDecomposition; + +impl Decomposition> for MatrixDecomposition +where + S: Storage, + M: Dim, + K: Dim, + E: Float, + DefaultAllocator: Allocator, + ShapeConstraint: StridesOk, +{ + type Decomposition<'b> + = MyCow<'b, OMatrix> + where + Matrix: 'b; + type Reference<'b> + = MatrixView<'b, E, M, K, Dyn, Dyn> + where + Matrix: 'b; #[inline] - fn apply>>( - &self, x : I - ) -> Self::Codomain { - x.either(|owned| self.mul(owned), |refr| self.mul(refr)) + fn lift<'b>(r: Self::Reference<'b>) -> Self::Decomposition<'b> + where + S: 'b, + { + MyCow::Owned(r.into_owned()) } } +impl Instance, MatrixDecomposition> for Matrix +where + S1: Storage, + S2: Storage, + M: Dim, + K: Dim, + E: Float, + DefaultAllocator: Allocator, + ShapeConstraint: StridesOk + StridesOk, +{ + #[inline] + fn eval_ref<'b, R>( + &'b self, + f: impl FnOnce(>>::Reference<'b>) -> R, + ) -> R + where + Self: 'b, + Matrix: 'b, + { + f(self.as_view::()) + } -impl<'a, SM,SV,N,M,K,E> Linear> for Matrix -where SM: Storage, SV: Storage + Clone, - N : Dim, M : Dim, K : Dim, E : Scalar + Zero + One + ClosedAddAssign + ClosedMulAssign, - DefaultAllocator : Allocator, - DefaultAllocator : Allocator, - DefaultAllocator : Allocator, - DefaultAllocator : Allocator { + #[inline] + fn own(self) -> OMatrix { + self.into_owned() + } + + #[inline] + fn cow<'b>(self) -> MyCow<'b, OMatrix> + where + Self: 'b, + { + self.cow_owned() + } + + #[inline] + fn decompose<'b>(self) -> MyCow<'b, OMatrix> + where + Self: 'b, + { + self.cow_owned() + } +} + +impl<'a, S1, S2, M, K, E> Instance, MatrixDecomposition> + for &'a Matrix +where + S1: Storage, + S2: Storage, + M: Dim, + K: Dim, + E: Float, + DefaultAllocator: Allocator, + ShapeConstraint: StridesOk + StridesOk, +{ + fn eval_ref<'b, R>( + &'b self, + f: impl FnOnce(>>::Reference<'b>) -> R, + ) -> R + where + Self: 'b, + Matrix: 'b, + { + f((*self).as_view::()) + } + + #[inline] + fn own(self) -> OMatrix { + self.into_owned() + } + + #[inline] + fn cow<'b>(self) -> MyCow<'b, OMatrix> + where + Self: 'b, + { + self.cow_owned() + } + + #[inline] + fn decompose<'b>(self) -> MyCow<'b, OMatrix> + where + Self: 'b, + { + self.cow_owned() + } } -impl GEMV, Matrix> for Matrix -where SM: Storage, SV1: Storage + Clone, SV2: StorageMut, - N : Dim, M : Dim, K : Dim, E : Scalar + Zero + One + Float, - DefaultAllocator : Allocator, - DefaultAllocator : Allocator, - DefaultAllocator : Allocator, - DefaultAllocator : Allocator { +impl<'a, S1, M, SM, K, E> Instance, MatrixDecomposition> + for MyCow<'a, Matrix> +where + S1: Storage, + SM: Storage, + M: Dim, + K: Dim, + E: Float, + DefaultAllocator: Allocator, + ShapeConstraint: StridesOk + StridesOk, +{ + #[inline] + fn eval_ref<'b, R>( + &'b self, + f: impl FnOnce(>>::Reference<'b>) -> R, + ) -> R + where + Self: 'b, + Matrix: 'b, + { + f(self.as_view::()) + } + + #[inline] + fn own(self) -> OMatrix { + self.into_owned() + } + + #[inline] + fn cow<'b>(self) -> MyCow<'b, OMatrix> + where + Self: 'b, + { + self.cow_owned() + } #[inline] - fn gemv>>( - &self, y : &mut Matrix, α : E, x : I, β : E + fn decompose<'b>(self) -> MyCow<'b, OMatrix> + where + Self: 'b, + { + self.cow_owned() + } +} + +impl Mapping> for Matrix +where + SM: Storage, + N: Dim, + M: Dim, + K: Dim, + E: Float, + DefaultAllocator: Allocator + Allocator, + ShapeConstraint: StridesOk + StridesOk, +{ + type Codomain = OMatrix; + + #[inline] + fn apply>>(&self, x: I) -> Self::Codomain { + x.eval_ref(|refr| self.mul(refr)) + } +} + +impl<'a, SM, N, M, K, E> Linear> for Matrix +where + SM: Storage, + N: Dim, + M: Dim, + K: Dim, + E: Float + ClosedMulAssign + ClosedAddAssign, + DefaultAllocator: Allocator + Allocator, + ShapeConstraint: StridesOk + StridesOk, +{ +} + +impl GEMV, Matrix> for Matrix +where + SM: Storage, + SV2: StorageMut, + N: Dim, + M: Dim, + K: Dim, + E: Float, + DefaultAllocator: Allocator + Allocator, + ShapeConstraint: StridesOk + StridesOk, +{ + #[inline] + fn gemv>>( + &self, y: &mut Matrix, α: E, x: I, β: E ) { x.eval(|x̃| Matrix::gemm(y, α, self, x̃, β)) } #[inline] - fn apply_mut<'a, I : Instance>>(&self, y : &mut Matrix, x : I) { + fn apply_mut<'a, I: Instance>>(&self, y: &mut Matrix, x: I) { x.eval(|x̃| self.mul_to(x̃, y)) } } -impl AXPY> for Vector -where SM: StorageMut + Clone, SV1: Storage + Clone, - M : Dim, E : Scalar + Zero + One + Float, - DefaultAllocator : Allocator { - type Owned = OVector; +impl VectorSpace for Matrix +where + S: Storage, + M: Dim, + N: Dim, + E: Float, + DefaultAllocator: Allocator, + ShapeConstraint: StridesOk, +{ + type Field = E; + type PrincipalV = OMatrix; #[inline] - fn axpy>>(&mut self, α : E, x : I, β : E) { - x.eval(|x̃| Matrix::axpy(self, α, x̃, β)) + fn similar_origin(&self) -> Self::PrincipalV { + let (n, m) = self.shape_generic(); + OMatrix::zeros_generic(n, m) + } +} + +// This can only be implemented for the “principal” OMatrix as parameter, as otherwise +// we run into problems of multiple implementations when calling the methods. +impl AXPY> for Matrix +where + S: StorageMut, + M: Dim, + N: Dim, + E: Float, + DefaultAllocator: Allocator, + ShapeConstraint: StridesOk, +{ + #[inline] + fn axpy>>(&mut self, α: E, x: I, β: E) { + x.eval(|x̃| { + assert_eq!(self.ncols(), x̃.ncols()); + // nalgebra does not implement axpy for matrices, and flattenining + // also seems difficult, so loop over columns. + for (mut y, ỹ) in self.column_iter_mut().zip(x̃.column_iter()) { + Vector::axpy(&mut y, α, &ỹ, β) + } + }) } #[inline] - fn copy_from>>(&mut self, y : I) { - y.eval(|ỹ| Matrix::copy_from(self, ỹ)) + fn copy_from>>(&mut self, y: I) { + y.eval_ref(|ỹ| Matrix::copy_from(self, &ỹ)) } #[inline] fn set_zero(&mut self) { self.iter_mut().for_each(|e| *e = E::ZERO); } - - #[inline] - fn similar_origin(&self) -> Self::Owned { - OVector::zeros_generic(M::from_usize(self.len()), Const) - } } /* Implemented automatically as Euclidean. @@ -125,26 +408,52 @@ } }*/ -impl Projection for Vector -where SM: StorageMut + Clone, - M : Dim, E : Scalar + Zero + One + Float + RealField, - DefaultAllocator : Allocator { +impl Projection for Vector +where + SM: StorageMut, + M: Dim, + E: Float + RealField, + DefaultAllocator: Allocator, + ShapeConstraint: StridesOk, +{ #[inline] - fn proj_ball_mut(&mut self, ρ : E, _ : Linfinity) { - self.iter_mut().for_each(|v| *v = num_traits::clamp(*v, -ρ, ρ)) + fn proj_ball(self, ρ: E, exp: Linfinity) -> ::Principal { + let mut owned = self.into_owned(); + owned.proj_ball_mut(ρ, exp); + owned } } -impl<'own,SV1,SV2,SM,N,M,K,E> Adjointable, Matrix> -for Matrix -where SM: Storage, SV1: Storage + Clone, SV2: Storage + Clone, - N : Dim, M : Dim, K : Dim, E : Scalar + Zero + One + SimdComplexField, - DefaultAllocator : Allocator, - DefaultAllocator : Allocator, - DefaultAllocator : Allocator, - DefaultAllocator : Allocator { - type AdjointCodomain = OMatrix; - type Adjoint<'a> = OMatrix where SM : 'a; +impl ProjectionMut for Vector +where + SM: StorageMut, + M: Dim, + E: Float + RealField, + DefaultAllocator: Allocator, + ShapeConstraint: StridesOk, +{ + #[inline] + fn proj_ball_mut(&mut self, ρ: E, _: Linfinity) { + self.iter_mut() + .for_each(|v| *v = num_traits::clamp(*v, -ρ, ρ)) + } +} + +impl<'own, SM, N, M, K, E> Adjointable, OMatrix> for Matrix +where + SM: Storage, + N: Dim, + M: Dim, + K: Dim, + E: Float + RealField, + DefaultAllocator: Allocator + Allocator + Allocator, + ShapeConstraint: StridesOk + StridesOk, +{ + type AdjointCodomain = OMatrix; + type Adjoint<'a> + = OMatrix + where + SM: 'a; #[inline] fn adjoint(&self) -> Self::Adjoint<'_> { @@ -152,6 +461,25 @@ } } +impl<'own, SM, N, M, K, E> SimplyAdjointable, OMatrix> + for Matrix +where + SM: Storage, + N: Dim, + M: Dim, + K: Dim, + E: Float + RealField, + DefaultAllocator: Allocator + Allocator + Allocator, + ShapeConstraint: StridesOk + StridesOk, +{ + type AdjointCodomain = OMatrix; + type SimpleAdjoint = OMatrix; + + #[inline] + fn adjoint(&self) -> Self::SimpleAdjoint { + Matrix::adjoint(self) + } +} /// This function is [`nalgebra::EuclideanNorm::metric_distance`] without the `sqrt`. #[inline] fn metric_distance_squared( @@ -160,7 +488,7 @@ m2: &Matrix, ) -> T::SimdRealField where - T: SimdComplexField, + T: SimdComplexField, R1: Dim, C1: Dim, S1: Storage, @@ -175,40 +503,41 @@ }) } -// TODO: should allow different input storages in `Euclidean`. - -impl Euclidean -for Vector -where M : Dim, - S : StorageMut + Clone, - E : Float + Scalar + Zero + One + RealField, - DefaultAllocator : Allocator { - - type Output = OVector; +impl Euclidean for Matrix +where + M: Dim, + N: Dim, + S: Storage, + E: Float + RealField, + DefaultAllocator: Allocator, + ShapeConstraint: StridesOk, +{ + type PrincipalE = OMatrix; #[inline] - fn dot>(&self, other : I) -> E { - Vector::::dot(self, other.ref_instance()) + fn dot>(&self, other: I) -> E { + other.eval_ref(|ref r| Matrix::::dot(self, r)) } #[inline] fn norm2_squared(&self) -> E { - Vector::::norm_squared(self) + Matrix::::norm_squared(self) } #[inline] - fn dist2_squared>(&self, other : I) -> E { - metric_distance_squared(self, other.ref_instance()) + fn dist2_squared>(&self, other: I) -> E { + other.eval_ref(|ref r| metric_distance_squared(self, r)) } } -impl StaticEuclidean -for Vector -where M : DimName, - S : StorageMut + Clone, - E : Float + Scalar + Zero + One + RealField, - DefaultAllocator : Allocator { - +impl StaticEuclidean for Vector +where + M: DimName, + S: Storage, + E: Float + RealField, + DefaultAllocator: Allocator, + ShapeConstraint: StridesOk, +{ #[inline] fn origin() -> OVector { OVector::zeros() @@ -216,13 +545,15 @@ } /// The default norm for `Vector` is [`L2`]. -impl Normed -for Vector -where M : Dim, - S : Storage + Clone, - E : Float + Scalar + Zero + One + RealField, - DefaultAllocator : Allocator { - +impl Normed for Matrix +where + M: Dim, + N: Dim, + S: Storage, + E: Float + RealField, + DefaultAllocator: Allocator, + ShapeConstraint: StridesOk, +{ type NormExp = L2; #[inline] @@ -232,92 +563,106 @@ #[inline] fn is_zero(&self) -> bool { - Vector::::norm_squared(self) == E::ZERO + Matrix::::norm_squared(self) == E::ZERO } } -impl HasDual -for Vector -where M : Dim, - S : Storage + Clone, - E : Float + Scalar + Zero + One + RealField, - DefaultAllocator : Allocator { - // TODO: Doesn't work with different storage formats. - type DualSpace = Self; +impl HasDual for Matrix +where + M: Dim, + N: Dim, + S: Storage, + E: Float + RealField, + DefaultAllocator: Allocator, + ShapeConstraint: StridesOk, +{ + type DualSpace = OMatrix; + + fn dual_origin(&self) -> OMatrix { + let (m, n) = self.shape_generic(); + OMatrix::zeros_generic(m, n) + } } -impl Norm -for Vector -where M : Dim, - S : Storage, - E : Float + Scalar + Zero + One + RealField, - DefaultAllocator : Allocator { - +impl Norm for Vector +where + M: Dim, + S: Storage, + E: Float + RealField, +{ #[inline] - fn norm(&self, _ : L1) -> E { + fn norm(&self, _: L1) -> E { nalgebra::Norm::norm(&LpNorm(1), self) } } -impl Dist -for Vector -where M : Dim, - S : Storage + Clone, - E : Float + Scalar + Zero + One + RealField, - DefaultAllocator : Allocator { +impl Dist for Vector +where + M: Dim, + S: Storage, + E: Float + RealField, + DefaultAllocator: Allocator, + ShapeConstraint: StridesOk, +{ #[inline] - fn dist>(&self, other : I, _ : L1) -> E { - nalgebra::Norm::metric_distance(&LpNorm(1), self, other.ref_instance()) + fn dist>(&self, other: I, _: L1) -> E { + other.eval_ref(|ref r| nalgebra::Norm::metric_distance(&LpNorm(1), self, r)) } } -impl Norm -for Vector -where M : Dim, - S : Storage, - E : Float + Scalar + Zero + One + RealField, - DefaultAllocator : Allocator { - +impl Norm for Matrix +where + M: Dim, + N: Dim, + S: Storage, + E: Float + RealField, +{ #[inline] - fn norm(&self, _ : L2) -> E { + fn norm(&self, _: L2) -> E { nalgebra::Norm::norm(&LpNorm(2), self) } } -impl Dist -for Vector -where M : Dim, - S : Storage + Clone, - E : Float + Scalar + Zero + One + RealField, - DefaultAllocator : Allocator { +impl Dist for Matrix +where + M: Dim, + N: Dim, + S: Storage, + E: Float + RealField, + DefaultAllocator: Allocator, + ShapeConstraint: StridesOk, +{ #[inline] - fn dist>(&self, other : I, _ : L2) -> E { - nalgebra::Norm::metric_distance(&LpNorm(2), self, other.ref_instance()) + fn dist>(&self, other: I, _: L2) -> E { + other.eval_ref(|ref r| nalgebra::Norm::metric_distance(&LpNorm(2), self, r)) } } -impl Norm -for Vector -where M : Dim, - S : Storage, - E : Float + Scalar + Zero + One + RealField, - DefaultAllocator : Allocator { - +impl Norm for Matrix +where + M: Dim, + N: Dim, + S: Storage, + E: Float + RealField, +{ #[inline] - fn norm(&self, _ : Linfinity) -> E { + fn norm(&self, _: Linfinity) -> E { nalgebra::Norm::norm(&UniformNorm, self) } } -impl Dist -for Vector -where M : Dim, - S : Storage + Clone, - E : Float + Scalar + Zero + One + RealField, - DefaultAllocator : Allocator { +impl Dist for Matrix +where + M: Dim, + N: Dim, + S: Storage, + E: Float + RealField, + DefaultAllocator: Allocator, + ShapeConstraint: StridesOk, +{ #[inline] - fn dist>(&self, other : I, _ : Linfinity) -> E { - nalgebra::Norm::metric_distance(&UniformNorm, self, other.ref_instance()) + fn dist>(&self, other: I, _: Linfinity) -> E { + other.eval_ref(|ref r| nalgebra::Norm::metric_distance(&UniformNorm, self, r)) } } @@ -329,15 +674,15 @@ /// from [`nalgebra`] conflicting with them. Only when absolutely necessary to work with /// nalgebra, one can convert to the nalgebra view of the same type using the methods of /// this trait. -pub trait ToNalgebraRealField : Float { +pub trait ToNalgebraRealField: Float { /// The nalgebra type corresponding to this type. Usually same as `Self`. /// /// This type only carries `nalgebra` traits. - type NalgebraType : RealField; + type NalgebraType: RealField; /// The “mixed” type corresponding to this type. Usually same as `Self`. /// /// This type carries both `num_traits` and `nalgebra` traits. - type MixedType : RealField + Float; + type MixedType: RealField + Float; /// Convert to the nalgebra view of `self`. fn to_nalgebra(self) -> Self::NalgebraType; @@ -346,10 +691,10 @@ fn to_nalgebra_mixed(self) -> Self::MixedType; /// Convert from the nalgebra view of `self`. - fn from_nalgebra(t : Self::NalgebraType) -> Self; + fn from_nalgebra(t: Self::NalgebraType) -> Self; /// Convert from the mixed (nalgebra and num_traits) view to `self`. - fn from_nalgebra_mixed(t : Self::MixedType) -> Self; + fn from_nalgebra_mixed(t: Self::MixedType) -> Self; } impl ToNalgebraRealField for f32 { @@ -357,17 +702,24 @@ type MixedType = f32; #[inline] - fn to_nalgebra(self) -> Self::NalgebraType { self } - - #[inline] - fn to_nalgebra_mixed(self) -> Self::MixedType { self } + fn to_nalgebra(self) -> Self::NalgebraType { + self + } #[inline] - fn from_nalgebra(t : Self::NalgebraType) -> Self { t } + fn to_nalgebra_mixed(self) -> Self::MixedType { + self + } #[inline] - fn from_nalgebra_mixed(t : Self::MixedType) -> Self { t } + fn from_nalgebra(t: Self::NalgebraType) -> Self { + t + } + #[inline] + fn from_nalgebra_mixed(t: Self::MixedType) -> Self { + t + } } impl ToNalgebraRealField for f64 { @@ -375,15 +727,22 @@ type MixedType = f64; #[inline] - fn to_nalgebra(self) -> Self::NalgebraType { self } + fn to_nalgebra(self) -> Self::NalgebraType { + self + } #[inline] - fn to_nalgebra_mixed(self) -> Self::MixedType { self } + fn to_nalgebra_mixed(self) -> Self::MixedType { + self + } #[inline] - fn from_nalgebra(t : Self::NalgebraType) -> Self { t } + fn from_nalgebra(t: Self::NalgebraType) -> Self { + t + } #[inline] - fn from_nalgebra_mixed(t : Self::MixedType) -> Self { t } + fn from_nalgebra_mixed(t: Self::MixedType) -> Self { + t + } } - diff -r 1f19c6bbf07b -r 3868555d135c src/norms.rs --- a/src/norms.rs Sun Apr 27 20:29:43 2025 -0500 +++ b/src/norms.rs Fri May 15 14:46:30 2026 -0500 @@ -2,79 +2,84 @@ Norms, projections, etc. */ -use serde::Serialize; -use std::marker::PhantomData; +use crate::euclidean::*; +use crate::instance::Ownable; +use crate::linops::{ClosedVectorSpace, VectorSpace}; +use crate::mapping::{Instance, Mapping, Space}; use crate::types::*; -use crate::euclidean::*; -use crate::mapping::{Mapping, Space, Instance}; +use serde::{Deserialize, Serialize}; +use std::marker::PhantomData; // // Abstract norms // -#[derive(Copy,Clone,Debug)] +#[derive(Copy, Clone, Debug, Serialize, Deserialize)] /// Helper structure to convert a [`NormExponent`] into a [`Mapping`] -pub struct NormMapping{ - pub(crate) exponent : E, - _phantoms : PhantomData +pub struct NormMapping { + pub(crate) exponent: E, + _phantoms: PhantomData, } /// An exponent for norms. /// // Just a collection of desirable attributes for a marker type -pub trait NormExponent : Copy + Send + Sync + 'static { +pub trait NormExponent: Copy { /// Return the norm as a mappin - fn as_mapping(self) -> NormMapping { - NormMapping{ exponent : self, _phantoms : PhantomData } + fn as_mapping(self) -> NormMapping { + NormMapping { exponent: self, _phantoms: PhantomData } } } /// Exponent type for the 1-[`Norm`]. -#[derive(Copy,Debug,Clone,Serialize,Eq,PartialEq)] +#[derive(Copy, Debug, Clone, Serialize, Eq, PartialEq)] pub struct L1; impl NormExponent for L1 {} /// Exponent type for the 2-[`Norm`]. -#[derive(Copy,Debug,Clone,Serialize,Eq,PartialEq)] +#[derive(Copy, Debug, Clone, Serialize, Eq, PartialEq)] pub struct L2; impl NormExponent for L2 {} /// Exponent type for the ∞-[`Norm`]. -#[derive(Copy,Debug,Clone,Serialize,Eq,PartialEq)] +#[derive(Copy, Debug, Clone, Serialize, Eq, PartialEq)] pub struct Linfinity; impl NormExponent for Linfinity {} /// Exponent type for 2,1-[`Norm`]. /// (1-norm over a domain Ω, 2-norm of a vector at each point of the domain.) -#[derive(Copy,Debug,Clone,Serialize,Eq,PartialEq)] +#[derive(Copy, Debug, Clone, Serialize, Eq, PartialEq)] pub struct L21; impl NormExponent for L21 {} /// Norms for pairs (a, b). ‖(a,b)‖ = ‖(‖a‖_A, ‖b‖_B)‖_J /// For use with [`crate::direct_product::Pair`] -#[derive(Copy,Debug,Clone,Serialize,Eq,PartialEq)] +#[derive(Copy, Debug, Clone, Serialize, Eq, PartialEq)] pub struct PairNorm(pub A, pub B, pub J); impl NormExponent for PairNorm -where A : NormExponent, B : NormExponent, J : NormExponent {} - +where + A: NormExponent, + B: NormExponent, + J: NormExponent, +{ +} /// A Huber/Moreau–Yosida smoothed [`L1`] norm. (Not a norm itself.) /// /// The parameter γ of this type is the smoothing factor. Zero means no smoothing, and higher /// values more smoothing. Behaviour with γ < 0 is undefined. -#[derive(Copy,Debug,Clone,Serialize,Eq,PartialEq)] -pub struct HuberL1(pub F); -impl NormExponent for HuberL1 {} +#[derive(Copy, Debug, Clone, Serialize, Eq, PartialEq)] +pub struct HuberL1(pub F); +impl NormExponent for HuberL1 {} /// A Huber/Moreau–Yosida smoothed [`L21`] norm. (Not a norm itself.) /// /// The parameter γ of this type is the smoothing factor. Zero means no smoothing, and higher /// values more smoothing. Behaviour with γ < 0 is undefined. -#[derive(Copy,Debug,Clone,Serialize,Eq,PartialEq)] -pub struct HuberL21(pub F); -impl NormExponent for HuberL21 {} - +#[derive(Copy, Debug, Clone, Serialize, Eq, PartialEq)] +pub struct HuberL21(pub F); +impl NormExponent for HuberL21 {} /// A normed space (type) with exponent or other type `Exponent` for the norm. /// @@ -86,27 +91,27 @@ /// /// println!("{}, {} {}", x.norm(L1), x.norm(L2), x.norm(Linfinity)) /// ``` -pub trait Norm { +pub trait Norm { /// Calculate the norm. - fn norm(&self, _p : Exponent) -> F; + fn norm(&self, _p: Exponent) -> F; } /// Indicates that the `Self`-[`Norm`] is dominated by the `Exponent`-`Norm` on the space /// `Elem` with the corresponding field `F`. -pub trait Dominated { +pub trait Dominated { /// Indicates the factor $c$ for the inequality $‖x‖ ≤ C ‖x‖_p$. - fn norm_factor(&self, p : Exponent) -> F; + fn norm_factor(&self, p: Exponent) -> F; /// Given a norm-value $‖x‖_p$, calculates $C‖x‖_p$ such that $‖x‖ ≤ C‖x‖_p$ #[inline] - fn from_norm(&self, p_norm : F, p : Exponent) -> F { + fn from_norm(&self, p_norm: F, p: Exponent) -> F { p_norm * self.norm_factor(p) } } /// Trait for distances with respect to a norm. -pub trait Dist : Norm + Space { +pub trait Dist: Norm + Space { /// Calculate the distance - fn dist>(&self, other : I, _p : Exponent) -> F; + fn dist>(&self, other: I, _p: Exponent) -> F; } /// Trait for Euclidean projections to the `Exponent`-[`Norm`]-ball. @@ -119,44 +124,48 @@ /// /// println!("{:?}, {:?}", x.proj_ball(1.0, L2), x.proj_ball(0.5, Linfinity)); /// ``` -pub trait Projection : Norm + Sized -where F : Float { +pub trait Projection: Ownable + Norm { /// Projection of `self` to the `q`-norm-ball of radius ρ. - fn proj_ball(mut self, ρ : F, q : Exponent) -> Self { - self.proj_ball_mut(ρ, q); - self - } - - /// In-place projection of `self` to the `q`-norm-ball of radius ρ. - fn proj_ball_mut(&mut self, ρ : F, q : Exponent); + fn proj_ball(self, ρ: F, q: Exponent) -> Self::OwnedVariant; } -/*impl> Norm for E { +pub trait ProjectionMut: Projection { + /// In-place projection of `self` to the `q`-norm-ball of radius ρ. + fn proj_ball_mut(&mut self, ρ: F, q: Exponent); +} + +/*impl> Norm for E { #[inline] fn norm(&self, _p : L2) -> F { self.norm2() } fn dist(&self, other : &Self, _p : L2) -> F { self.dist2(other) } }*/ -impl + Norm> Projection for E { +impl + Norm> Projection for E { #[inline] - fn proj_ball(self, ρ : F, _p : L2) -> Self { self.proj_ball2(ρ) } - - #[inline] - fn proj_ball_mut(&mut self, ρ : F, _p : L2) { self.proj_ball2_mut(ρ) } + fn proj_ball(self, ρ: F, _p: L2) -> Self::OwnedVariant { + self.proj_ball2(ρ) + } } -impl HuberL1 { - fn apply(self, xnsq : F) -> F { +impl + Norm> ProjectionMut for E { + #[inline] + fn proj_ball_mut(&mut self, ρ: F, _p: L2) { + self.proj_ball2_mut(ρ) + } +} + +impl HuberL1 { + fn apply(self, xnsq: F) -> F { let HuberL1(γ) = self; let xn = xnsq.sqrt(); if γ == F::ZERO { xn } else { if xn > γ { - xn-γ / F::TWO - } else if xn<(-γ) { - -xn-γ / F::TWO + xn - γ / F::TWO + } else if xn < (-γ) { + -xn - γ / F::TWO } else { xnsq / (F::TWO * γ) } @@ -164,25 +173,25 @@ } } -impl> Norm> for E { - fn norm(&self, huber : HuberL1) -> F { +impl + Normed> Norm, F> for E { + fn norm(&self, huber: HuberL1) -> F { huber.apply(self.norm2_squared()) } } -impl> Dist> for E { - fn dist>(&self, other : I, huber : HuberL1) -> F { +impl + Normed> Dist, F> for E { + fn dist>(&self, other: I, huber: HuberL1) -> F { huber.apply(self.dist2_squared(other)) } } -// impl> Norm for Vec { +// impl> Norm for Vec { // fn norm(&self, _l21 : L21) -> F { // self.iter().map(|e| e.norm(L2)).sum() // } // } -// impl> Dist for Vec { +// impl> Dist for Vec { // fn dist>(&self, other : I, _l21 : L21) -> F { // self.iter().zip(other.iter()).map(|(e, g)| e.dist(g, L2)).sum() // } @@ -190,20 +199,21 @@ impl Mapping for NormMapping where - F : Float, - E : NormExponent, - Domain : Space + Norm, + F: Float, + E: NormExponent, + Domain: Space, + Domain::Principal: Norm, { type Codomain = F; #[inline] - fn apply>(&self, x : I) -> F { + fn apply>(&self, x: I) -> F { x.eval(|r| r.norm(self.exponent)) } } -pub trait Normed : Space + Norm { - type NormExp : NormExponent; +pub trait Normed: Space + Norm { + type NormExp: NormExponent; fn norm_exponent(&self) -> Self::NormExp; @@ -214,33 +224,38 @@ // fn similar_origin(&self) -> Self; - fn is_zero(&self) -> bool; + fn is_zero(&self) -> bool { + self.norm_() == F::ZERO + } } -pub trait HasDual : Normed { - type DualSpace : Normed; +pub trait HasDual: Normed + VectorSpace { + type DualSpace: Normed + ClosedVectorSpace; + + fn dual_origin(&self) -> ::PrincipalV; } /// Automatically implemented trait for reflexive spaces -pub trait Reflexive : HasDual +pub trait Reflexive: HasDual where - Self::DualSpace : HasDual -{ } + Self::DualSpace: HasDual, +{ +} -impl> Reflexive for X -where - X::DualSpace : HasDual -{ } +impl> Reflexive for X where + X::DualSpace: HasDual +{ +} -pub trait HasDualExponent : NormExponent { - type DualExp : NormExponent; +pub trait HasDualExponent: NormExponent { + type DualExp: NormExponent; fn dual_exponent(&self) -> Self::DualExp; } impl HasDualExponent for L2 { type DualExp = L2; - + #[inline] fn dual_exponent(&self) -> Self::DualExp { L2 @@ -249,17 +264,16 @@ impl HasDualExponent for L1 { type DualExp = Linfinity; - + #[inline] fn dual_exponent(&self) -> Self::DualExp { Linfinity } } - impl HasDualExponent for Linfinity { type DualExp = L1; - + #[inline] fn dual_exponent(&self) -> Self::DualExp { L1 @@ -271,49 +285,50 @@ ($exponent : ty) => { impl Norm> for D where - F : Float, - D : Norm, - C : Constant, + F: Float, + D: Norm<$exponent, F>, + C: Constant, { - fn norm(&self, e : Weighted<$exponent, C>) -> F { + fn norm(&self, e: Weighted<$exponent, C>) -> F { let v = e.weight.value(); assert!(v > F::ZERO); v * self.norm(e.base_fn) } } - impl NormExponent for Weighted<$exponent, C> {} + impl NormExponent for Weighted<$exponent, C> {} - impl HasDualExponent for Weighted<$exponent, C> - where $exponent : HasDualExponent { + impl HasDualExponent for Weighted<$exponent, C> + where + $exponent: HasDualExponent, + { type DualExp = Weighted<<$exponent as HasDualExponent>::DualExp, C::Type>; fn dual_exponent(&self) -> Self::DualExp { Weighted { - weight : C::Type::ONE / self.weight.value(), - base_fn : self.base_fn.dual_exponent() + weight: C::Type::ONE / self.weight.value(), + base_fn: self.base_fn.dual_exponent(), } } } - impl Projection> for T + impl Projection> for T where - T : Projection, - F : Float, - C : Constant, + T: Projection, + F: Float, + C: Constant, { - fn proj_ball(self, ρ : F, q : Weighted<$exponent , C>) -> Self { + fn proj_ball(self, ρ: F, q: Weighted<$exponent, C>) -> Self { self.proj_ball(ρ / q.weight.value(), q.base_fn) } - fn proj_ball_mut(&mut self, ρ : F, q : Weighted<$exponent , C>) { + fn proj_ball_mut(&mut self, ρ: F, q: Weighted<$exponent, C>) { self.proj_ball_mut(ρ / q.weight.value(), q.base_fn) } } - } + }; } //impl_weighted_norm!(L1); //impl_weighted_norm!(L2); //impl_weighted_norm!(Linfinity); - diff -r 1f19c6bbf07b -r 3868555d135c src/operator_arithmetic.rs --- a/src/operator_arithmetic.rs Sun Apr 27 20:29:43 2025 -0500 +++ b/src/operator_arithmetic.rs Fri May 15 14:46:30 2026 -0500 @@ -2,72 +2,74 @@ Arithmetic of [`Mapping`]s. */ -use serde::Serialize; +use crate::instance::{ClosedSpace, Instance, Space}; +use crate::mapping::{DifferentiableImpl, DifferentiableMapping, Mapping}; use crate::types::*; -use crate::instance::{Space, Instance}; -use crate::mapping::{Mapping, DifferentiableImpl, DifferentiableMapping}; +use serde::Serialize; /// A trait for encoding constant [`Float`] values -pub trait Constant : Copy + Sync + Send + 'static + std::fmt::Debug + Into { +pub trait Constant: Copy + Sync + Send + 'static + std::fmt::Debug + Into { /// The type of the value - type Type : Float; + type Type: Float; /// Returns the value of the constant fn value(&self) -> Self::Type; } -impl Constant for F { +impl Constant for F { type Type = F; #[inline] - fn value(&self) -> F { *self } + fn value(&self) -> F { + *self + } } /// Weighting of a [`Mapping`] by scalar multiplication. -#[derive(Copy,Clone,Debug,Serialize)] -pub struct Weighted { +#[derive(Copy, Clone, Debug, Serialize)] +pub struct Weighted { /// The weight - pub weight : C, + pub weight: C, /// The base [`Mapping`] being weighted. - pub base_fn : T, + pub base_fn: T, } impl Weighted where - C : Constant, + C: Constant, { /// Construct from an iterator. - pub fn new(weight : C, base_fn : T) -> Self { - Weighted{ weight, base_fn } + pub fn new(weight: C, base_fn: T) -> Self { + Weighted { weight, base_fn } } } -impl<'a, T, V, D, F, C> Mapping for Weighted +impl<'a, T, D, F, C> Mapping for Weighted where - F : Float, - D : Space, - T : Mapping, - V : Space + ClosedMul, - C : Constant + F: Float, + D: Space, + T: Mapping, + T::Codomain: ClosedMul, + C: Constant, { - type Codomain = V; + type Codomain = T::Codomain; #[inline] - fn apply>(&self, x : I) -> Self::Codomain { + fn apply>(&self, x: I) -> Self::Codomain { self.base_fn.apply(x) * self.weight.value() } } impl<'a, T, V, D, F, C> DifferentiableImpl for Weighted where - F : Float, - D : Space, - T : DifferentiableMapping, - V : Space + std::ops::Mul, - C : Constant + F: Float, + D: Space, + T: DifferentiableMapping, + V: ClosedSpace + std::ops::Mul, + C: Constant, { type Derivative = V; #[inline] - fn differential_impl>(&self, x : I) -> Self::Derivative { + fn differential_impl>(&self, x: I) -> Self::Derivative { self.base_fn.differential(x) * self.weight.value() } } @@ -76,9 +78,9 @@ #[derive(Serialize, Debug, Clone)] pub struct MappingSum(Vec); -impl< M> MappingSum { +impl MappingSum { /// Construct from an iterator. - pub fn new>(iter : I) -> Self { + pub fn new>(iter: I) -> Self { MappingSum(iter.into_iter().collect()) } @@ -90,28 +92,26 @@ impl Mapping for MappingSum where - Domain : Space + Clone, - M : Mapping, - M::Codomain : std::iter::Sum + Clone + Domain: Space + Clone, + M: Mapping, + M::Codomain: std::iter::Sum + Clone, { type Codomain = M::Codomain; - fn apply>(&self, x : I) -> Self::Codomain { - let xr = x.ref_instance(); - self.0.iter().map(|c| c.apply(xr)).sum() + fn apply>(&self, x: I) -> Self::Codomain { + x.eval_ref(|xr| self.0.iter().map(|c| c.apply(xr)).sum()) } } -impl DifferentiableImpl for MappingSum< M> +impl DifferentiableImpl for MappingSum where - Domain : Space + Clone, - M : DifferentiableMapping, - M :: DerivativeDomain : std::iter::Sum + Domain: Space, + M: DifferentiableMapping, + M::DerivativeDomain: std::iter::Sum, { type Derivative = M::DerivativeDomain; - fn differential_impl>(&self, x : I) -> Self::Derivative { - let xr = x.ref_instance(); - self.0.iter().map(|c| c.differential(xr)).sum() + fn differential_impl>(&self, x: I) -> Self::Derivative { + x.eval_ref(|xr| self.0.iter().map(|c| c.differential(xr)).sum()) } } diff -r 1f19c6bbf07b -r 3868555d135c src/parallelism.rs --- a/src/parallelism.rs Sun Apr 27 20:29:43 2025 -0500 +++ b/src/parallelism.rs Fri May 15 14:46:30 2026 -0500 @@ -7,24 +7,24 @@ For actually spawning scoped tasks in a thread pool, it currently uses [`rayon`]. */ -use std::sync::Once; +pub use rayon::{Scope, ThreadPool, ThreadPoolBuilder}; use std::num::NonZeroUsize; -use std::thread::available_parallelism; -pub use rayon::{Scope, ThreadPoolBuilder, ThreadPool}; use std::sync::atomic::{ AtomicUsize, - Ordering::{Release, Relaxed}, + Ordering::{Relaxed, Release}, }; +use std::sync::Once; +use std::thread::available_parallelism; #[cfg(feature = "use_custom_thread_pool")] type Pool = ThreadPool; #[cfg(not(feature = "use_custom_thread_pool"))] type Pool = GlobalPool; -const ONE : NonZeroUsize = unsafe { NonZeroUsize::new_unchecked(1) }; -static mut TASK_OVERBUDGETING : AtomicUsize = AtomicUsize::new(1); -static mut N_THREADS : NonZeroUsize = ONE; -static mut POOL : Option = None; +const ONE: NonZeroUsize = unsafe { NonZeroUsize::new_unchecked(1) }; +static mut TASK_OVERBUDGETING: AtomicUsize = AtomicUsize::new(1); +static mut N_THREADS: NonZeroUsize = ONE; +static mut POOL: Option = None; static INIT: Once = Once::new(); #[cfg(not(feature = "use_custom_thread_pool"))] @@ -38,7 +38,8 @@ pub fn scope<'scope, OP, R>(&self, op: OP) -> R where OP: FnOnce(&rayon::Scope<'scope>) -> R + Send, - R: Send { + R: Send, + { rayon::scope(op) } } @@ -51,17 +52,22 @@ /// /// This routine can only be called once. /// It also calls [`set_task_overbudgeting`] with $m = (n + 1) / 2$. -pub fn set_num_threads(n : NonZeroUsize) { +pub fn set_num_threads(n: NonZeroUsize) { INIT.call_once(|| unsafe { N_THREADS = n; let n = n.get(); set_task_overbudgeting((n + 1) / 2); POOL = if n > 1 { - #[cfg(feature = "use_custom_thread_pool")] { + #[cfg(feature = "use_custom_thread_pool")] + { Some(ThreadPoolBuilder::new().num_threads(n).build().unwrap()) } - #[cfg(not(feature = "use_custom_thread_pool"))] { - ThreadPoolBuilder::new().num_threads(n).build_global().unwrap(); + #[cfg(not(feature = "use_custom_thread_pool"))] + { + ThreadPoolBuilder::new() + .num_threads(n) + .build_global() + .unwrap(); Some(GlobalPool) } } else { @@ -74,20 +80,21 @@ /// /// The initial value is 1. Calling [`set_num_threads`] sets this to $m = (n + 1) / 2$, where /// $n$ is the number of threads. -pub fn set_task_overbudgeting(m : usize) { +pub fn set_task_overbudgeting(m: usize) { #[allow(static_mut_refs)] - unsafe { TASK_OVERBUDGETING.store(m, Relaxed) } + unsafe { + TASK_OVERBUDGETING.store(m, Relaxed) + } } /// Set the number of threads to the minimum of `n` and [`available_parallelism`]. /// /// This routine can only be called once. -pub fn set_max_threads(n : NonZeroUsize) { +pub fn set_max_threads(n: NonZeroUsize) { let available = available_parallelism().unwrap_or(ONE); set_num_threads(available.min(n)); } - /// Get the number of threads pub fn num_threads() -> NonZeroUsize { unsafe { N_THREADS } @@ -99,7 +106,9 @@ /// The pool has [`num_threads`]` - 1` threads. pub fn thread_pool() -> Option<&'static Pool> { #[allow(static_mut_refs)] - unsafe { POOL.as_ref() } + unsafe { + POOL.as_ref() + } } /// Get the number of thread pool workers. @@ -119,25 +128,25 @@ /// Initial multi-threaded state MultiThreadedInitial { /// Thread budget counter - budget : AtomicUsize, + budget: AtomicUsize, /// Thread pool - pool : &'scheduler Pool, + pool: &'scheduler Pool, }, /// Nested multi-threaded state MultiThreadedZoom { /// Thread budget reference - budget : &'scope AtomicUsize, - scope : &'scheduler Scope<'scope>, - } + budget: &'scope AtomicUsize, + scope: &'scheduler Scope<'scope>, + }, } /// Task execution scope for [`TaskBudget`]. pub enum TaskBudgetScope<'scope, 'scheduler> { SingleThreaded, MultiThreaded { - budget : &'scope AtomicUsize, - scope : &'scheduler Scope<'scope>, - } + budget: &'scope AtomicUsize, + scope: &'scheduler Scope<'scope>, + }, } impl<'scope, 'b> TaskBudget<'scope, 'b> { @@ -146,7 +155,7 @@ /// The number of tasks [executed][TaskBudgetScope::execute] in [scopes][TaskBudget::zoom] /// created through the budget is limited to [`num_threads()`]` + overbudget`. If `overbudget` /// is `None`, the [global setting][set_task_overbudgeting] is used.§ - pub fn init(overbudget : Option) -> Self { + pub fn init(overbudget: Option) -> Self { let n = num_threads().get(); #[allow(static_mut_refs)] let m = overbudget.unwrap_or_else(|| unsafe { TASK_OVERBUDGETING.load(Relaxed) }); @@ -161,14 +170,18 @@ } /// Initialise single-threaded thread budgeting. - pub fn none() -> Self { Self::SingleThreaded } + pub fn none() -> Self { + Self::SingleThreaded + } } impl<'scope, 'scheduler> TaskBudget<'scope, 'scheduler> { /// Create a sub-scope for launching tasks - pub fn zoom<'smaller, F, R : Send>(&self, scheduler : F) -> R - where 'scope : 'smaller, - F : for<'a> FnOnce(TaskBudgetScope<'smaller, 'a>) -> R + Send + 'smaller { + pub fn zoom<'smaller, F, R: Send>(&self, scheduler: F) -> R + where + 'scope: 'smaller, + F: for<'a> FnOnce(TaskBudgetScope<'smaller, 'a>) -> R + Send + 'smaller, + { match self { &Self::SingleThreaded => scheduler(TaskBudgetScope::SingleThreaded), &Self::MultiThreadedInitial { ref budget, pool } => { @@ -191,15 +204,16 @@ impl<'scope, 'scheduler> TaskBudgetScope<'scope, 'scheduler> { /// Queue a task or execute it in this thread if the thread budget is exhausted. - pub fn execute(&self, job : F) - where F : for<'b> FnOnce(TaskBudget<'scope, 'b>) + Send + 'scope { + pub fn execute(&self, job: F) + where + F: for<'b> FnOnce(TaskBudget<'scope, 'b>) + Send + 'scope, + { match self { Self::SingleThreaded => job(TaskBudget::SingleThreaded), Self::MultiThreaded { scope, budget } => { - let spawn = budget.fetch_update(Release, - Relaxed, - |n| (n > 1).then_some(n - 1)) - .is_ok(); + let spawn = budget + .fetch_update(Release, Relaxed, |n| (n > 1).then_some(n - 1)) + .is_ok(); if spawn { scope.spawn(|scope| { let task_budget = TaskBudget::MultiThreadedZoom { scope, budget }; @@ -216,8 +230,10 @@ /// Runs `scheduler` with a [`TaskBudget`]. /// -/// This corresponds to calling `scheduler` with [`TaskBudget::init(None)`]. -pub fn with_task_budget<'scope, F, R>(scheduler : F) -> R -where F : for<'b> FnOnce(TaskBudget<'scope, 'b>) -> R + 'scope { +/// This corresponds to calling `scheduler` with [`TaskBudget::init`]`(None)`. +pub fn with_task_budget<'scope, F, R>(scheduler: F) -> R +where + F: for<'b> FnOnce(TaskBudget<'scope, 'b>) -> R + 'scope, +{ scheduler(TaskBudget::init(None)) } diff -r 1f19c6bbf07b -r 3868555d135c src/sets.rs --- a/src/sets.rs Sun Apr 27 20:29:43 2025 -0500 +++ b/src/sets.rs Fri May 15 14:46:30 2026 -0500 @@ -2,51 +2,57 @@ This module provides various sets and traits for them. */ -use std::ops::{RangeFull,RangeFrom,Range,RangeInclusive,RangeTo,RangeToInclusive}; -use crate::types::*; +use crate::euclidean::Euclidean; +use crate::instance::{BasicDecomposition, Instance, Space}; use crate::loc::Loc; -use crate::euclidean::Euclidean; -use crate::instance::{Space, Instance}; +use crate::types::*; use serde::Serialize; +use std::ops::{Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive}; pub mod cube; pub use cube::Cube; /// Trait for arbitrary sets. The parameter `U` is the element type. -pub trait Set where U : Space { +pub trait Set +where + U: Space, +{ /// Check for element containment - fn contains>(&self, item : I) -> bool; + fn contains>(&self, item: I) -> bool; } /// Additional ordering (besides [`PartialOrd`]) of a subfamily of sets: /// greatest lower bound and least upper bound. -pub trait SetOrd : Sized { - /// Returns the smallest set of same class contain both parameters. - fn common(&self, other : &Self) -> Self; +pub trait SetOrd: Sized { + /// Returns the smallest set of same class contain both parameters. + fn common(&self, other: &Self) -> Self; - /// Returns the greatest set of same class contaied by n both parameter sets. - fn intersect(&self, other : &Self) -> Option; + /// Returns the greatest set of same class contaied by n both parameter sets. + fn intersect(&self, other: &Self) -> Option; } -impl Set> -for Cube -where U : Num + PartialOrd + Sized { - fn contains>>(&self, item : I) -> bool { - self.0.iter().zip(item.ref_instance().iter()).all(|(s, x)| s.contains(x)) +impl Set> for Cube +where + U: Num + PartialOrd + Sized, +{ + fn contains>>(&self, item: I) -> bool { + item.eval_ref(|r| self.0.iter().zip(r.iter()).all(|(s, x)| s.contains(x))) } } -impl Set for RangeFull { - fn contains>(&self, _item : I) -> bool { true } +impl Set for RangeFull { + fn contains>(&self, _item: I) -> bool { + true + } } macro_rules! impl_ranges { ($($range:ident),*) => { $( impl Set for $range where - Idx : PartialOrd, - U : PartialOrd + Space, - Idx : PartialOrd + U : Space, + U::Principal : PartialOrd, + Idx : PartialOrd + PartialOrd, { #[inline] fn contains>(&self, item : I) -> bool { @@ -56,45 +62,56 @@ )* } } -impl_ranges!(RangeFrom,Range,RangeInclusive,RangeTo,RangeToInclusive); +impl_ranges!(RangeFrom, Range, RangeInclusive, RangeTo, RangeToInclusive); /// Halfspaces described by an orthogonal vector and an offset. /// /// The halfspace is $H = \\{ t v + a \mid a^⊤ v = 0 \\}$, where $v$ is the orthogonal /// vector and $t$ the offset. -#[derive(Clone,Copy,Debug,Serialize,Eq,PartialEq)] -pub struct Halfspace where A : Euclidean, F : Float { - pub orthogonal : A, - pub offset : F, +#[derive(Clone, Copy, Debug, Serialize, Eq, PartialEq)] +pub struct Halfspace +where + A: Euclidean, + F: Float, +{ + pub orthogonal: A, + pub offset: F, } -impl Halfspace where A : Euclidean, F : Float { +impl Halfspace +where + A: Euclidean, + F: Float, +{ #[inline] - pub fn new(orthogonal : A, offset : F) -> Self { - Halfspace{ orthogonal : orthogonal, offset : offset } + pub fn new(orthogonal: A, offset: F) -> Self { + Halfspace { orthogonal: orthogonal, offset: offset } } } /// Trait for generating a halfspace spanned by another set `Self` of elements of type `U`. -pub trait SpannedHalfspace where F : Float { +pub trait SpannedHalfspace +where + F: Float, +{ /// Type of the orthogonal vector describing the halfspace. - type A : Euclidean; + type A: Euclidean; /// Returns the halfspace spanned by this set. fn spanned_halfspace(&self) -> Halfspace; } // TODO: Gram-Schmidt for higher N. -impl SpannedHalfspace for [Loc; 2] { - type A = Loc; +impl SpannedHalfspace for [Loc<1, F>; 2] { + type A = Loc<1, F>; fn spanned_halfspace(&self) -> Halfspace { let (x0, x1) = (self[0], self[1]); - Halfspace::new(x1-x0, x0[0]) + Halfspace::new(x1 - x0, x0[0]) } } // TODO: Gram-Schmidt for higher N. -impl SpannedHalfspace for [Loc; 2] { - type A = Loc; +impl SpannedHalfspace for [Loc<2, F>; 2] { + type A = Loc<2, F>; fn spanned_halfspace(&self) -> Halfspace { let (x0, x1) = (&self[0], &self[1]); let d = x1 - x0; @@ -104,29 +121,30 @@ } } -impl Set for Halfspace +impl Set for Halfspace where - A : Euclidean, - F : Float, + A: Euclidean, + F: Float, { #[inline] - fn contains>(&self, item : I) -> bool { + fn contains>(&self, item: I) -> bool { self.orthogonal.dot(item) >= self.offset } } /// Polygons defined by `N` `Halfspace`s. -#[derive(Clone,Copy,Debug,Eq,PartialEq)] -pub struct NPolygon(pub [Halfspace; N]) -where A : Euclidean, F : Float; - -impl Set for NPolygon +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub struct NPolygon(pub [Halfspace; N]) where - A : Euclidean, - F : Float, + A: Euclidean, + F: Float; + +impl Set for NPolygon +where + A: Euclidean, + F: Float, { - fn contains>(&self, item : I) -> bool { - let r = item.ref_instance(); - self.0.iter().all(|halfspace| halfspace.contains(r)) + fn contains>(&self, item: I) -> bool { + item.eval_ref(|r| self.0.iter().all(|halfspace| halfspace.contains(r))) } } diff -r 1f19c6bbf07b -r 3868555d135c src/sets/cube.rs --- a/src/sets/cube.rs Sun Apr 27 20:29:43 2025 -0500 +++ b/src/sets/cube.rs Fri May 15 14:46:30 2026 -0500 @@ -16,25 +16,19 @@ ``` */ -use serde::ser::{Serialize, Serializer, SerializeTupleStruct}; -use crate::types::*; use crate::loc::Loc; +use crate::maputil::{map1, map1_indexed, map2, FixedLength, FixedLengthMut}; use crate::sets::SetOrd; -use crate::maputil::{ - FixedLength, - FixedLengthMut, - map1, - map1_indexed, - map2, -}; +use crate::types::*; +use serde::ser::{Serialize, SerializeTupleStruct, Serializer}; /// A multi-dimensional cube $∏_{i=1}^N [a_i, b_i)$ with the starting and ending points /// along $a_i$ and $b_i$ along each dimension of type `U`. #[derive(Copy, Clone, Debug, Eq, PartialEq)] -pub struct Cube(pub(super) [[U; 2]; N]); +pub struct Cube(pub(super) [[U; 2]; N]); // Need to manually implement as [F; N] serialisation is provided only for some N. -impl Serialize for Cube +impl Serialize for Cube where F: Serialize, { @@ -50,7 +44,7 @@ } } -impl FixedLength for Cube { +impl FixedLength for Cube { type Iter = std::array::IntoIter<[A; 2], N>; type Elem = [A; 2]; #[inline] @@ -59,7 +53,7 @@ } } -impl FixedLengthMut for Cube { +impl FixedLengthMut for Cube { type IterMut<'a> = std::slice::IterMut<'a, [A; 2]>; #[inline] fn fl_iter_mut(&mut self) -> Self::IterMut<'_> { @@ -67,7 +61,7 @@ } } -impl<'a, A : Num, const N : usize> FixedLength for &'a Cube { +impl<'a, A: Num, const N: usize> FixedLength for &'a Cube { type Iter = std::slice::Iter<'a, [A; 2]>; type Elem = &'a [A; 2]; #[inline] @@ -76,15 +70,14 @@ } } - /// Iterator for [`Cube`] corners. -pub struct CubeCornersIter<'a, U : Num, const N : usize> { - index : usize, - cube : &'a Cube, +pub struct CubeCornersIter<'a, U: Num, const N: usize> { + index: usize, + cube: &'a Cube, } -impl<'a, U : Num, const N : usize> Iterator for CubeCornersIter<'a, U, N> { - type Item = Loc; +impl<'a, U: Num, const N: usize> Iterator for CubeCornersIter<'a, U, N> { + type Item = Loc; #[inline] fn next(&mut self) -> Option { if self.index >= N { @@ -92,28 +85,30 @@ } else { let i = self.index; self.index += 1; - let arr = self.cube.map_indexed(|k, a, b| if (i>>k)&1 == 0 { a } else { b }); + let arr = self + .cube + .map_indexed(|k, a, b| if (i >> k) & 1 == 0 { a } else { b }); Some(arr.into()) } } } -impl Cube { +impl Cube { /// Maps `f` over the triples $\\{(i, a\_i, b\_i)\\}\_{i=1}^N$ /// of the cube $∏_{i=1}^N [a_i, b_i)$. #[inline] - pub fn map_indexed(&self, f : impl Fn(usize, U, U) -> T) -> [T; N] { + pub fn map_indexed(&self, f: impl Fn(usize, U, U) -> T) -> [T; N] { map1_indexed(self, |i, &[a, b]| f(i, a, b)) } /// Maps `f` over the tuples $\\{(a\_i, b\_i)\\}\_{i=1}^N$ /// of the cube $∏_{i=1}^N [a_i, b_i)$. #[inline] - pub fn map(&self, f : impl Fn(U, U) -> T) -> [T; N] { + pub fn map(&self, f: impl Fn(U, U) -> T) -> [T; N] { map1(self, |&[a, b]| f(a, b)) } - /// Iterates over the start and end coordinates $\{(a_i, b_i)\}_{i=1}^N$ of the cube along + /// Iterates over the start and end coordinates $\{(a_i, b_i)\}_{i=1}^N$ of the cube along /// each dimension. #[inline] pub fn iter_coords(&self) -> std::slice::Iter<'_, [U; 2]> { @@ -122,27 +117,27 @@ /// Returns the “start” coordinate $a_i$ of the cube $∏_{i=1}^N [a_i, b_i)$. #[inline] - pub fn start(&self, i : usize) -> U { + pub fn start(&self, i: usize) -> U { self.0[i][0] } /// Returns the end coordinate $a_i$ of the cube $∏_{i=1}^N [a_i, b_i)$. #[inline] - pub fn end(&self, i : usize) -> U { + pub fn end(&self, i: usize) -> U { self.0[i][1] } /// Returns the “start” $(a_1, … ,a_N)$ of the cube $∏_{i=1}^N [a_i, b_i)$ /// spanned between $(a_1, … ,a_N)$ and $(b_1, … ,b_N)$. #[inline] - pub fn span_start(&self) -> Loc { + pub fn span_start(&self) -> Loc { Loc::new(self.map(|a, _b| a)) } /// Returns the end $(b_1, … ,b_N)$ of the cube $∏_{i=1}^N [a_i, b_i)$ /// spanned between $(a_1, … ,a_N)$ and $(b_1, … ,b_N)$. #[inline] - pub fn span_end(&self) -> Loc { + pub fn span_end(&self) -> Loc { Loc::new(self.map(|_a, b| b)) } @@ -150,19 +145,22 @@ /// $∏_{i=1}^N [a_i, b_i)$. #[inline] pub fn iter_corners(&self) -> CubeCornersIter<'_, U, N> { - CubeCornersIter{ index : 0, cube : self } + CubeCornersIter { + index: 0, + cube: self, + } } /// Returns the width-`N`-tuple $(b_1-a_1, … ,b_N-a_N)$ of the cube $∏_{i=1}^N [a_i, b_i)$. #[inline] - pub fn width(&self) -> Loc { - Loc::new(self.map(|a, b| b-a)) + pub fn width(&self) -> Loc { + Loc::new(self.map(|a, b| b - a)) } /// Translates the cube $∏_{i=1}^N [a_i, b_i)$ by the `shift` $(s_1, … , s_N)$ to /// $∏_{i=1}^N [a_i+s_i, b_i+s_i)$. #[inline] - pub fn shift(&self, shift : &Loc) -> Self { + pub fn shift(&self, shift: &Loc) -> Self { let mut cube = self.clone(); for i in 0..N { cube.0[i][0] += shift[i]; @@ -173,144 +171,158 @@ /// Creates a new cube from an array. #[inline] - pub fn new(data : [[U; 2]; N]) -> Self { + pub fn new(data: [[U; 2]; N]) -> Self { Cube(data) } } -impl Cube { +impl Cube { /// Returns the centre of the cube - pub fn center(&self) -> Loc { + pub fn center(&self) -> Loc { map1(self, |&[a, b]| (a + b) / F::TWO).into() } } -impl Cube { +impl Cube<1, U> { /// Get the corners of the cube. /// /// TODO: generic implementation once const-generics can be involved in /// calculations. #[inline] - pub fn corners(&self) -> [Loc; 2] { + pub fn corners(&self) -> [Loc<1, U>; 2] { let [[a, b]] = self.0; [a.into(), b.into()] } } -impl Cube { +impl Cube<2, U> { /// Get the corners of the cube in counter-clockwise order. /// /// TODO: generic implementation once const-generics can be involved in /// calculations. #[inline] - pub fn corners(&self) -> [Loc; 4] { - let [[a1, b1], [a2, b2]]=self.0; - [[a1, a2].into(), - [b1, a2].into(), - [b1, b2].into(), - [a1, b2].into()] + pub fn corners(&self) -> [Loc<2, U>; 4] { + let [[a1, b1], [a2, b2]] = self.0; + [ + [a1, a2].into(), + [b1, a2].into(), + [b1, b2].into(), + [a1, b2].into(), + ] } } -impl Cube { +impl Cube<3, U> { /// Get the corners of the cube. /// /// TODO: generic implementation once const-generics can be involved in /// calculations. #[inline] - pub fn corners(&self) -> [Loc; 8] { - let [[a1, b1], [a2, b2], [a3, b3]]=self.0; - [[a1, a2, a3].into(), - [b1, a2, a3].into(), - [b1, b2, a3].into(), - [a1, b2, a3].into(), - [a1, b2, b3].into(), - [b1, b2, b3].into(), - [b1, a2, b3].into(), - [a1, a2, b3].into()] + pub fn corners(&self) -> [Loc<3, U>; 8] { + let [[a1, b1], [a2, b2], [a3, b3]] = self.0; + [ + [a1, a2, a3].into(), + [b1, a2, a3].into(), + [b1, b2, a3].into(), + [a1, b2, a3].into(), + [a1, b2, b3].into(), + [b1, b2, b3].into(), + [b1, a2, b3].into(), + [a1, a2, b3].into(), + ] } } // TODO: Implement Add and Sub of Loc to Cube, and Mul and Div by U : Num. -impl From<[[U; 2]; N]> for Cube { +impl From<[[U; 2]; N]> for Cube { #[inline] - fn from(data : [[U; 2]; N]) -> Self { + fn from(data: [[U; 2]; N]) -> Self { Cube(data) } } -impl From> for [[U; 2]; N] { +impl From> for [[U; 2]; N] { #[inline] - fn from(Cube(data) : Cube) -> Self { + fn from(Cube(data): Cube) -> Self { data } } - -impl Cube where U : Num + PartialOrd { +impl Cube +where + U: Num + PartialOrd, +{ /// Checks whether the cube is non-degenerate, i.e., the start coordinate /// of each axis is strictly less than the end coordinate. #[inline] pub fn nondegenerate(&self) -> bool { self.0.iter().all(|range| range[0] < range[1]) } - + /// Checks whether the cube intersects some `other` cube. /// Matching boundary points are not counted, so `U` is ideally a [`Float`]. #[inline] - pub fn intersects(&self, other : &Cube) -> bool { - self.iter_coords().zip(other.iter_coords()).all(|([a1, b1], [a2, b2])| { - a1 < b2 && a2 < b1 - }) + pub fn intersects(&self, other: &Cube) -> bool { + self.iter_coords() + .zip(other.iter_coords()) + .all(|([a1, b1], [a2, b2])| a1 < b2 && a2 < b1) } /// Checks whether the cube contains some `other` cube. - pub fn contains_set(&self, other : &Cube) -> bool { - self.iter_coords().zip(other.iter_coords()).all(|([a1, b1], [a2, b2])| { - a1 <= a2 && b1 >= b2 - }) + pub fn contains_set(&self, other: &Cube) -> bool { + self.iter_coords() + .zip(other.iter_coords()) + .all(|([a1, b1], [a2, b2])| a1 <= a2 && b1 >= b2) } /// Produces the point of minimum $ℓ^p$-norm within the cube `self` for any $p$-norm. /// This is the point where each coordinate is closest to zero. #[inline] - pub fn minnorm_point(&self) -> Loc { + pub fn minnorm_point(&self) -> Loc { let z = U::ZERO; // As always, we assume that a ≤ b. self.map(|a, b| { debug_assert!(a <= b); match (a < z, z < b) { - (false, _) => a, - (_, false) => b, - (true, true) => z + (false, _) => a, + (_, false) => b, + (true, true) => z, } - }).into() + }) + .into() } /// Produces the point of maximum $ℓ^p$-norm within the cube `self` for any $p$-norm. /// This is the point where each coordinate is furthest from zero. #[inline] - pub fn maxnorm_point(&self) -> Loc { + pub fn maxnorm_point(&self) -> Loc { let z = U::ZERO; // As always, we assume that a ≤ b. self.map(|a, b| { debug_assert!(a <= b); match (a < z, z < b) { - (false, _) => b, - (_, false) => a, + (false, _) => b, + (_, false) => a, // A this stage we must have a < 0 (so U must be signed), and want to check // whether |a| > |b|. We can do this without assuming U to actually implement // `Neg` by comparing whether 0 > a + b. - (true, true) => if z > a + b { a } else { b } + (true, true) => { + if z > a + b { + a + } else { + b + } + } } - }).into() + }) + .into() } } macro_rules! impl_common { ($($t:ty)*, $min:ident, $max:ident) => { $( - impl SetOrd for Cube<$t, N> { + impl SetOrd for Cube { #[inline] fn common(&self, other : &Self) -> Self { map2(self, other, |&[a1, b1], &[a2, b2]| { @@ -338,7 +350,7 @@ #[cfg(feature = "nightly")] impl_common!(f32 f64, minimum, maximum); -impl std::ops::Index for Cube { +impl std::ops::Index for Cube { type Output = [U; 2]; #[inline] fn index(&self, index: usize) -> &Self::Output { @@ -346,7 +358,7 @@ } } -impl std::ops::IndexMut for Cube { +impl std::ops::IndexMut for Cube { #[inline] fn index_mut(&mut self, index: usize) -> &mut Self::Output { &mut self.0[index] diff -r 1f19c6bbf07b -r 3868555d135c src/types.rs --- a/src/types.rs Sun Apr 27 20:29:43 2025 -0500 +++ b/src/types.rs Fri May 15 14:46:30 2026 -0500 @@ -10,15 +10,12 @@ */ //use trait_set::trait_set; +pub use num_traits::cast::AsPrimitive; pub use num_traits::Float as NumTraitsFloat; // needed to re-export functions. -pub use num_traits::cast::AsPrimitive; pub use simba::scalar::{ - ClosedAdd, ClosedAddAssign, + ClosedAdd, ClosedAddAssign, ClosedDiv, ClosedDivAssign, ClosedMul, ClosedMulAssign, ClosedNeg, ClosedSub, ClosedSubAssign, - ClosedMul, ClosedMulAssign, - ClosedDiv, ClosedDivAssign, - ClosedNeg }; /// Typical integer type @@ -34,8 +31,8 @@ pub type float = f64; /// Casts of abstract numerical types to others via the standard `as` keyword. -pub trait CastFrom : num_traits::cast::AsPrimitive { - fn cast_from(other : T) -> Self; +pub trait CastFrom: num_traits::cast::AsPrimitive { + fn cast_from(other: T) -> Self; } macro_rules! impl_casts { @@ -58,53 +55,71 @@ f32 f64); /// Trait for general numeric types -pub trait Num : 'static + Copy + Sync + Send + num::Num + num_traits::NumAssign - + std::iter::Sum + std::iter::Product - + std::fmt::Debug + std::fmt::Display + serde::Serialize - + CastFrom + CastFrom + CastFrom + CastFrom - + CastFrom + CastFrom - + CastFrom + CastFrom + CastFrom + CastFrom - + CastFrom + CastFrom - + CastFrom + CastFrom - + crate::instance::Space { - - const ZERO : Self; - const ONE : Self; - const TWO : Self; +pub trait Num: + 'static + + Copy + + Sync + + Send + + num::Num + + num_traits::NumAssign + + std::iter::Sum + + std::iter::Product + + std::fmt::Debug + + std::fmt::Display + + serde::Serialize + + CastFrom + + CastFrom + + CastFrom + + CastFrom + + CastFrom + + CastFrom + + CastFrom + + CastFrom + + CastFrom + + CastFrom + + CastFrom + + CastFrom + + CastFrom + + CastFrom + + crate::instance::ClosedSpace +{ + const ZERO: Self; + const ONE: Self; + const TWO: Self; /// Generic version of `Self::MAX` - const RANGE_MAX : Self; + const RANGE_MAX: Self; /// Generic version of `Self::MIN` - const RANGE_MIN : Self; + const RANGE_MIN: Self; } /// Trait for signed numeric types -pub trait SignedNum : Num + num::Signed + std::ops::Neg {} -impl> SignedNum for U { } +pub trait SignedNum: Num + num::Signed + std::ops::Neg {} +impl> SignedNum for U {} /// Trait for floating point numbers -pub trait Float : SignedNum + num::Float /*+ From*/ { +pub trait Float: SignedNum + std::fmt::LowerExp + num::Float /*+ From*/ { // An unsigned integer that can be used for indexing operations and // converted to F without loss. //type CompatibleSize : CompatibleUnsigned; - const PI : Self; - const E : Self; - const EPSILON : Self; - const SQRT_2 : Self; - const INFINITY : Self; - const NEG_INFINITY : Self; - const NAN : Self; - const FRAC_2_SQRT_PI : Self; + const PI: Self; + const E: Self; + const EPSILON: Self; + const SQRT_2: Self; + const INFINITY: Self; + const NEG_INFINITY: Self; + const NAN: Self; + const FRAC_2_SQRT_PI: Self; } /// Trait for integers -pub trait Integer : Num + num::Integer {} +pub trait Integer: Num + num::Integer {} /// Trait for unsigned integers -pub trait Unsigned : Num + Integer + num::Unsigned {} +pub trait Unsigned: Num + Integer + num::Unsigned {} /// Trait for signed integers -pub trait Signed : SignedNum + Integer {} +pub trait Signed: SignedNum + Integer {} macro_rules! impl_num_consts { ($($type:ty)*) => { $( @@ -137,14 +152,14 @@ #[cfg(any(target_pointer_width = "32", target_pointer_width = "16"))] type CompatibleSize = usize;*/ - const PI : Self = std::f64::consts::PI; - const E : Self = std::f64::consts::E; - const EPSILON : Self = std::f64::EPSILON; - const SQRT_2 : Self = std::f64::consts::SQRT_2; - const INFINITY : Self = std::f64::INFINITY; - const NEG_INFINITY : Self = std::f64::NEG_INFINITY; - const NAN : Self = std::f64::NAN; - const FRAC_2_SQRT_PI : Self = std::f64::consts::FRAC_2_SQRT_PI; + const PI: Self = std::f64::consts::PI; + const E: Self = std::f64::consts::E; + const EPSILON: Self = std::f64::EPSILON; + const SQRT_2: Self = std::f64::consts::SQRT_2; + const INFINITY: Self = std::f64::INFINITY; + const NEG_INFINITY: Self = std::f64::NEG_INFINITY; + const NAN: Self = std::f64::NAN; + const FRAC_2_SQRT_PI: Self = std::f64::consts::FRAC_2_SQRT_PI; } impl Float for f32 { @@ -155,14 +170,14 @@ type CompatibleSize = usize; */ - const PI : Self = std::f32::consts::PI; - const E : Self = std::f32::consts::E; - const EPSILON : Self = std::f32::EPSILON; - const SQRT_2 : Self = std::f32::consts::SQRT_2; - const INFINITY : Self = std::f32::INFINITY; - const NEG_INFINITY : Self = std::f32::NEG_INFINITY; - const NAN : Self = std::f32::NAN; - const FRAC_2_SQRT_PI : Self = std::f32::consts::FRAC_2_SQRT_PI; + const PI: Self = std::f32::consts::PI; + const E: Self = std::f32::consts::E; + const EPSILON: Self = std::f32::EPSILON; + const SQRT_2: Self = std::f32::consts::SQRT_2; + const INFINITY: Self = std::f32::INFINITY; + const NEG_INFINITY: Self = std::f32::NEG_INFINITY; + const NAN: Self = std::f32::NAN; + const FRAC_2_SQRT_PI: Self = std::f32::consts::FRAC_2_SQRT_PI; } /* @@ -171,4 +186,3 @@ pub trait CompatibleSigned = Signed + Into; } */ -